In [6]:
import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import HeteroData

# Load the ENZYMES dataset (assumes the nodes are one-hot encoded into 3 types)
dataset = TUDataset(root='data/ENZYMES', name='ENZYMES')

# We'll store the converted heterogeneous graphs here.
hetero_graphs = []

# Define a mapping from one-hot vector to a node type label.
# This assumes the one-hot vector is of length 3.
def one_hot_to_type(x_row):
    # Find the index of the 1.0 entry.
    type_idx = int(torch.argmax(x_row))
    if type_idx == 0:
        return "A"
    elif type_idx == 1:
        return "B"
    elif type_idx == 2:
        return "C"
    else:
        return "Unknown"

# For an undirected graph with 3 node types, we only need to specify each unordered pair once.
# We choose a canonical order, e.g. using the list ['A', 'B', 'C'].
node_type_order = ['A', 'B', 'C']

# Process every graph in the dataset.
for graph in dataset:
    hetero_data = HeteroData()
    
    num_nodes = graph.x.size(0)
    # Determine node type for each node.
    node_types = [one_hot_to_type(graph.x[i]) for i in range(num_nodes)]
    
    # Group nodes by type.
    node_type_to_indices = {}
    for idx, n_type in enumerate(node_types):
        if n_type not in node_type_to_indices:
            node_type_to_indices[n_type] = []
        node_type_to_indices[n_type].append(idx)
    
    # Add node features for each type to HeteroData.
    for n_type, indices in node_type_to_indices.items():
        indices_tensor = torch.tensor(indices, dtype=torch.long)
        hetero_data[n_type].x = graph.x[indices_tensor]
    
    # Prepare a dictionary to collect edges for each relation.
    edge_dict = {}
    # Iterate over each edge in the homogeneous graph.
    # Note: graph.edge_index is assumed to be of shape [2, num_edges]
    for i in range(graph.edge_index.size(1)):
        u = int(graph.edge_index[0, i])
        v = int(graph.edge_index[1, i])
        # Get the node types for the endpoints.
        t_u = node_types[u]
        t_v = node_types[v]
        # For an undirected graph, define the relation using a canonical order.
        # For example, if t_u and t_v are different, we sort them.
        if node_type_order.index(t_u) <= node_type_order.index(t_v):
            relation = (t_u, 'to', t_v)
        else:
            relation = (t_v, 'to', t_u)
        if relation not in edge_dict:
            edge_dict[relation] = []
        edge_dict[relation].append([u, v])
    
    # To convert edge indices for heterogeneous data, we need a mapping from
    # the original node indices to the new indices within each node type group.
    mapping = {}
    for n_type, indices in node_type_to_indices.items():
        mapping[n_type] = {old_idx: new_idx for new_idx, old_idx in enumerate(indices)}
    
    # For each relation, convert the edge indices to the new indexing.
    for relation, edges in edge_dict.items():
        src_type, _, dst_type = relation
        converted_edges = []
        for edge in edges:
            u, v = edge
            # Depending on the canonical ordering chosen above, ensure we use the correct mapping.
            if node_type_order.index(src_type) <= node_type_order.index(dst_type):
                # Expect u to be from src_type and v from dst_type.
                if u in mapping[src_type] and v in mapping[dst_type]:
                    new_u = mapping[src_type][u]
                    new_v = mapping[dst_type][v]
                    converted_edges.append([new_u, new_v])
            else:
                # If the relation was reversed (should not happen due to our sorting) handle similarly.
                if u in mapping[dst_type] and v in mapping[src_type]:
                    new_u = mapping[dst_type][u]
                    new_v = mapping[src_type][v]
                    converted_edges.append([new_u, new_v])
        # Create an edge index tensor; if no edges exist, use an empty tensor.
        if len(converted_edges) == 0:
            edge_index_tensor = torch.empty((2, 0), dtype=torch.long)
        else:
            edge_index_tensor = torch.tensor(converted_edges, dtype=torch.long).t().contiguous()
        # Add the relation to hetero_data.
        hetero_data[relation].edge_index = edge_index_tensor

    # Optionally, you can also store the graph-level label.
    hetero_data.graph_label = graph.y
    
    hetero_graphs.append(hetero_data)

# Example: Print the first heterogeneous graph's structure.
print(hetero_graphs[0])


Downloading https://www.chrsmrrs.com/graphkerneldatasets/ENZYMES.zip
Processing...
Done!


HeteroData(
  graph_label=[1],
  A={ x=[24, 3] },
  B={ x=[13, 3] },
  (A, to, A)={ edge_index=[2, 86] },
  (A, to, B)={ edge_index=[2, 17] },
  (B, to, B)={ edge_index=[2, 48] }
)


In [7]:
hetero_graphs[:5]

[HeteroData(
   graph_label=[1],
   A={ x=[24, 3] },
   B={ x=[13, 3] },
   (A, to, A)={ edge_index=[2, 86] },
   (A, to, B)={ edge_index=[2, 17] },
   (B, to, B)={ edge_index=[2, 48] }
 ),
 HeteroData(
   graph_label=[1],
   A={ x=[15, 3] },
   B={ x=[8, 3] },
   (A, to, A)={ edge_index=[2, 52] },
   (A, to, B)={ edge_index=[2, 19] },
   (B, to, B)={ edge_index=[2, 12] }
 ),
 HeteroData(
   graph_label=[1],
   A={ x=[19, 3] },
   B={ x=[6, 3] },
   (A, to, A)={ edge_index=[2, 54] },
   (A, to, B)={ edge_index=[2, 16] },
   (B, to, B)={ edge_index=[2, 6] }
 ),
 HeteroData(
   graph_label=[1],
   A={ x=[18, 3] },
   B={ x=[6, 3] },
   (A, to, A)={ edge_index=[2, 50] },
   (A, to, B)={ edge_index=[2, 17] },
   (B, to, B)={ edge_index=[2, 6] }
 ),
 HeteroData(
   graph_label=[1],
   A={ x=[18, 3] },
   B={ x=[5, 3] },
   (A, to, A)={ edge_index=[2, 58] },
   (A, to, B)={ edge_index=[2, 14] },
   (B, to, B)={ edge_index=[2, 4] }
 )]

In [3]:
dataset.x.shape, dataset.y.shape

(torch.Size([19580, 3]), torch.Size([600]))

In [4]:
len(heterogeneous_graphs)

600

In [5]:
dataset[:5], heterogeneous_graphs[:5]

(ENZYMES(5),
 [HeteroData(
    type_B={ x=[37, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 168] }
  ),
  HeteroData(
    type_B={ x=[23, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 102] }
  ),
  HeteroData(
    type_B={ x=[25, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 92] }
  ),
  HeteroData(
    type_B={ x=[24, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 90] }
  ),
  HeteroData(
    type_B={ x=[23, 3] },
    (type_B, same_type, type_B)={ edge_index=[2, 90] }
  )])