In [6]:
import torch
from torch_geometric.data import Data, HeteroData

from copy import deepcopy

## Test `HeteroData`

In [None]:
torch

In [2]:
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)

In [3]:
VIRTUAL_NODE_T: str = "virtual"
REAL_NODE_T: str = "real"
VIRTUAL_TO_REAL_EDGE_T: str = "v_to_r"
REAL_TO_VIRTUAL_EDGE_T: str = "r_to_v"
VIRTUAL_TO_VIRTUAL_EDGE_T: str = "v_to_v"
REAL_TO_REAL_EDGE_T: str = "r_to_r"

In [13]:
def convert_to_heterodata(data: Data) -> HeteroData:
    edge_type = (REAL_NODE_T, REAL_TO_REAL_EDGE_T, REAL_NODE_T)

    mapping = {REAL_NODE_T: {}, edge_type: {}}

    # Add node attributes
    if hasattr(data, "x"):
        mapping[REAL_NODE_T]["x"] = deepcopy(data.x)
    
    if hasattr(data, "pos"):
        mapping[REAL_NODE_T]["pos"] = deepcopy(data.pos)

    if hasattr(data, "y") and data.y:
        if data.y is None or data.y.shape[0] == 1:
            mapping[REAL_NODE_T]["y"] = None
        elif data.y.shape[0] == data.x.shape[0]:
            mapping[REAL_NODE_T]["y"] = deepcopy(data.y)
        else:
            raise ValueError(f"Invalid data.y: {data.y}")
    
    # Add edge attributes
    if hasattr(data, "edge_index"):
        mapping[edge_type]["edge_index"] = deepcopy(data.edge_index)
    
    if hasattr(data, "edge_attr"):
        mapping[edge_type]["edge_attr"] = deepcopy(data.edge_attr)
    

    return HeteroData(mapping)

In [14]:
test_hetero = convert_to_heterodata(data)


In [15]:
test_hetero.node_types

['real']

In [16]:
test_hetero

HeteroData(
  real={ x=[3, 1] },
  (real, r_to_r, real)={ edge_index=[2, 4] }
)

## `torch.cartesian_prod`

In [18]:
X = torch.tensor([1, 2, 3, 4], dtype=torch.long)
Y = torch.tensor([-1, -2, -3], dtype=torch.long)

In [19]:
torch.cartesian_prod(X, Y)

tensor([[ 1, -1],
        [ 1, -2],
        [ 1, -3],
        [ 2, -1],
        [ 2, -2],
        [ 2, -3],
        [ 3, -1],
        [ 3, -2],
        [ 3, -3],
        [ 4, -1],
        [ 4, -2],
        [ 4, -3]])

In [20]:
torch.range(0, 10)

  torch.range(0, 10)


tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])

In [21]:
torch.arange(0, 10)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [24]:
torch.cartesian_prod(torch.tensor([]), torch.tensor([]))


tensor([], size=(0, 2))

In [25]:
torch.arange(0)

tensor([], dtype=torch.int64)