In [1]:
# Install required packages.
!pip install torch==2.4.0
!pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cu121.html
!pip install pytorch_frame
!pip install relbench

Looking in links: https://data.pyg.org/whl/torch-2.4.0+cu121.html


In [2]:
import os
import torch
import relbench

relbench.__version__

'1.1.0'

In [3]:
import numpy as np

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

Let's check out the training table just to make sure it looks fine.

In [4]:
train_table

Table(df=
           date  driverId  position
0    2004-07-05        10     10.75
1    2004-07-05        47     12.00
2    2004-03-07         7     15.00
3    2004-01-07        10      9.00
4    2003-09-09        52     13.00
...         ...       ...       ...
7448 1995-08-22        96     15.75
7449 1975-06-08       228      8.00
7450 1965-05-31       418     16.00
7451 1961-08-20       467     37.00
7452 1954-05-29       677     30.00

[7453 rows x 3 columns],
  fkey_col_to_pkey_table={'driverId': 'drivers'},
  pkey_col=None,
  time_col=date)

Note that to load the data we did not require any deep learning libraries. Now we introduce the PyTorch Frame library, which is useful for encoding individual tables into initial node features.

In [5]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

cuda


The first big move is to build a graph out of the database. Here we use our pre-prepared conversion function.

The source code can be found at: https://github.com/snap-stanford/relbench/blob/main/relbench/modeling/graph.py

Each node in the graph corresonds to a single row in the database. Crucially, PyTorch Frame stores whole tables as objects in a way that is compatibile with PyG minibatch sampling, meaning we can sample subgraphs as in https://arxiv.org/abs/1706.02216, and retrieve the relevant raw features.

PyTorch Frame also stores the `stype` (i.e., modality) of each column, and any specialized feature encoders (e.g., text encoders) to be used later. So we need to configure the `stype` for each column, for which we use a function that tries to automatically detect the `stype`.

In [6]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.08 seconds.


{'qualifying': {'qualifyId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'driverId': <stype.numerical: 'numerical'>,
  'constructorId': <stype.numerical: 'numerical'>,
  'number': <stype.numerical: 'numerical'>,
  'position': <stype.numerical: 'numerical'>,
  'date': <stype.timestamp: 'timestamp'>},
 'drivers': {'driverId': <stype.numerical: 'numerical'>,
  'driverRef': <stype.text_embedded: 'text_embedded'>,
  'code': <stype.text_embedded: 'text_embedded'>,
  'forename': <stype.text_embedded: 'text_embedded'>,
  'surname': <stype.text_embedded: 'text_embedded'>,
  'dob': <stype.timestamp: 'timestamp'>,
  'nationality': <stype.text_embedded: 'text_embedded'>},
 'results': {'resultId': <stype.numerical: 'numerical'>,
  'raceId': <stype.numerical: 'numerical'>,
  'driverId': <stype.numerical: 'numerical'>,
  'constructorId': <stype.numerical: 'numerical'>,
  'number': <stype.numerical: 'numerical'>,
  'grid': <stype.numerical: 'numerical'>,
  'position':

If trying a new dataset, you should definitely check through this dict of `stype`s to check that look right, and manually change any mistakes by the auto-detection function.

Next we also define our text encoding model, which we use GloVe embeddings for speed and convenience. Feel free to try alternatives here.

In [7]:
!pip install torchvision==0.19



In [8]:
!pip install -U sentence-transformers # we need another package for text encoding
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))





In [9]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=GloveTextEmbedding(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)
  tf_dict, col_stats = torch.load(path)


We can now check out `data`, our main graph object. `data` is a heterogeneous and temporal graph, with node types given by the table it originates from.

In [10]:
data

HeteroData(
  qualifying={
    tf=TensorFrame([4082, 3]),
    time=[4082],
  },
  drivers={ tf=TensorFrame([857, 6]) },
  results={
    tf=TensorFrame([20323, 11]),
    time=[20323],
  },
  circuits={ tf=TensorFrame([77, 7]) },
  standings={
    tf=TensorFrame([28115, 4]),
    time=[28115],
  },
  constructor_results={
    tf=TensorFrame([9408, 2]),
    time=[9408],
  },
  constructor_standings={
    tf=TensorFrame([10170, 4]),
    time=[10170],
  },
  races={
    tf=TensorFrame([820, 5]),
    time=[820],
  },
  constructors={ tf=TensorFrame([211, 3]) },
  (qualifying, f2p_raceId, races)={ edge_index=[2, 4082] },
  (races, rev_f2p_raceId, qualifying)={ edge_index=[2, 4082] },
  (qualifying, f2p_driverId, drivers)={ edge_index=[2, 4082] },
  (drivers, rev_f2p_driverId, qualifying)={ edge_index=[2, 4082] },
  (qualifying, f2p_constructorId, constructors)={ edge_index=[2, 4082] },
  (constructors, rev_f2p_constructorId, qualifying)={ edge_index=[2, 4082] },
  (results, f2p_raceId, races)=

We can also check out the TensorFrame for one table like this:

In [11]:
data["races"].tf

TensorFrame(
  num_cols=5,
  num_rows=820,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

This may be a little confusing at first, as in graph ML it is more standard to associate to the graph object `data` a tensor, e.g., `data.x` for which `data.x[idx]` is a 1D array/tensor storing all the features for node with index `idx`.

But actually this `data` object behaves similarly. For a given node type, e.g., `races` again, `data['races']` stores two pieces of information


In [12]:
list(data["races"].keys())

['tf', 'time']

A `TensorFrame` object, and a timestamp for each node. The `TensorFrame` object acts analogously to the usual tensor of node features, and you can simply use indexing to retrieve the features of a single row (node), or group of nodes.

In [13]:
data["races"].tf[10]

TensorFrame(
  num_cols=5,
  num_rows=1,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

In [14]:
data["races"].tf[10:20]

TensorFrame(
  num_cols=5,
  num_rows=10,
  categorical (1): ['year'],
  numerical (1): ['round'],
  timestamp (2): ['date', 'time'],
  embedding (1): ['name'],
  has_target=False,
  device='cpu',
)

We can also check the edge indices between two different node types, such as `races` amd `circuits`. Note that the edges are also heterogenous, so we also need to specify which edge type we want to look at. Here we look at `f2p_curcuitId`, which are the directed edges pointing _from_ a race (the `f` stands for `foreign key`), _to_ the circuit at which te race happened (the `p` stands for `primary key`).

In [15]:
data[("races", "f2p_circuitId", "circuits")]

{'edge_index': tensor([[  0,   1,   2,  ..., 817, 818, 819],
        [  8,   5,  18,  ...,  21,  17,  23]])}

Now we are ready to instantiate our data loaders. For this we will need to import PyTorch Geometric, our GNN library. Whilst we're at it let's add a seed.


In [16]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

Now we need our model...




In [40]:
from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter

import torch_geometric.backend
import torch_geometric.typing
from torch_geometric import is_compiling
from torch_geometric.index import index2ptr
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.typing import (
    Adj,
    OptTensor,
    SparseTensor,
    pyg_lib,
    torch_sparse,
)
from torch_geometric.utils import index_sort, one_hot, scatter, spmm


def masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj:
    print(f"edge_mask shape: {edge_mask.shape}")
    if isinstance(edge_index, Tensor):
        return edge_index[:, edge_mask]
    return torch_sparse.masked_select_nnz(edge_index, edge_mask, layout='coo')


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper.

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    .. note::
        This implementation is as memory-efficient as possible by iterating
        over each individual relation type.
        Therefore, it may result in low GPU utilization in case the graph has a
        large number of relations.
        As an alternative approach, :class:`FastRGCNConv` does not iterate over
        each individual type, but may consume a large amount of memory to
        compensate.
        We advise to check out both implementations to see which one fits your
        needs.

    .. note::
        :class:`RGCNConv` can use `dynamic shapes
        <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index
        .html#work_dynamic_shapes>`_, which means that the shape of the interim
        tensors can be determined at runtime.
        If your device doesn't support dynamic shapes, use
        :class:`FastRGCNConv` instead.

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
            In case no input features are given, this argument should
            correspond to the number of nodes in your graph.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        num_bases (int, optional): If set, this layer will use the
            basis-decomposition regularization scheme where :obj:`num_bases`
            denotes the number of bases to use. (default: :obj:`None`)
        num_blocks (int, optional): If set, this layer will use the
            block-diagonal-decomposition regularization scheme where
            :obj:`num_blocks` denotes the number of blocks to use.
            (default: :obj:`None`)
        aggr (str, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"mean"`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        is_sorted (bool, optional): If set to :obj:`True`, assumes that
            :obj:`edge_index` is sorted by :obj:`edge_type`. This avoids
            internal re-sorting of the data and can improve runtime and memory
            efficiency. (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        num_relations: int,
        num_bases: Optional[int] = None,
        num_blocks: Optional[int] = None,
        aggr: str = 'mean',
        root_weight: bool = True,
        is_sorted: bool = False,
        bias: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', aggr)
        super().__init__(node_dim=0, **kwargs)

        if num_bases is not None and num_blocks is not None:
            raise ValueError('Can not apply both basis-decomposition and '
                             'block-diagonal-decomposition at the same time.')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.num_blocks = num_blocks
        self.is_sorted = is_sorted

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)
        self.in_channels_l = in_channels[0]

        self._use_segment_matmul_heuristic_output: torch.jit.Attribute(
            None, Optional[float])

        if num_bases is not None:
            self.weight = Parameter(
                torch.empty(num_bases, in_channels[0], out_channels))
            self.comp = Parameter(torch.empty(num_relations, num_bases))

        elif num_blocks is not None:
            assert (in_channels[0] % num_blocks == 0
                    and out_channels % num_blocks == 0)
            self.weight = Parameter(
                torch.empty(num_relations, num_blocks,
                            in_channels[0] // num_blocks,
                            out_channels // num_blocks))
            self.register_parameter('comp', None)

        else:
            self.weight = Parameter(
                torch.empty(num_relations, in_channels[0], out_channels))
            self.register_parameter('comp', None)

        if root_weight:
            self.root = Parameter(torch.empty(in_channels[1], out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = Parameter(torch.empty(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        glorot(self.weight)
        glorot(self.comp)
        glorot(self.root)
        zeros(self.bias)

    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
                edge_index: Adj, edge_type: OptTensor = None):
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor or tuple, optional): The input node features.
                Can be either a :obj:`[num_nodes, in_channels]` node feature
                matrix, or an optional one-dimensional node index tensor (in
                which case input features are treated as trainable node
                embeddings).
                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting
                source and destination node features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_type (torch.Tensor, optional): The one-dimensional relation
                type/index for each edge in :obj:`edge_index`.
                Should be only :obj:`None` in case :obj:`edge_index` is of type
                :class:`torch_sparse.SparseTensor`. (default: :obj:`None`)
        """
        # Convert input features to a pair of node features or node indices.
        print("FORWARD DEBUGGING PRINTS")
        print(f"x: {x}")
        print(f"x_l: {x[0].shape}")
        print(f"x_r: {x[1].shape}")

        x_l: OptTensor = None
        if isinstance(x, tuple):
            x_l = x[0]
        else:
            x_l = x
        if x_l is None:
            x_l = torch.arange(self.in_channels_l, device=self.weight.device)

        x_r: Tensor = x_l
        if isinstance(x, tuple):
            x_r = x[1]

        size = (x_l.size(0), x_r.size(0))

        # print(f"edge_index: {edge_index}")
        # print(f"edge_index.storage.value(): {edge_index.storage.value()}")
        if isinstance(edge_index, SparseTensor):
            edge_type = edge_index.storage.value()
        assert edge_type is not None

        # propagate_type: (x: Tensor, edge_type_ptr: OptTensor)
        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)

        weight = self.weight
        if self.num_bases is not None:  # Basis-decomposition =================
            weight = (self.comp @ weight.view(self.num_bases, -1)).view(
                self.num_relations, self.in_channels_l, self.out_channels)

        if self.num_blocks is not None:  # Block-diagonal-decomposition =====

            if not torch.is_floating_point(
                    x_r) and self.num_blocks is not None:
                raise ValueError('Block-diagonal decomposition not supported '
                                 'for non-continuous input features.')

            for i in range(self.num_relations):
                tmp = masked_edge_index(edge_index, edge_type == i)
                h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)
                h = h.view(-1, weight.size(1), weight.size(2))
                h = torch.einsum('abc,bcd->abd', h, weight[i])
                out = out + h.contiguous().view(-1, self.out_channels)

        else:  # No regularization/Basis-decomposition ========================

            use_segment_matmul = torch_geometric.backend.use_segment_matmul
            # If `use_segment_matmul` is not specified, use a simple heuristic
            # to determine whether `segment_matmul` can speed up computation
            # given the observed input sizes:
            if use_segment_matmul is None:
                print(f"edge_index.size{edge_index.size}")
                print(f"edge_type: {edge_type}, edge_type.shape: {edge_type.shape}")
                segment_count = scatter(torch.ones_like(edge_type), edge_type,
                                        dim_size=self.num_relations)

                self._use_segment_matmul_heuristic_output = (
                    torch_geometric.backend.use_segment_matmul_heuristic(
                        num_segments=self.num_relations,
                        max_segment_size=int(segment_count.max()),
                        in_channels=self.weight.size(1),
                        out_channels=self.weight.size(2),
                    ))

                assert self._use_segment_matmul_heuristic_output is not None
                use_segment_matmul = self._use_segment_matmul_heuristic_output

            if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM
                    and not is_compiling() and self.num_bases is None
                    and x_l.is_floating_point()
                    and isinstance(edge_index, Tensor)):

                if not self.is_sorted:
                    if (edge_type[1:] < edge_type[:-1]).any():
                        edge_type, perm = index_sort(
                            edge_type, max_value=self.num_relations)
                        edge_index = edge_index[:, perm]
                edge_type_ptr = index2ptr(edge_type, self.num_relations)
                out = self.propagate(edge_index, x=x_l,
                                     edge_type_ptr=edge_type_ptr, size=size)
            else:
                for i in range(self.num_relations):
                    print(edge_type.shape)
                    tmp = masked_edge_index(edge_index, edge_type == i)

                    if not torch.is_floating_point(x_r):
                        out = out + self.propagate(
                            tmp,
                            x=weight[i, x_l],
                            edge_type_ptr=None,
                            size=size,
                        )
                    else:
                        print(f"out.shape: {out.shape}")
                        print(f"x_l: {x_l}")
                        print(f"x_l.shape{x_l.shape}")
                        print(f"size: {size}")
                        h = self.propagate(tmp, x=x_l, edge_type_ptr=None,
                                           size=size)
                        print(f"h.shape{h.shape}")
                        print(f"weight[i].shape{weight[i].shape}")
                        assert h.shape[0] == size[1]
                        out = out + (h @ weight[i])

        root = self.root
        if root is not None:
            if not torch.is_floating_point(x_r):
                out = out + root[x_r]
            else:
                out = out + x_r @ root

        if self.bias is not None:
            out = out + self.bias

        return out

    def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:
        if (torch_geometric.typing.WITH_SEGMM and not is_compiling()
                and edge_type_ptr is not None):
            # TODO Re-weight according to edge type degree for `aggr=mean`.
            return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)

        return x_j

    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
        if isinstance(adj_t, SparseTensor):
            adj_t = adj_t.set_value(None)
        return spmm(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_relations={self.num_relations})')

In [18]:
from typing import Any, Dict, List, Optional

import torch
import torch_frame
from torch import Tensor
from torch_frame.data.stats import StatType
from torch_frame.nn.models import ResNet
from torch_geometric.nn import HeteroConv, LayerNorm, PositionalEncoding, SAGEConv, GCNConv
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.typing import SparseTensor

In [46]:
class HeteroRGCN(torch.nn.Module):
  def __init__(
      self,
      node_types: List[NodeType],
      edge_types: List[EdgeType],
      channels: int,
      aggr: str = "mean",
      num_layers: int = 2):
      super().__init__()

      self.edge_types = edge_types
      #{relation type A: [first edge with this relation has edge category 0, second edge with this relation has edge category 1, etc.]}
      edge_category_set = set()
      for edge_type in edge_types:
          edge_category_set.add(edge_type[1])
      self.edge_category_map = {edge_category: i for i, edge_category in enumerate(edge_category_set)}

      self.convs = torch.nn.ModuleList()
      for _ in range(num_layers):
          conv = HeteroConv(
              {
                  edge_type: RGCNConv(channels, channels, num_relations=len(edge_types), aggr = aggr)
                  for edge_type in edge_types #edge type refers to triplet
              },
              aggr="sum",  # Aggregation across edge types
          )
          self.convs.append(conv)

      self.norms = torch.nn.ModuleList()
      for _ in range(num_layers):
          norm_dict = torch.nn.ModuleDict()
          for node_type in node_types:
              norm_dict[node_type] = LayerNorm(channels, mode="node")
          self.norms.append(norm_dict)

  def reset_parameters(self):
      for conv in self.convs:
          conv.reset_parameters()
      for norm_dict in self.norms:
          for norm in norm_dict.values():
              norm.reset_parameters()

  def convert_to_sparse(self, edge_index_dict):
      for edge_type, edge_index in edge_index_dict.items():
          edge_values = torch.ones(edge_index.size(1), device = device).to(torch.int64)
          edge_index_dict[edge_type] = SparseTensor(row=edge_index[0].to(torch.int64), col=edge_index[1].to(torch.int64), value = edge_values)

  def remove_none_relation_types(self, edge_index_dict):
      none_edges = [key for key, value in edge_index_dict.items() if value.size(1) == 0]
      for key in none_edges:
          del edge_index_dict[key]


  def forward(
      self,
      x_dict: Dict[NodeType, Tensor],
      edge_index_dict: Dict[NodeType, Tensor],
      num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
      num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
  ) -> Dict[NodeType, Tensor]:

      edge_category_dict = {
        edge_type: torch.full(
            (edge_index_dict[edge_type].size(1),),  # Use size(1) of the full tensor
            self.edge_category_map[edge_type[1]]
        )
        for edge_type in edge_index_dict.keys()
      }
      self.remove_none_relation_types(edge_index_dict)
      self.convert_to_sparse(edge_index_dict)
      print(f"x_dict size: {x_dict}")
      for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
          x_dict = conv(x_dict, edge_index_dict)
          x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
          x_dict = {key: x.relu() for key, x in x_dict.items()}

      return x_dict

In [47]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder


class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroRGCN(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

We also need standard train/test loops

In [48]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)
        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Now we are ready to train!

In [49]:
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    val_metrics = task.evaluate(val_pred, val_table)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = task.evaluate(val_pred, val_table)
print(f"Best Val metrics: {val_metrics}")

test_pred = test(loader_dict["test"])
test_metrics = task.evaluate(test_pred)
print(f"Best test metrics: {test_metrics}")

  0%|          | 0/15 [00:00<?, ?it/s]


x_dict size: {'qualifying': tensor([[-0.0507, -0.1374, -0.1300,  ..., -0.9554,  0.3449,  0.0474],
        [ 0.0834, -0.0418,  0.2918,  ..., -0.3810, -0.3079,  0.3386],
        [ 0.1228,  0.1362,  0.3702,  ..., -0.9824, -0.1003,  0.5063],
        ...,
        [-0.9434,  0.1796,  0.7265,  ..., -0.8174, -0.5249,  0.5687],
        [-0.1571,  0.3097,  0.6933,  ..., -0.4593, -0.0697,  0.3844],
        [ 0.0187, -0.0658, -0.1479,  ..., -0.2611,  0.7999,  0.7034]],
       device='cuda:0', grad_fn=<AddBackward0>), 'drivers': tensor([[-0.6597, -0.7554, -0.9443,  ..., -0.0114, -0.4321,  0.6976],
        [-0.1199, -0.5705, -0.5053,  ...,  0.3269,  0.4034, -0.4012],
        [ 0.3399, -0.8939, -0.5962,  ..., -0.1072, -0.2863,  0.4295],
        ...,
        [-0.1507, -0.7998, -0.2184,  ...,  0.3464, -0.0235, -0.3919],
        [-0.1319, -0.7740, -0.4964,  ...,  0.2914, -0.3622, -0.2314],
        [-0.4548, -0.1001, -0.5029,  ..., -0.4193, -0.3772,  0.6506]],
       device='cuda:0', grad_fn=<AddmmBackwa

AssertionError: 

In [None]:
print(torch.cuda.is_available())
print(torch.version.cuda)
print(device)

In [None]:
from torch_sparse import SparseTensor

# Example SparseTensor creation on CUDA
row = torch.tensor([0, 1, 2], device='cuda')
col = torch.tensor([1, 2, 0], device='cuda')
sparse_tensor = SparseTensor(row=row, col=col)
print(sparse_tensor)
print(torch_sparse.__version__)  # Ensure the version matches the one you installed