In [None]:
import torch
import pyg_lib

print("PyTorch Version:", torch.__version__)
print("pyg-lib Version:", pyg_lib.__version__)


In [None]:
# # # Install required packages.
# !pip install torch==2.6.0
# !pip install torch-geometric torch-sparse torch-scatter torch-cluster torch-spline-conv pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install pytorch_frame
# !pip install -U sentence-transformers # we need another package for text encoding


In [None]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame


In [None]:

from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "rmse"
higher_is_better = False

Let's check out the training table just to make sure it looks fine.

In [None]:
train_table

Note that to load the data we did not require any deep learning libraries. Now we introduce the PyTorch Frame library, which is useful for encoding individual tables into initial node features.

In [None]:
# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = ".tutorials/data"

The first big move is to build a graph out of the database. Here we use our pre-prepared conversion function.

The source code can be found at: https://github.com/snap-stanford/relbench/blob/main/relbench/modeling/graph.py

Each node in the graph corresonds to a single row in the database. Crucially, PyTorch Frame stores whole tables as objects in a way that is compatibile with PyG minibatch sampling, meaning we can sample subgraphs as in https://arxiv.org/abs/1706.02216, and retrieve the relevant raw features.

PyTorch Frame also stores the `stype` (i.e., modality) of each column, and any specialized feature encoders (e.g., text encoders) to be used later. So we need to configure the `stype` for each column, for which we use a function that tries to automatically detect the `stype`.

In [None]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
col_to_stype_dict

If trying a new dataset, you should definitely check through this dict of `stype`s to check that look right, and manually change any mistakes by the auto-detection function.

Next we also define our text encoding model, which we use GloVe embeddings for speed and convenience. Feel free to try alternatives here.

In [None]:
from typing import List, Optional
from sentence_transformers import SentenceTransformer
from torch import Tensor


class GloveTextEmbedding:
    def __init__(self, device: Optional[torch.device
                                       ] = None):
        self.model = SentenceTransformer(
            "sentence-transformers/average_word_embeddings_glove.6B.300d",
            device=device,
        )

    def __call__(self, sentences: List[str]) -> Tensor:
        return torch.from_numpy(self.model.encode(sentences))



In [None]:
import os
import pickle

output_file = 'output_fi.pkl'

if os.path.exists(output_file):
    # File exists: load the data
    with open(output_file, 'rb') as f:
        data, col_stats_dict = pickle.load(f)
    print("Loaded data from file.")
else:
    # File does not exist: run the code and save the output
    from torch_frame.config.text_embedder import TextEmbedderConfig
    from relbench.modeling.graph import make_snapshot_graph
    # Ensure GloveTextEmbedding, device, db, col_to_stype_dict, and root_dir are defined
    text_embedder_cfg = TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), batch_size=256
    )
    
    data, col_stats_dict = make_snapshot_graph(
        db,
        col_to_stype_dict=col_to_stype_dict,              # specified column types
        main_table_name="races",                          # use 'races' table as timestamp reference
        interval_days=30,                                 # generate snapshots every 30 days
        text_embedder_cfg=text_embedder_cfg,              # chosen text encoder
        cache_dir=os.path.join(root_dir, "rel-f1_materialized_cache"),  # store materialized graph for convenience
    )
    
    # Save the output to a file for future use
    with open(output_file, 'wb') as f:
        pickle.dump((data, col_stats_dict), f)
    print("Data computed and saved to file.")


In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

def visualize_hetero_graph(snapshot, num_nodes=10):
    """Visualize a small portion of a heterogeneous graph with a specific edge type."""
    
    # Select the first available edge type (e.g., ('node_type1', 'relation', 'node_type2'))
    edge_type = list(snapshot.edge_index_dict.keys())[0]
    print(f"Using edge type: {edge_type}")  # Debugging

    # Extract edge index for the selected edge type
    edge_index = snapshot[edge_type].edge_index
    
    # Convert the selected subgraph to NetworkX
    nx_graph = nx.Graph()  # Create an empty NetworkX graph

    # Add edges from edge_index
    for i in range(edge_index.shape[1]):
        src = int(edge_index[0, i].item())
        dst = int(edge_index[1, i].item())
        nx_graph.add_edge(src, dst)  # Add edge to the NetworkX graph

    # Get a subset of nodes
    sampled_nodes = list(nx_graph.nodes)[:num_nodes]
    subgraph = nx_graph.subgraph(sampled_nodes)
    
    # Plot the graph
    plt.figure(figsize=(8, 6))
    nx.draw(subgraph, with_labels=True, node_color='lightblue', edge_color='gray', node_size=500, font_size=10)
    plt.title(f"Visualization of {num_nodes} nodes from edge type {edge_type}")
    plt.show()

# Select the first snapshot
snapshot = data[16]  # Assuming data is a list of snapshots

# Call the function
visualize_hetero_graph(snapshot, num_nodes=150)


We can now check out `data`, our main graph object. `data` is a heterogeneous and temporal graph, with node types given by the table it originates from.

In [None]:
data[1]

We can also check out the TensorFrame for one table like this:

In [None]:
col_stats_dict

This may be a little confusing at first, as in graph ML it is more standard to associate to the graph object `data` a tensor, e.g., `data.x` for which `data.x[idx]` is a 1D array/tensor storing all the features for node with index `idx`.

But actually this `data` object behaves similarly. For a given node type, e.g., `races` again, `data['races']` stores two pieces of information


In [None]:
print(len(data))  # Number of snapshots created


A `TensorFrame` object, and a timestamp for each node. The `TensorFrame` object acts analogously to the usual tensor of node features, and you can simply use indexing to retrieve the features of a single row (node), or group of nodes.

In [None]:
data[1]["races"].tf[10:20]

We can also check the edge indices between two different node types, such as `races` amd `circuits`. Note that the edges are also heterogenous, so we also need to specify which edge type we want to look at. Here we look at `f2p_curcuitId`, which are the directed edges pointing _from_ a race (the `f` stands for `foreign key`), _to_ the circuit at which te race happened (the `p` stands for `primary key`).

In [None]:
data.__sizeof__()

In [None]:
for edge_type in data[0].edge_types:
    print(f"Edge: {edge_type}, Shape: {data[0][edge_type].edge_index.shape}")


In [None]:
for node_type in data[14].node_types:
    print(f"Node: {node_type}, Feature Shape: {data[14][node_type].tf}")


In [None]:
data[1][("races", "f2p_circuitId", "circuits")]

Now we are ready to instantiate our data loaders. For this we will need to import PyTorch Geometric, our GNN library. Whilst we're at it let's add a seed.


In [None]:
from relbench.modeling.graph import get_node_train_table_input
from torch_geometric.loader import NeighborLoader

loader_dict = {}

if not data:
    raise ValueError("No snapshots found in `data`. Ensure make_snapshot_graph() returns a non-empty list.")

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]  # ✅ Get the main node type for training

    loader_dict[split] = []
for snapshot in data:  # ✅ Iterate over snapshots
    # ✅ Ensure time_attr is set if input_time exists
    time_attr = "time" if table_input.time is not None and "time" in snapshot.get(entity_table, {}) else None

    # 🔥 Check if entity_table exists in snapshot
    print(f"Available node types in snapshot[0]: {snapshot.node_types}")
    if entity_table not in snapshot.node_types:
        print(f"⚠️ Warning: {entity_table} not found in snapshot. Skipping this snapshot.")
        continue  # Skip this snapshot

    # 🔥 Check if num_nodes is valid
    if snapshot[entity_table].num_nodes is None:
        print(f"⚠️ Warning: {entity_table} has no valid nodes in snapshot. Skipping this snapshot.")
        continue  # Skip this snapshot

    loader = NeighborLoader(
        snapshot,  # ✅ Use snapshot instead of a single data graph
        num_neighbors=[10, 10],  # ✅ Adjust depth if needed
        time_attr=time_attr,  # ✅ Ensure "time" exists for entity_table
        input_nodes=(entity_table, None),  # ✅ Use tuple format (node_type, indices)
        input_time=table_input.time if time_attr is not None else None,  # ✅ Fix input_time conflict
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=(split == "train"),
        num_workers=0,
        persistent_workers=False,
    )

    loader_dict[split].append(loader)


for snapshot_loader in loader_dict["train"]:
    for batch in snapshot_loader:
        print(batch)  # Process each batch per snapshot


In [None]:
print(f"Available node types in snapshot[0]: {data[0].node_types}")


Now we need our model...




In [None]:
from typing import Any, Dict
import torch
from torch import nn, Tensor
from torch.nn import BCEWithLogitsLoss, Embedding, ModuleDict
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_frame.data.stats import StatType
from torch_geometric.typing import NodeType
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder

# ✅ LSTM-Based Temporal Encoder for Evolving Node Embeddings
class LSTMBasedTemporalEncoder(torch.nn.Module):
    def __init__(self, node_types, channels):
        super().__init__()
        self.lstm_dict = torch.nn.ModuleDict({
            node_type: torch.nn.LSTM(input_size=channels, hidden_size=channels, batch_first=True)
            for node_type in node_types
        })

    def forward(self, h_dict, time_dict, batch_dict):
        updated_h_dict = {}
        for node_type, lstm in self.lstm_dict.items():
            if node_type in h_dict and h_dict[node_type].size(0) > 0:  # ✅ Skip empty inputs
                h, _ = lstm(h_dict[node_type].unsqueeze(0))  # Apply LSTM
                updated_h_dict[node_type] = h.squeeze(0)
            else:
                updated_h_dict[node_type] = h_dict.get(node_type, torch.zeros(0))  # ✅ Keep existing values
        return updated_h_dict


class Model(torch.nn.Module):
    def __init__(
        self,
        data: List[HeteroData],
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        shallow_list: List[NodeType] = [],
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[0][node_type].tf.col_names_dict  # ✅ Use the first snapshot
                for node_type in data[0].node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data[0].node_types if "time" in data[0][node_type]
            ],
            channels=channels,
        )

        self.gnn = HeteroGraphSAGE(
            node_types=data[0].node_types,  # ✅ Use the first snapshot to extract node types
            edge_types=data[0].edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )

        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )

        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data[0].num_nodes_dict.get(node, 0), channels)
                for node in shallow_list if data[0].num_nodes_dict.get(node, 0) > 0
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch.time_dict.get(entity_table, None)

        x_dict = self.encoder({node_type: batch[node_type].tf for node_type in batch.node_types})

        if seed_time is None:
            print(f"⚠️ Warning: `{entity_table}` missing time information.")
            rel_time_dict = {}
        else:
            rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, batch.batch_dict)

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] += rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] += embedding(batch[node_type].n_id)

        x_dict = self.gnn(x_dict, batch.edge_index_dict)

        return self.head(x_dict[entity_table][: seed_time.size(0)])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
    id_awareness=True,  # Enable ID awareness
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_fn = BCEWithLogitsLoss()
epochs = 10


We also need standard train/test loops

In [None]:
from tqdm import tqdm
import numpy as np
import copy

def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for snapshot_loader in tqdm(loader_dict["train"]):  # ✅ Iterate over snapshot loaders
        for batch in snapshot_loader:  # ✅ Iterate over batches
            batch = batch.to(device)  # ✅ Now applies correctly to `HeteroData`


        if task.entity_table not in batch:
            print(f"⚠️ Warning: `{task.entity_table}` missing in batch. Skipping.")
            continue  # Skip invalid batches

        optimizer.zero_grad()
        pred = model(batch, task.entity_table)
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[task.entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum if count_accum > 0 else float("inf")


@torch.no_grad()
def test(loader):
    model.eval()
    pred_list = []
    for snapshot_loader in loader:  # ✅ Iterate over snapshot loaders
        for batch in snapshot_loader:  # ✅ Iterate over batches
            batch = batch.to(device)  # ✅ Now applies correctly to `HeteroData`

            pred = model(batch, task.entity_table)
            pred = pred.view(-1) if pred.size(1) == 1 else pred
            pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()


Now we are ready to train!

In [None]:
# Training Loop
for epoch in range(epochs):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    print(f"Epoch: {epoch+1}, Train loss: {train_loss}")

# Test Model
test_pred = test(loader_dict["test"])
print(f"Test predictions: {test_pred}")