In [None]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import HeteroData

# Load the ENZYMES dataset
dataset = TUDataset(root="data", name="ENZYMES")

# Define node type mapping based on features
def categorize_node(node_feature):
    """Dummy function to assign node types based on features"""
    if node_feature.mean() > 0.5:
        return "type_A"
    else:
        return "type_B"

# Define edge type mapping
def categorize_edge(node_type_1, node_type_2):
    """Dummy function to assign edge types"""
    if node_type_1 == node_type_2:
        return "same_type"
    else:
        return "diff_type"

# Convert each graph into a heterogeneous graph
heterogeneous_graphs = []

for graph in dataset:
    hetero_data = HeteroData()
    
    # Assign node types based on their features
    node_types = [categorize_node(graph.x[i]) for i in range(graph.x.size(0))]
    unique_types = list(set(node_types))
    
    # Create a mapping from type to node indices
    node_type_mapping = {t: [] for t in unique_types}
    for i, t in enumerate(node_types):
        node_type_mapping[t].append(i)

    # Store node features for each type
    for node_type, indices in node_type_mapping.items():
        hetero_data[node_type].x = graph.x[torch.tensor(indices)]
    
    # Assign edge types
    edge_index = graph.edge_index
    edge_types = []

    for i in range(edge_index.shape[1]):
        src, dst = edge_index[:, i]
        src_type = node_types[src]
        dst_type = node_types[dst]
        edge_types.append(categorize_edge(src_type, dst_type))

    # Create edge index per relation
    edge_type_mapping = {t: [] for t in set(edge_types)}
    for i, edge_type in enumerate(edge_types):
        edge_type_mapping[edge_type].append(edge_index[:, i])

    # Store edges in HeteroData
    for edge_type, edges in edge_type_mapping.items():
        hetero_data[src_type, edge_type, dst_type].edge_index = torch.stack(edges, dim=1)

    heterogeneous_graphs.append(hetero_data)

# Print an example graph
print(heterogeneous_graphs[0])


HeteroData(
  type_B={ x=[37, 3] },
  (type_B, same_type, type_B)={ edge_index=[2, 168] }
)


In [3]:
dataset.x.shape, dataset.y.shape

(torch.Size([19580, 3]), torch.Size([600]))

In [4]:
len(heterogeneous_graphs)

600

In [5]:
dataset[:5], heterogeneous_graphs[:5]

(ENZYMES(5),
 [HeteroData(
    type_B={ x=[37, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 168] }
  ),
  HeteroData(
    type_B={ x=[23, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 102] }
  ),
  HeteroData(
    type_B={ x=[25, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 92] }
  ),
  HeteroData(
    type_B={ x=[24, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 90] }
  ),
  HeteroData(
    type_B={ x=[23, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 90] }
  )])