In [142]:
import matplotlib.pyplot as plt
import networkx as nx

import torch
from torch_geometric.datasets import TUDataset
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_networkx


In [143]:
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
hom_data = dataset[61] # example with A, B, and C

In [144]:
len(hom_data.x)

39

In [145]:
a_list = []
b_list = []
c_list = []

for idx, x_row in enumerate(hom_data.x):
    type_idx = int(torch.argmax(x_row))
    if type_idx == 0:
        a_list.append(idx)
    elif type_idx == 1:
        b_list.append(idx)
    elif type_idx == 2:
        c_list.append(idx)
    else:
        print("Unknown")


In [146]:
a_list, len(a_list)

([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], 10)

In [147]:
a_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(a_list)}
a_map

{0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9}

In [148]:
b_list, len(b_list)

([10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27], 18)

In [149]:
b_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(b_list)}
b_map

{10: 0,
 11: 1,
 12: 2,
 13: 3,
 14: 4,
 15: 5,
 16: 6,
 17: 7,
 18: 8,
 19: 9,
 20: 10,
 21: 11,
 22: 12,
 23: 13,
 24: 14,
 25: 15,
 26: 16,
 27: 17}

In [150]:
c_list, len(c_list)

([28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38], 11)

In [151]:
c_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(c_list)}
c_map

{28: 0, 29: 1, 30: 2, 31: 3, 32: 4, 33: 5, 34: 6, 35: 7, 36: 8, 37: 9, 38: 10}

In [152]:
sum((len(a_list),len(b_list),len(c_list)))

39

In [153]:
het_data = HeteroData()

In [154]:
het_data["A"].x = torch.ones(len(a_list), 1)
het_data["B"].x = torch.ones(len(b_list), 1)
het_data["C"].x = torch.ones(len(c_list), 1)

In [155]:
hom_data.edge_index

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,
          5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10,
         10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15,
         15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 20,
         20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25,
         25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30,
         30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 35,
         35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 38, 38],
        [12, 13, 28, 29, 14, 15, 33, 34,  3, 15, 35,  2, 15, 19, 35, 19, 21, 35,
         20, 22, 36, 20, 24, 25, 37, 23, 26, 38,  9, 27, 38,  8, 27, 38, 11, 16,
         33, 34, 10, 17, 18, 32,  0, 18, 29, 30, 31,  0, 28, 29,  1, 33, 34,  1,
          2,  3, 10, 33, 34, 11, 18, 32, 11, 12, 17, 30, 31, 32,  3,  4, 35,  5,
          6, 25,  4, 22, 36,  5, 21, 36,  7

In [156]:
a_tensor = torch.tensor(a_list)
a_mask = torch.isin(hom_data.edge_index[0], a_tensor)
a_mask

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [157]:
b_tensor = torch.tensor(b_list)
b_mask = torch.isin(hom_data.edge_index[1], b_tensor)
b_mask

tensor([ True,  True, False, False,  True,  True, False, False, False,  True,
        False, False,  True,  True, False,  True,  True, False,  True,  True,
        False,  True,  True,  True, False,  True,  True, False, False,  True,
        False, False,  True, False,  True,  True, False, False,  True,  True,
         True, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False,  True,
         True, False,  True,  True,  True, False, False, False, False, False,
        False, False, False,  True, False,  True, False, False,  True, False,
        False,  True,  True, False,  True,  True, False, False,  True,  True,
        False, False,  True,  True, False, False, False,  True, False, False,
         True, False, False,  True,  True, False,  True,  True, False,  True,
         True, False,  True,  True,  True, False,  True,  True,  True, False,
        False,  True,  True,  True, False, False, False, False, 

In [158]:
a_b_mask = a_mask & b_mask
a_b_mask

tensor([ True,  True, False, False,  True,  True, False, False, False,  True,
        False, False,  True,  True, False,  True,  True, False,  True,  True,
        False,  True,  True,  True, False,  True,  True, False, False,  True,
        False, False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [192]:
a_b_edge_index = hom_data.edge_index[:, a_b_mask]
a_b_edge_index

tensor([[ 0,  0,  1,  1,  2,  3,  3,  4,  4,  5,  5,  6,  6,  6,  7,  7,  8,  9],
        [12, 13, 14, 15, 15, 15, 19, 19, 21, 20, 22, 20, 24, 25, 23, 26, 27, 27]])

In [194]:
a_b_edge_index.shape

torch.Size([2, 18])

In [160]:
a_edges = torch.tensor([a_map[orig_id.item()] for orig_id in a_b_edge_index[0]])
b_edges = torch.tensor([b_map[orig_id.item()] for orig_id in a_b_edge_index[1]])
a_b_edge_index = torch.stack((a_edges, b_edges))
a_b_edge_index

tensor([[ 0,  0,  1,  1,  2,  3,  3,  4,  4,  5,  5,  6,  6,  6,  7,  7,  8,  9],
        [ 2,  3,  4,  5,  5,  5,  9,  9, 11, 10, 12, 10, 14, 15, 13, 16, 17, 17]])

In [161]:
het_data["A", "connects", "B"].edge_index = a_b_edge_index

In [None]:
filtered_edges.numel()

In [162]:
het_data

HeteroData(
  A={ x=[10, 1] },
  B={ x=[18, 1] },
  C={ x=[11, 1] },
  (A, connects, B)={ edge_index=[2, 18] }
)

In [163]:
def get_lists(hom_data):
    """
    Split one-hot encoded homogeneous nodes into distinct types. 
    """
    a_list = []
    b_list = []
    c_list = []

    for idx, x_row in enumerate(hom_data.x):
        type_idx = int(torch.argmax(x_row))
        if type_idx == 0:
            a_list.append(idx)
        elif type_idx == 1:
            b_list.append(idx)
        elif type_idx == 2:
            c_list.append(idx)
        else:
            print("Unknown")

    return a_list, b_list, c_list

def get_maps(type_list):
    """
    Map homogeneous node indicies to heterogeneous node indicies for a given node type.
    """
    return {orig_idx: new_idx for new_idx, orig_idx in enumerate(type_list)}

def get_edge_index(hom_data, type1_list, type2_list, type1_map, type2_map):
    """
    Split homogeneous edge_index into distinc types.
    """
    type1_tensor = torch.tensor(type1_list)
    type2_tensor = torch.tensor(type2_list)

    type1_mask = torch.isin(hom_data.edge_index[0], type1_tensor)
    type2_mask = torch.isin(hom_data.edge_index[1], type2_tensor)

    type1_type2_mask = type1_mask & type2_mask

    type1_type2_edge_index = hom_data.edge_index[:, type1_type2_mask]

    type1_edges = torch.tensor([type1_map[orig_id.item()] for orig_id in type1_type2_edge_index[0]])
    type2_edges = torch.tensor([type2_map[orig_id.item()] for orig_id in type1_type2_edge_index[1]])
    
    type1_type2_edge_index = torch.stack((type1_edges, type2_edges))

    return type1_type2_edge_index

def convert_hom_to_het(hom_data):
    """
    Converts undirected homogeneous graph into undirected heterogeneous graph.
    """
    het_data = HeteroData()

    # 1. Get lists
    a_list, b_list, c_list = get_lists(hom_data)

    # 2. Assign nodes to type
    het_data["A"].x = torch.ones(len(a_list), 1)
    het_data["B"].x = torch.ones(len(b_list), 1)
    het_data["C"].x = torch.ones(len(c_list), 1)

    # 3. Get maps
    a_map, b_map, c_map = get_maps(a_list), get_maps(b_list), get_maps(c_list)

    # 4. Assign edge types
    het_data["A", "connects", "A"].edge_index = get_edge_index(hom_data, a_list, a_list, a_map, a_map)
    het_data["A", "connects", "B"].edge_index = get_edge_index(hom_data, a_list, b_list, a_map, b_map)
    het_data["A", "connects", "C"].edge_index = get_edge_index(hom_data, a_list, c_list, a_map, c_map)
    het_data["B", "connects", "B"].edge_index = get_edge_index(hom_data, b_list, b_list, b_map, b_map)
    het_data["B", "connects", "C"].edge_index = get_edge_index(hom_data, b_list, c_list, b_map, c_map)
    het_data["C", "connects", "C"].edge_index = get_edge_index(hom_data, c_list, c_list, c_map, c_map)

    het_data.y = hom_data.y

    return het_data

# Main
def convert_dataset(dataset):
    hetero_graphs = []
    for graph in dataset:
        hetero_graph = convert_hom_to_het(graph)
        hetero_graphs.append(hetero_graph)
    return hetero_graphs

In [164]:
het_data = convert_hom_to_het(hom_data)
het_data # since B-A, C-A, C-B = A-B, A-C, C-B in undirected graph, we don't need to count these double to total 140 edges (87 edges works, but mb we should include it cuz the original data has it)

HeteroData(
  y=[1],
  A={ x=[10, 1] },
  B={ x=[18, 1] },
  C={ x=[11, 1] },
  (A, connects, A)={ edge_index=[2, 4] },
  (A, connects, B)={ edge_index=[2, 18] },
  (A, connects, C)={ edge_index=[2, 12] },
  (B, connects, B)={ edge_index=[2, 24] },
  (B, connects, C)={ edge_index=[2, 23] },
  (C, connects, C)={ edge_index=[2, 6] }
)

In [165]:
hom_data.edge_index

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,
          5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10,
         10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13, 14, 14, 14, 15,
         15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 20,
         20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23, 24, 24, 24, 24, 25, 25, 25,
         25, 26, 26, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30,
         30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 35,
         35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 38, 38],
        [12, 13, 28, 29, 14, 15, 33, 34,  3, 15, 35,  2, 15, 19, 35, 19, 21, 35,
         20, 22, 36, 20, 24, 25, 37, 23, 26, 38,  9, 27, 38,  8, 27, 38, 11, 16,
         33, 34, 10, 17, 18, 32,  0, 18, 29, 30, 31,  0, 28, 29,  1, 33, 34,  1,
          2,  3, 10, 33, 34, 11, 18, 32, 11, 12, 17, 30, 31, 32,  3,  4, 35,  5,
          6, 25,  4, 22, 36,  5, 21, 36,  7

In [166]:
het_graphs = []
for i in (0, 10, 61, 115, 200):
    het_graphs.append(convert_hom_to_het(dataset[i]))

In [167]:
for i in het_graphs:
    print(i)

HeteroData(
  y=[1],
  A={ x=[24, 1] },
  B={ x=[13, 1] },
  C={ x=[0, 1] },
  (A, connects, A)={ edge_index=[2, 86] },
  (A, connects, B)={ edge_index=[2, 17] },
  (A, connects, C)={ edge_index=[2, 0] },
  (B, connects, B)={ edge_index=[2, 48] },
  (B, connects, C)={ edge_index=[2, 0] },
  (C, connects, C)={ edge_index=[2, 0] }
)
HeteroData(
  y=[1],
  A={ x=[0, 1] },
  B={ x=[0, 1] },
  C={ x=[4, 1] },
  (A, connects, A)={ edge_index=[2, 0] },
  (A, connects, B)={ edge_index=[2, 0] },
  (A, connects, C)={ edge_index=[2, 0] },
  (B, connects, B)={ edge_index=[2, 0] },
  (B, connects, C)={ edge_index=[2, 0] },
  (C, connects, C)={ edge_index=[2, 12] }
)
HeteroData(
  y=[1],
  A={ x=[10, 1] },
  B={ x=[18, 1] },
  C={ x=[11, 1] },
  (A, connects, A)={ edge_index=[2, 4] },
  (A, connects, B)={ edge_index=[2, 18] },
  (A, connects, C)={ edge_index=[2, 12] },
  (B, connects, B)={ edge_index=[2, 24] },
  (B, connects, C)={ edge_index=[2, 23] },
  (C, connects, C)={ edge_index=[2, 6] }
)
Het

In [190]:
import torch
from torch_geometric.data import HeteroData



def get_node_type_mapping(hom_data):
    """Dynamically identify node types from one-hot encoded features."""
    type_dict = {}
    for idx, x_row in enumerate(hom_data.x):
        type_idx = int(torch.argmax(x_row))
        if type_idx not in type_dict:
            type_dict[type_idx] = []
        type_dict[type_idx].append(idx)
    return type_dict



def create_mapping_dict(type_dict):
    """Create ID mapping dictionaries for all node types."""
    return {t: {orig_idx: new_idx for new_idx, orig_idx in enumerate(nodes)} 
            for t, nodes in type_dict.items()}



def filter_and_remap_edges(hom_data, src_type_nodes, dst_type_nodes, 
                          src_map, dst_map):
    """Generalized edge filtering and index remapping."""
    src_tensor = torch.tensor(src_type_nodes)
    dst_tensor = torch.tensor(dst_type_nodes)
    
    # Find edges between these types
    mask = torch.isin(hom_data.edge_index[0], src_tensor) & \
           torch.isin(hom_data.edge_index[1], dst_tensor)
    
    filtered_edges = hom_data.edge_index[:, mask]
    
    # Remap indices using dictionaries
    src_indices = torch.tensor([src_map[orig.item()] for orig in filtered_edges[0]])
    dst_indices = torch.tensor([dst_map[orig.item()] for orig in filtered_edges[1]])
    
    return torch.stack([src_indices, dst_indices])



def convert_hom_to_het(hom_data):
    """Convert any one-hot encoded homogeneous graph to heterogeneous format."""
    het_data = HeteroData()
    
    # 1. Dynamically identify node types
    type_dict = get_node_type_mapping(hom_data)
    
    # 2. Create mappings and initialize nodes
    mapping_dict = create_mapping_dict(type_dict)
    
    # Add node types and features
    for t, nodes in type_dict.items():
        het_data[str(t)].x = torch.ones(len(nodes), 1)  # or use original features
    
    # 3. Process all possible edge type combinations
    for src_type in type_dict:
        for dst_type in type_dict:
            edge_type = (str(src_type), "connects", str(dst_type))
            
            src_nodes = type_dict[src_type]
            dst_nodes = type_dict[dst_type]
            
            edge_index = filter_and_remap_edges(
                hom_data,
                src_nodes,
                dst_nodes,
                mapping_dict[src_type],
                mapping_dict[dst_type]
            )
            
            if edge_index.shape[1] > 0:  # Only add if edges exist
                het_data[edge_type].edge_index = edge_index
    
    # Preserve labels if they exist
    if hasattr(hom_data, 'y'):
        het_data.y = hom_data.y
    
    return het_data

In [191]:
for i in (0, 10, 61, 115, 200):
    print(convert_hom_to_het(dataset[i]))

HeteroData(
  y=[1],
  0={ x=[24, 1] },
  1={ x=[13, 1] },
  (0, connects, 0)={ edge_index=[2, 86] },
  (0, connects, 1)={ edge_index=[2, 17] },
  (1, connects, 0)={ edge_index=[2, 17] },
  (1, connects, 1)={ edge_index=[2, 48] }
)
HeteroData(
  y=[1],
  2={ x=[4, 1] },
  (2, connects, 2)={ edge_index=[2, 12] }
)
HeteroData(
  y=[1],
  0={ x=[10, 1] },
  1={ x=[18, 1] },
  2={ x=[11, 1] },
  (0, connects, 0)={ edge_index=[2, 4] },
  (0, connects, 1)={ edge_index=[2, 18] },
  (0, connects, 2)={ edge_index=[2, 12] },
  (1, connects, 0)={ edge_index=[2, 18] },
  (1, connects, 1)={ edge_index=[2, 24] },
  (1, connects, 2)={ edge_index=[2, 23] },
  (2, connects, 0)={ edge_index=[2, 12] },
  (2, connects, 1)={ edge_index=[2, 23] },
  (2, connects, 2)={ edge_index=[2, 6] }
)
HeteroData(
  y=[1],
  0={ x=[11, 1] },
  1={ x=[11, 1] },
  2={ x=[20, 1] },
  (0, connects, 0)={ edge_index=[2, 4] },
  (0, connects, 1)={ edge_index=[2, 12] },
  (0, connects, 2)={ edge_index=[2, 21] },
  (1, connects,

In [195]:
import torch
from torch_geometric.data import HeteroData

def get_node_type_mapping(hom_data):
    """
    Dynamically identify node types from one-hot encoded features.
    
    e.g. [[0,0,1],      {0: [2]
          [1,0,0],  ->   1: [1]
          [0,1,0],       2: [0, 3]}
          [0,0,1]]
    """
    type_dict = {}
    for idx, x_row in enumerate(hom_data.x):
        type_idx = int(torch.argmax(x_row))
        if type_idx not in type_dict:
            type_dict[type_idx] = []
        type_dict[type_idx].append(idx)
    return type_dict

def create_mapping_dict(type_dict):
    """
    Create ID mapping dictionaries for all node types.

    Effectively, this resets node indecies for each node type in 
    the heterogeneous graph instead or maintaining the original 
    index from the homogeneous graph.    
    """
    return {t: {orig_idx: new_idx for new_idx, orig_idx in enumerate(nodes)} 
            for t, nodes in type_dict.items()}

def filter_and_remap_edges(hom_data, src_type_nodes, dst_type_nodes, 
                          src_map, dst_map, enforce_canonical):
    """Generalized edge filtering and index remapping."""
    src_tensor = torch.tensor(src_type_nodes)
    dst_tensor = torch.tensor(dst_type_nodes)

    # Create heterogeneous mask over homogeneous edge_index    
    mask = torch.isin(hom_data.edge_index[0], src_tensor) & torch.isin(hom_data.edge_index[1], dst_tensor)
    
    # Mask out edges other than source and destination edges
    filtered_edges = hom_data.edge_index[:, mask]
    
    # If there are no edges between source and destination nodes, create empty tensor for edge_index
    if filtered_edges.numel() == 0:
        return torch.empty((2, 0), dtype=torch.long)
    
    # Apply filter to get only edges between source and destination nodes
    src_indices = torch.tensor([src_map[orig.item()] for orig in filtered_edges[0]])
    dst_indices = torch.tensor([dst_map[orig.item()] for orig in filtered_edges[1]])
    
    # If graph is undirected, ensure the order is canonical
    if enforce_canonical:
        src_indices, dst_indices = torch.min(src_indices, dst_indices), torch.max(src_indices, dst_indices)
    
    # Stack to get final edge_index for source and destination nodes
    return torch.stack([src_indices, dst_indices])

def convert_hom_to_het(hom_data, expected_types=[0, 1, 2], enforce_canonical=False):
    """Convert any one-hot encoded homogeneous graph to heterogeneous format."""
    het_data = HeteroData()
    
    # Dynamically identify one-hot encoded node types
    type_dict = get_node_type_mapping(hom_data)

    # Create mappings and initialize nodes
    mapping_dict = create_mapping_dict(type_dict)
    
    # Add node types and features
    for t in expected_types:
        if t in type_dict:
            het_data[str(t)].x = hom_data.x[type_dict[t]]
        else:
            het_data[str(t)].x = torch.empty((0, hom_data.x.shape[1]))  # Placeholder if missing
    
    # Process all possible edge type combinations
    for src_type in expected_types:
        for dst_type in expected_types:
            edge_type = (str(src_type), "connects", str(dst_type))
            
            src_nodes = type_dict.get(src_type, [])
            dst_nodes = type_dict.get(dst_type, [])
            
            edge_index = filter_and_remap_edges(
                hom_data,
                src_nodes,
                dst_nodes,
                mapping_dict.get(src_type, {}),
                mapping_dict.get(dst_type, {}),
                enforce_canonical
            )
            
            het_data[edge_type].edge_index = edge_index
    
    # If label exists, add to graph
    if hasattr(hom_data, 'y'):
        het_data.y = hom_data.y
    
    return het_data


In [189]:
for i in (0, 10, 61, 115, 200):
    print(convert_hom_to_het(dataset[i]))

HeteroData(
  y=[1],
  0={ x=[24, 3] },
  1={ x=[13, 3] },
  2={ x=[0, 3] },
  (0, connects, 0)={ edge_index=[2, 86] },
  (0, connects, 1)={ edge_index=[2, 17] },
  (0, connects, 2)={ edge_index=[2, 0] },
  (1, connects, 0)={ edge_index=[2, 17] },
  (1, connects, 1)={ edge_index=[2, 48] },
  (1, connects, 2)={ edge_index=[2, 0] },
  (2, connects, 0)={ edge_index=[2, 0] },
  (2, connects, 1)={ edge_index=[2, 0] },
  (2, connects, 2)={ edge_index=[2, 0] }
)
HeteroData(
  y=[1],
  0={ x=[0, 3] },
  1={ x=[0, 3] },
  2={ x=[4, 3] },
  (0, connects, 0)={ edge_index=[2, 0] },
  (0, connects, 1)={ edge_index=[2, 0] },
  (0, connects, 2)={ edge_index=[2, 0] },
  (1, connects, 0)={ edge_index=[2, 0] },
  (1, connects, 1)={ edge_index=[2, 0] },
  (1, connects, 2)={ edge_index=[2, 0] },
  (2, connects, 0)={ edge_index=[2, 0] },
  (2, connects, 1)={ edge_index=[2, 0] },
  (2, connects, 2)={ edge_index=[2, 12] }
)
HeteroData(
  y=[1],
  0={ x=[10, 3] },
  1={ x=[18, 3] },
  2={ x=[11, 3] },
  (0, c