<a href="https://colab.research.google.com/github/Andrea-1704/Pytorch_Geometric_tutorial/blob/main/train_model_baseline_f1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Summary
In this project we are going to implement the Graphormer architecture following the paper of You et al., 2021.

We are first going to implement a simplified version of the proposed architecture going to lonely implement the self_attention module in a multi-head fashion for heterogeneous graph (the one provided by relbench and already discussed in the previous works).

Then, we are going to simulate the actual graphormer implementation by introducing the spatial encoding, centrality encoding and the edge encoding into the attention computation.

### For colab:

In [1]:
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-geometric==2.6.0 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html

# !pip install pytorch_frame[full]==1.2.2
# !pip install relbench[full]==1.0.0
# !pip uninstall -y pyg_lib torch  # Uninstall current versions
# !pip install torch==2.6.0  # Reinstall your desired PyTorch version
# !pip install --no-cache-dir git+https://github.com/pyg-team/pyg-lib.git # Install pyg-lib; --no-cache-dir ensures a fresh install


#if having trouble with pyg lib try the following (for cpu), or see chatGPT history chat:
#!pip install pyg_lib -f https://data.pyg.org/whl/torch-2.5.0+cpu.html

In [2]:
import os
import torch
import relbench
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task

dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
loss_fn = L1Loss()
tune_metric = "mae"
higher_is_better = False

Note that to load the data we did not require any deep learning libraries. Now we introduce the PyTorch Frame library, which is useful for encoding individual tables into initial node features.

In [3]:
import os
import math
import numpy as np
from tqdm import tqdm

import torch
import torch_geometric
import torch_frame

# Some book keeping
from torch_geometric.seed import seed_everything

seed_everything(42)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)  # check that it's cuda if you want it to run in reasonable time!
root_dir = "./data"

  from .autonotebook import tqdm as notebook_tqdm


cpu


The first big move is to build a graph out of the database. Here we use our pre-prepared conversion function.

The source code can be found at: https://github.com/snap-stanford/relbench/blob/main/relbench/modeling/graph.py

Each node in the graph corresonds to a single row in the database. Crucially, PyTorch Frame stores whole tables as objects in a way that is compatibile with PyG minibatch sampling, meaning we can sample subgraphs as in https://arxiv.org/abs/1706.02216, and retrieve the relevant raw features.

PyTorch Frame also stores the `stype` (i.e., modality) of each column, and any specialized feature encoders (e.g., text encoders) to be used later. So we need to configure the `stype` for each column, for which we use a function that tries to automatically detect the `stype`.

In [4]:
from relbench.modeling.utils import get_stype_proposal

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)

Loading Database object from C:\Users\andrea\AppData\Local\relbench\relbench\Cache/rel-f1/db...
Done in 0.06 seconds.


If trying a new dataset, you should definitely check through this dict of `stype`s to check that look right, and manually change any mistakes by the auto-detection function.

Next we also define our text encoding model, which we use GloVe embeddings for speed and convenience. Feel free to try alternatives here.

In [5]:
# import torch
# from typing import List, Optional
# from sentence_transformers import SentenceTransformer
# from torch import Tensor


# class GloveTextEmbedding:
#     def __init__(self, device: Optional[torch.device
#                                        ] = None):
#         self.model = SentenceTransformer(
#             "sentence-transformers/average_word_embeddings_glove.6B.300d",
#             device=device,
#         )

#     def __call__(self, sentences: List[str]) -> Tensor:
#         return torch.from_numpy(self.model.encode(sentences))

import torch
import numpy as np
from collections import defaultdict
import requests
from io import StringIO

class LightweightGloveEmbedder:
    def __init__(self, device=None):
        self.device = device
        self.embeddings = defaultdict(lambda: np.zeros(300))
        self._load_embeddings()

    def _load_embeddings(self):
        try:
            #(senza bisogno di estrarre zip)
            url = "https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.300d.txt"
            response = requests.get(url)
            response.raise_for_status()

            for line in StringIO(response.text):
                parts = line.split()
                word = parts[0]
                vector = np.array(parts[1:], dtype=np.float32)
                self.embeddings[word] = vector
        except Exception as e:
            print(f"Warning: Couldn't load GloVe embeddings ({str(e)}). Using zero vectors.")

    def __call__(self, sentences):
        results = []
        for text in sentences:
            words = text.lower().split()
            vectors = [self.embeddings[w] for w in words if w in self.embeddings]
            if vectors:
                avg_vector = np.mean(vectors, axis=0)
            else:
                avg_vector = np.zeros(300)
            results.append(avg_vector)

        tensor = torch.tensor(np.array(results), dtype=torch.float32)
        return tensor.to(self.device) if self.device else tensor

In [6]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph

text_embedder_cfg = TextEmbedderConfig(
    text_embedder=LightweightGloveEmbedder(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    #Solution if not working: !pip install --upgrade torch torchvision transformers
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)





We can now check out `data`, our main graph object. `data` is a heterogeneous and temporal graph, with node types given by the table it originates from.

We can also check out the TensorFrame for one table like this:

This may be a little confusing at first, as in graph ML it is more standard to associate to the graph object `data` a tensor, e.g., `data.x` for which `data.x[idx]` is a 1D array/tensor storing all the features for node with index `idx`.

But actually this `data` object behaves similarly. For a given node type, e.g., `races` again, `data['races']` stores two pieces of information


A `TensorFrame` object, and a timestamp for each node. The `TensorFrame` object acts analogously to the usual tensor of node features, and you can simply use indexing to retrieve the features of a single row (node), or group of nodes.

Now we are ready to instantiate our data loaders. For this we will need to import PyTorch Geometric, our GNN library. Whilst we're at it let's add a seed.


In [7]:
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader

loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )
    entity_table = table_input.nodes[0]
    loader_dict[split] = NeighborLoader(
        data, # the complete graph
        num_neighbors=[
            128 for _ in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )

Notice that the code:


is designed for being a particular data loader and is used to sample sub-graphs from an heterogeneous graph during the training/test phase. This loader manages the batches and supports the temporal sampling thanks to the time_attr attribute.

Is also important to notice that num_neighbors=[128 for _ in range(2)] is telling that for each node we sample data from 128 neighbours and a maximum distance of 2 layers.

Time_attribute indicates that the graph is temporal and we use this attribute to SORT the nodes based on the time.

This code is creating **BATCHES**. The number of batches is determined by dividing the total number of nodes by the `batch_size` (e.g., 512). Each batch contains a subgraph centered around up to `batch_size` input nodes.

For each input node, up to 128 neighbors are sampled at the first level, and for each of those, up to 128 neighbors are sampled at the second level (`num_neighbors=[128, 128]`). This creates a subgraph with nodes up to 2 hops away. While the theoretical maximum size of a subgraph is `128 * 128` nodes per input node, the actual size is limited by the graph's structure and overlaps between neighbors. The `time_attr="time"` ensures that nodes are sorted and sampled based on their temporal attributes.

This approach allows training on large graphs by processing smaller subgraphs in memory, making it scalable and efficient for stochastic gradient descent (SGD).

Now we need our model...




# Graphormer

In [8]:
import torch
from torch import nn
from torch_geometric.nn import Linear
from torch_geometric.utils import softmax

class HeteroGraphormerLayer(nn.Module):
    def __init__(self, channels, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.channels = channels
        self.head_dim = channels // num_heads

        assert self.channels % num_heads == 0, "channels must be divisible by num_heads"

        self.q_lin = Linear(channels, channels)
        self.k_lin = Linear(channels, channels)
        self.v_lin = Linear(channels, channels)
        self.out_lin = Linear(channels, channels)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(channels)

        #this is optional
        self.edge_type_bias = nn.ParameterDict()

    def forward(self, x_dict, edge_index_dict):
        out_dict = {k: torch.zeros_like(v) for k, v in x_dict.items()}

        for edge_type, edge_index in edge_index_dict.items():
            src_type, _, dst_type = edge_type
            x_src, x_dst = x_dict[src_type], x_dict[dst_type]
            src, dst = edge_index

            Q = self.q_lin(x_dst).view(-1, self.num_heads, self.head_dim)
            K = self.k_lin(x_src).view(-1, self.num_heads, self.head_dim)
            V = self.v_lin(x_src).view(-1, self.num_heads, self.head_dim)

            attn_scores = (Q[dst] * K[src]).sum(dim=-1) / self.head_dim**0.5

            
            bias_name = "__".join(edge_type)
            if bias_name not in self.edge_type_bias:
                self.edge_type_bias[bias_name] = nn.Parameter(torch.randn(1))
            attn_scores = attn_scores + self.edge_type_bias[bias_name]

            attn_weights = softmax(attn_scores, dst)
            attn_weights = self.dropout(attn_weights)

            out = V[src] * attn_weights.unsqueeze(-1)
            out = out.view(-1, self.channels)

            out_dict[dst_type].index_add_(0, dst, out)

        for node_type in out_dict:
            out_dict[node_type] = self.norm(out_dict[node_type] + x_dict[node_type])

        return out_dict

Notice that the "edge_type_bias" term is basicly a learnable bias term that indicates the type of edge (since we are dealing with heterogeneous graphs).

This is the standard transformer layer.

In the You et al., 2021 they added some specific concept that should be included in the computation of the attention scores and will be included by us in the following part of this notebook.

In particular what we are still missing out is:
1. centrality encoding: a measurement of the "importance" of a given node";
2. Spatial encoding: a measurement of the spatial informations of the graph. It may be the shortest path between node i and j, or -1 if the graph is unconnected for those two nodes;
3. edge encoding.

In [9]:
class HeteroGraphormer(torch.nn.Module):
    def __init__(self, node_types, edge_types, channels, num_layers=2):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            HeteroGraphormerLayer(channels) for _ in range(num_layers)
        ])

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        for layer in self.layers:
            x_dict = layer(x_dict, edge_index_dict)
        return x_dict

    def reset_parameters(self):
        for layer in self.layers:
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()

## Model

In [10]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder



class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphormer(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

We also need standard train/test loops

In [None]:
def train() -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

Now we are ready to train!

In [12]:
import numpy as np

def rmse(true, pred):
    """Calculate the Root Mean Squared Error (RMSE)."""
    return np.sqrt(np.mean((true - pred)**2)) # Calculate RMSE manually

In [13]:
import numpy as np

def custom_evaluate(pred: np.ndarray, target_table, metrics) -> dict:
    """Custom evaluation function to replace task.evaluate."""

    # Extract target values from the target table
    target = target_table.df[task.target_col].to_numpy()

    # Check for length mismatch
    if len(pred) != len(target):
        raise ValueError(
            f"The length of pred and target must be the same (got "
            f"{len(pred)} and {len(target)}, respectively)."
        )

    # Calculate metrics
    results = {}
    for metric_fn in metrics:
        if metric_fn.__name__ == "rmse":  # Handle RMSE specifically
            results["rmse"] = np.sqrt(np.mean((target - pred)**2))
        else:  # Handle other metrics (if any)
            results[metric_fn.__name__] = metric_fn(target, pred)

    return results

In [18]:
from sklearn.metrics import mean_squared_error
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    #val_metrics = task.evaluate(val_pred, val_table)
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
print(f"Best Val metrics: {val_metrics}")

#test_pred = test(loader_dict["test"])
#test_metrics = custom_evaluate(test_pred, test_table, task.metrics)
#print(f"Best test metrics: {test_metrics}")

  0%|          | 0/15 [00:00<?, ?it/s]


AttributeError: 'NodeStorage' object has no attribute 'x'

# Graphormer implementation

Now we can finally try to implement a complete version of the graphormer architecture following the work of You et al., 2021.

In [14]:
import torch
from torch import nn
from torch_geometric.nn import Linear
from torch_geometric.utils import softmax

class HeteroGraphormerLayer_enhanced(nn.Module):
    def __init__(self, channels, num_heads=4, dropout=0.1, max_degree=10, max_distance=10):
        super().__init__()
        self.num_heads = num_heads
        self.channels = channels
        self.head_dim = channels // num_heads

        assert self.channels % num_heads == 0, "channels must be divisible by num_heads"

        self.q_lin = Linear(channels, channels)
        self.k_lin = Linear(channels, channels)
        self.v_lin = Linear(channels, channels)
        self.out_lin = Linear(channels, channels)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(channels)

        # Spatial encoding (e.g., shortest path distance embedding)
        self.spatial_encoding = nn.Embedding(max_distance + 2, num_heads)  # +1 for unseen, +1 for padding
        self.centrality_encoding = nn.Embedding(max_degree + 1, channels)

        self.edge_type_bias = nn.ParameterDict()

    def forward(self, x_dict, edge_index_dict, spatial_dict=None, degree_dict=None):
        out_dict = {k: torch.zeros_like(v) for k, v in x_dict.items()}

        for edge_type, edge_index in edge_index_dict.items():
            src_type, _, dst_type = edge_type
            x_src, x_dst = x_dict[src_type], x_dict[dst_type]
            src, dst = edge_index

            Q = self.q_lin(x_dst).view(-1, self.num_heads, self.head_dim)
            K = self.k_lin(x_src).view(-1, self.num_heads, self.head_dim)
            V = self.v_lin(x_src).view(-1, self.num_heads, self.head_dim)

            attn_scores = (Q[dst] * K[src]).sum(dim=-1) / self.head_dim**0.5

            # Edge type bias
            bias_name = "__".join(edge_type)
            if bias_name not in self.edge_type_bias:
                self.edge_type_bias[bias_name] = nn.Parameter(torch.randn(1))
            attn_scores = attn_scores + self.edge_type_bias[bias_name]

            # Spatial Encoding bias (shortest path)
            if spatial_dict and edge_type in spatial_dict:
                spatial = spatial_dict[edge_type]  # tensor of shape [num_edges]
                spatial_emb = self.spatial_encoding(spatial.clamp(min=0) + 1)  # -1 becomes 0
                attn_scores += spatial_emb

            # Attention weights
            attn_weights = softmax(attn_scores, dst)
            attn_weights = self.dropout(attn_weights)

            out = V[src] * attn_weights.unsqueeze(-1)
            out = out.view(-1, self.channels)

            out_dict[dst_type].index_add_(0, dst, out)

        # Centrality encoding (degree)
        for node_type in out_dict:
            if degree_dict and node_type in degree_dict:
                deg = degree_dict[node_type]  # Tensor[num_nodes]
                centrality_emb = self.centrality_encoding(deg.clamp(max=self.centrality_encoding.num_embeddings - 1))
                out_dict[node_type] = out_dict[node_type] + centrality_emb

            out_dict[node_type] = self.norm(out_dict[node_type] + x_dict[node_type])

        return out_dict


In [15]:
class HeteroGraphormer(torch.nn.Module):
    def __init__(self, node_types, edge_types, channels, num_layers=2):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            HeteroGraphormerLayer_enhanced(channels) for _ in range(num_layers)
        ])

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        for layer in self.layers:
            x_dict = layer(x_dict, edge_index_dict)
        return x_dict

    def reset_parameters(self):
        for layer in self.layers:
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()

In [16]:
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List

import torch
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType

from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder



class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphormer(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        
        # 1. Degree encoding
        degree_dict = {}
        for node_type in batch.node_types:
            deg = torch.zeros(batch[node_type].num_nodes, dtype=torch.long, device=batch[node_type].x.device)
            for edge_type, edge_index in batch.edge_index_dict.items():
                src_type, _, dst_type = edge_type
                if dst_type == node_type:
                    dst = edge_index[1]
                    deg = deg.index_add(0, dst, torch.ones_like(dst))
            degree_dict[node_type] = deg

        # 2. Spatial encoding (1 if edge exists, -1 if not)
        spatial_dict = {
            edge_type: torch.ones(edge_index.size(1), dtype=torch.long, device=edge_index.device)
            for edge_type, edge_index in batch.edge_index_dict.items()
        }

        # 3. GNN forward con encoding
        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            spatial_dict=spatial_dict,
            degree_dict=degree_dict,
        )



        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            spatial_dict=spatial_dict,
            degree_dict=degree_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)


# if you try out different RelBench tasks you will need to change these
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
epochs = 10

In [20]:
from sklearn.metrics import mean_squared_error
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train()
    val_pred = test(loader_dict["val"])
    #val_metrics = task.evaluate(val_pred, val_table)
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
        not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(loader_dict["val"])
val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
print(f"Best Val metrics: {val_metrics}")

#test_pred = test(loader_dict["test"])
#test_metrics = custom_evaluate(test_pred, test_table, task.metrics)
#print(f"Best test metrics: {test_metrics}")

  0%|          | 0/15 [00:00<?, ?it/s]


AttributeError: 'NodeStorage' object has no attribute 'x'