<a href="https://colab.research.google.com/github/Andrea-1704/Pytorch_Geometric_tutorial/blob/main/train_model_baseline_f1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-geometric==2.6.0 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html

# !pip install pytorch_frame[full]==1.2.2
# !pip install relbench[full]==1.0.0
# !pip uninstall -y pyg_lib torch  # Uninstall current versions
# !pip install torch==2.6.0  # Reinstall your desired PyTorch version
# !pip install --no-cache-dir git+https://github.com/pyg-team/pyg-lib.git # Install pyg-lib; --no-cache-dir ensures a fresh install

In [2]:
import os
import torch
import relbench
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task
import math
from tqdm import tqdm
import torch_geometric
import torch_frame
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from collections import defaultdict
import requests
from io import StringIO
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader
import pyg_lib
from sklearn.metrics import mean_squared_error

  from .autonotebook import tqdm as notebook_tqdm


# Dataset and task creation

In [3]:
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
# one because we are estimating one single value.
loss_fn = L1Loss()
# this is the mae loss and is used when have regressions tasks.
tune_metric = "mae"
higher_is_better = False

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) 
root_dir = "./data"

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
#this is used to get the stype of the columns 

cpu
Loading Database object from C:\Users\andrea\AppData\Local\relbench\relbench\Cache/rel-f1/db...
Done in 0.10 seconds.


# Embedder

In [4]:
# import torch
# from typing import List, Optional
# from sentence_transformers import SentenceTransformer
# from torch import Tensor


# class GloveTextEmbedding:
#     def __init__(self, device: Optional[torch.device
#                                        ] = None):
#         self.model = SentenceTransformer(
#             "sentence-transformers/average_word_embeddings_glove.6B.300d",
#             device=device,
#         )

#     def __call__(self, sentences: List[str]) -> Tensor:
#         return torch.from_numpy(self.model.encode(sentences))


class LightweightGloveEmbedder:
    def __init__(self, device=None):
        self.device = device
        self.embeddings = defaultdict(lambda: np.zeros(300))
        self._load_embeddings()

    def _load_embeddings(self):
        try:
            #(senza bisogno di estrarre zip
            url = "https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.300d.txt"
            response = requests.get(url)
            response.raise_for_status()

            for line in StringIO(response.text):
                parts = line.split()
                word = parts[0]
                vector = np.array(parts[1:], dtype=np.float32)
                self.embeddings[word] = vector
        except Exception as e:
            print(f"Warning: Couldn't load GloVe embeddings ({str(e)}). Using zero vectors.")

    def __call__(self, sentences):
        results = []
        for text in sentences:
            words = text.lower().split()
            vectors = [self.embeddings[w] for w in words if w in self.embeddings]
            if vectors:
                avg_vector = np.mean(vectors, axis=0)
            else:
                avg_vector = np.zeros(300)
            results.append(avg_vector)

        tensor = torch.tensor(np.array(results), dtype=torch.float32)
        return tensor.to(self.device) if self.device else tensor

In [5]:
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=LightweightGloveEmbedder(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    #Solution if not working: !pip install --upgrade torch torchvision transformers
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)# create a graph how relbench requires.





In [6]:
loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )#notice that table_input is an object with three elements: nodes, time and transform.
    #nodes contains the input nodes
    #time contains the time for each node
    #transform is the tranformation to be applied to nodes
    entity_table = table_input.nodes[0]
    #we need to populate the loader_dict with three elements: "train", "val", and "test".
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )#this is the loader for grapg

# Model

## GAT

In [7]:
from torch_geometric.nn import HeteroConv, GATConv

class HeteroGAT(torch.nn.Module):
    def __init__(self, node_types, edge_types, channels, heads=8, num_layers=2):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.edge_types = edge_types
        #per ogni layer andiamoa a creare un HeteroConv che contiene 
        #tanti GATConv (che sono già implementati) quanti sono gli 
        #edge type.
        for i in range(num_layers):
            conv = HeteroConv(
                {
                    edge_type: GATConv(
                        (-1, -1), channels, heads=heads, concat=False, add_self_loops=False
                    )
                    for edge_type in edge_types
                },
                aggr="sum",
            )
            self.convs.append(conv)

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
        return x_dict

    def reset_parameters(self):
        for conv in self.convs:
            for edge_type in self.edge_types:
                conv.convs[edge_type].reset_parameters()

### How does it works?

It's pretty simple: for each layer of the network we are creating an HeteroConv layer, which can be seen as a wrapper of GNN layers. Then, for each edge_type in that layer we are creating a GAT layer.

The reason why we need a different GAT layer for each different edge type of a given layer is that each layer requires to be considered separately for it's semantic difference with all the others. 

So each of the edge type requires its own attention scores.

NB:

When we do something like:
~~~python
edge_type: GATConv(
    (-1, -1), channels, heads=heads, concat=False, add_self_loops=False
)
~~~

The "(-1, -1)" refers to the dimension of the input channels. In this case we are using 'lazy initialization': PyTorch Geometric will automatically infer the in_channels during the first forward() call by reading the actual dimensions of the source and destination node embeddings.

This is especially useful in heterogeneous graphs, where dimensions may vary across different types (or be decided dynamically).

In [8]:
class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData, #notice that "data2 is the graph we created with function make_pkey_fkey_graph
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGAT(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            heads=8,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,#one, since we are doing regressio
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData, 
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        #takes the timestamp of the nodes for which we want to make predictions
        #not the neighbours, but the nodes we want to make prediction for.
        x_dict = self.encoder(batch.tf_dict)
        #this creates a dictionar for all the nodes: each nodes has its
        #embedding

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )
        #this add the temporal information to the node using the 
        #HeteroTemporalEncoder

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time
        #add some other shallow embedder

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,#feature of nodes
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )#apply the gnn

        return self.head(x_dict[entity_table][: seed_time.size(0)])#final prediction

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])


model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=2,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

We also need standard train/test loops

In [9]:
def train(model, optimizer) -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(model, loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [10]:
def rmse(true, pred):
    """Calculate the Root Mean Squared Error (RMSE)."""
    return np.sqrt(np.mean((true - pred)**2)) # Calculate RMSE manually

In [11]:
def custom_evaluate(pred: np.ndarray, target_table, metrics) -> dict:
    """Custom evaluation function to replace task.evaluate."""

    # Extract target values from the target table
    target = target_table.df[task.target_col].to_numpy()

    # Check for length mismatch
    if len(pred) != len(target):
        raise ValueError(
            f"The length of pred and target must be the same (got "
            f"{len(pred)} and {len(target)}, respectively)."
        )

    # Calculate metrics
    results = {}
    for metric_fn in metrics:
        if metric_fn.__name__ == "rmse":  # Handle RMSE specifically
            results["rmse"] = np.sqrt(np.mean((target - pred)**2))
        else:  # Handle other metrics (if any)
            results[metric_fn.__name__] = metric_fn(target, pred)

    return results

In [12]:
def training_function(model, optimizer, epochs):
    state_dict = None
    best_val_metric = -math.inf if higher_is_better else math.inf
    for epoch in range(1, epochs + 1):
        train_loss = train(model, optimizer)
        val_pred = test(model, loader_dict["val"])
        #val_metrics = task.evaluate(val_pred, val_table)
        val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
        if epoch % 10 == 0:
            print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
        #print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
            not higher_is_better and val_metrics[tune_metric] < best_val_metric
        ):
            best_val_metric = val_metrics[tune_metric]
            state_dict = copy.deepcopy(model.state_dict())


    model.load_state_dict(state_dict)
    val_pred = test(model, loader_dict["val"])
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
    print(f"Best Val metrics for parameters {optimizer}, are: {val_metrics}")

## Cross validation cycle

In [13]:
# #cross validation cycle:
# #possible learning rates: [0.01, 0.001, 0.0001, 0.00001]
# #possible batch sizes: [64, 256, 512]
# #possible number of layers: [1, 2, 3]
# #possible weight decay: [0.0001, 0.001, 0.01]

# for lr in [0.01, 0.001, 0.0001, 0.00001]:#0.001
#     #for batch_size in [64, 256, 512]:
#         for num_layers in [1, 2, 3]:#1
#             #for weight_decay in [0.0001, 0.001, 0.01]:
#                 model = Model(
#                     data=data,
#                     col_stats_dict=col_stats_dict,
#                     num_layers=num_layers,
#                     channels=128,
#                     out_channels=1,
#                     aggr="sum",
#                     norm="batch_norm",
#                 ).to(device)
#                 print(f"Training with lr={lr}, num_layers={num_layers}")
#                 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
#                 training_function(model, optimizer, epochs=10) # Set epochs to a smaller number for testing

# Training

In [None]:
model = Model(
                    data=data,
                    col_stats_dict=col_stats_dict,
                    num_layers=1,
                    channels=128,
                    out_channels=1,
                    aggr="mean",
                    norm="batch_norm",
                ).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
epochs = 3
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf
for epoch in range(1, epochs + 1):
    train_loss = train(model, optimizer)
    val_pred = test(model, loader_dict["val"])
    #val_metrics = task.evaluate(val_pred, val_table)
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
            not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(model, loader_dict["val"])
val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
print(f"Best Val metrics: {val_metrics}")

task.get_table("test", mask_input_cols=False)
test_table = task.get_table("test", mask_input_cols=False)
test_pred = test(model,loader_dict["test"])
test_metrics = custom_evaluate(test_pred, test_table, task.metrics)
print(f"Best test metrics: {test_metrics}")

  0%|          | 0/15 [00:09<?, ?it/s]


KeyboardInterrupt: 

# Import a predefined model to use it

In [None]:
# model.load_state_dict(torch.load('best_model_GAT_head2.pth', map_location=torch.device('cpu')))

  model.load_state_dict(torch.load('best_model_GAT_head2.pth', map_location=torch.device('cpu')))


RuntimeError: Error(s) in loading state_dict for Model:
	Missing key(s) in state_dict: "encoder.encoders.qualifying.encoder.encoder_dict.categorical.offset", "encoder.encoders.qualifying.encoder.encoder_dict.categorical.emb.weight". 
	Unexpected key(s) in state_dict: "encoder.encoders.constructor_standings.encoder.encoder_dict.categorical.offset", "encoder.encoders.constructor_standings.encoder.encoder_dict.categorical.emb.weight". 
	size mismatch for encoder.encoders.constructor_standings.encoder.encoder_dict.numerical.weight: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([3, 128]).
	size mismatch for encoder.encoders.constructor_standings.encoder.encoder_dict.numerical.bias: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([3, 128]).
	size mismatch for encoder.encoders.constructor_standings.encoder.encoder_dict.numerical.mean: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).
	size mismatch for encoder.encoders.constructor_standings.encoder.encoder_dict.numerical.std: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).
	size mismatch for encoder.encoders.qualifying.encoder.encoder_dict.numerical.weight: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for encoder.encoders.qualifying.encoder.encoder_dict.numerical.bias: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([1, 128]).
	size mismatch for encoder.encoders.qualifying.encoder.encoder_dict.numerical.mean: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).
	size mismatch for encoder.encoders.qualifying.encoder.encoder_dict.numerical.std: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([1]).