<a href="https://colab.research.google.com/github/Andrea-1704/Pytorch_Geometric_tutorial/blob/main/train_model_baseline_f1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In questo netebook proviamo ad implementare un processo di **pre gtraining** basato su self supervision.
Il modello dovrebbe riconoscere se un nodo fa parte del grafo reale o di una versione corrotta (negativa).
Quindi:

    Se un nodo ha embedding simile al global summary del grafo reale → POSITIVO

    Se un nodo viene da un grafo "corrotto" (shufflato) → NEGATIVO

Fondamentalmente l'idea è quella di, prendere **x** come input delle features originali. Costruire **z=GNN(x, edge_index)**, ovvero le rappresentazioni dei nodi che esistono realmente nel grafo. Calcolare **summary=READOUT(z)** ovvero una sorta di rappresentazione globale del modello che descriva le informazioni generali topologiche per il grafo.

A questo punto andremo a perturbare x, calcolando **x_corrupted=x corrotto** che in genere viene calcolato tramite una perturbazione, o shuffle delle node features: semplicemente cambiamo l'ordine delle features dei nodi. Calcoliamo come prima **z_corrupted=GNN(x_corrupted, edge_index)**.

Poi calcoliamo questi valori di score: **logits = discriminator(z, summary)**, **logits=discriminator(z_corrupted, summary)**, dove l'obiettivo è quello di otteenere valori vicini ad uno per il primo dei due e vicini a zero per il secondo.

L'obiettivo adesso sarà quello di compiere un processo di classificazione binaria finalizzata a riconoscere come positivi i sampel che derivano dalla struttura reale (z) e come negativi quelli che invece provengono dalla struttura perturbata (z_corrupted).

In [144]:
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install torch-geometric==2.6.0 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
# !pip install pyg-lib -f https://data.pyg.org/whl/torch-2.4.0+cpu.html

# !pip install pytorch_frame[full]==1.2.2
# !pip install relbench[full]==1.0.0
# !pip uninstall -y pyg_lib torch  # Uninstall current versions
# !pip install torch==2.6.0  # Reinstall your desired PyTorch version
# !pip install --no-cache-dir git+https://github.com/pyg-team/pyg-lib.git # Install pyg-lib; --no-cache-dir ensures a fresh install

New libraries to run on colab:

In [145]:
# !pip install torch==2.6.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# !pip install pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-geometric==2.6.0 -f https://data.pyg.org/whl/torch-2.6.0+cu118.html

# !pip install pytorch_frame[full]==1.2.2
# !pip install relbench[full]==1.0.0

In [146]:
import os
import torch
import relbench
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task
import math
from tqdm import tqdm
import torch_geometric
import torch_frame
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from collections import defaultdict
import requests
from io import StringIO
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader
import pyg_lib
from sklearn.metrics import mean_squared_error
#per lo scheduler
from torch.optim.lr_scheduler import LambdaLR
import matplotlib.pyplot as plt
import networkx as nx
import torch
from torch import nn
from torch_geometric.nn import Linear
from torch_geometric.utils import softmax
from torch_geometric.utils import degree
from collections import defaultdict

# Dataset and task creation

In [147]:
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train")
val_table = task.get_table("val")
test_table = task.get_table("test")

out_channels = 1
# one because we are estimating one single value.
loss_fn = L1Loss()
# this is the mae loss and is used when have regressions tasks.
tune_metric = "mae"
higher_is_better = False

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
root_dir = "./data"

db = dataset.get_db()
col_to_stype_dict = get_stype_proposal(db)
#this is used to get the stype of the columns

cuda


# Utility functions

In [None]:
def plot_validation_metrics(metric_histories, model_names=None, metric_name="MAE", informationsTitle=""):
    """
    Plotta l'andamento del metric_name per più modelli nel tempo.

    Args:
        metric_histories (list of lists): Lista di liste, ognuna rappresenta i valori di metriche per un modello.
        model_names (list of str): Nomi dei modelli (opzionale).
        metric_name (str): Nome della metrica da visualizzare.
        informationsTitle (str): info aggiungitive da mettere nel titolo (conf generale dei parametri ecc).
    """
    plt.figure(figsize=(9, 5))

    if model_names is None:
        model_names = [f"Model {i+1}" for i in range(len(metric_histories))]

    for metrics, name in zip(metric_histories, model_names):
        plt.plot(metrics, marker='o', label=f'{name} {metric_name}')

    plt.xlabel("Epoch")
    plt.ylabel(metric_name)
    plt.title(f"{metric_name} over Epochs for Multiple Models {informationsTitle}")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


In [None]:
def custom_evaluate(pred: np.ndarray, target_table, metrics) -> dict:
    """Custom evaluation function to replace task.evaluate."""

    # Extract target values from the target table
    target = target_table.df[task.target_col].to_numpy()

    # Check for length mismatch
    if len(pred) != len(target):
        raise ValueError(
            f"The length of pred and target must be the same (got "
            f"{len(pred)} and {len(target)}, respectively)."
        )

    # Calculate metrics
    results = {}
    for metric_fn in metrics:
        if metric_fn.__name__ == "rmse":  # Handle RMSE specifically
            results["rmse"] = np.sqrt(np.mean((target - pred)**2))
        else:  # Handle other metrics (if any)
            results[metric_fn.__name__] = metric_fn(target, pred)

    return results

In [None]:
def rmse(true, pred):
    """Calculate the Root Mean Squared Error (RMSE)."""
    return np.sqrt(np.mean((true - pred)**2)) # Calculate RMSE manually

# Embedder

In [148]:
# import torch
# from typing import List, Optional
# from sentence_transformers import SentenceTransformer
# from torch import Tensor


# class GloveTextEmbedding:
#     def __init__(self, device: Optional[torch.device
#                                        ] = None):
#         self.model = SentenceTransformer(
#             "sentence-transformers/average_word_embeddings_glove.6B.300d",
#             device=device,
#         )

#     def __call__(self, sentences: List[str]) -> Tensor:
#         return torch.from_numpy(self.model.encode(sentences))


class LightweightGloveEmbedder:
    def __init__(self, device=None):
        self.device = device
        self.embeddings = defaultdict(lambda: np.zeros(300))
        self._load_embeddings()

    def _load_embeddings(self):
        try:
            #(senza bisogno di estrarre zip
            url = "https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.300d.txt"
            response = requests.get(url)
            response.raise_for_status()

            for line in StringIO(response.text):
                parts = line.split()
                word = parts[0]
                vector = np.array(parts[1:], dtype=np.float32)
                self.embeddings[word] = vector
        except Exception as e:
            print(f"Warning: Couldn't load GloVe embeddings ({str(e)}). Using zero vectors.")

    def __call__(self, sentences):
        results = []
        for text in sentences:
            words = text.lower().split()
            vectors = [self.embeddings[w] for w in words if w in self.embeddings]
            if vectors:
                avg_vector = np.mean(vectors, axis=0)
            else:
                avg_vector = np.zeros(300)
            results.append(avg_vector)

        tensor = torch.tensor(np.array(results), dtype=torch.float32)
        return tensor.to(self.device) if self.device else tensor

In [149]:
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=LightweightGloveEmbedder(device=device), batch_size=256
)

data, col_stats_dict = make_pkey_fkey_graph(
    #Solution if not working: !pip install --upgrade torch torchvision transformers
    db,
    col_to_stype_dict=col_to_stype_dict,  # speficied column types
    text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder
    cache_dir=os.path.join(
        root_dir, f"rel-f1_materialized_cache"
    ),  # store materialized graph for convenience
)# create a graph how relbench requires.





In [150]:
loader_dict = {}

for split, table in [
    ("train", train_table),
    ("val", val_table),
    ("test", test_table),
]:
    table_input = get_node_train_table_input(
        table=table,
        task=task,
    )#notice that table_input is an object with three elements: nodes, time and transform.
    #nodes contains the input nodes
    #time contains the time for each node
    #transform is the tranformation to be applied to nodes
    entity_table = table_input.nodes[0]
    #we need to populate the loader_dict with three elements: "train", "val", and "test".
    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[
            128 for i in range(2)
        ],  # we sample subgraphs of depth 2, 128 neighbors per node.
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=512,
        temporal_strategy="uniform",
        shuffle=split == "train",
        num_workers=0,
        persistent_workers=False,
    )#this is the loader for grapg

Nota che ogni oggetto "batch" rappresenta un dizionario in cui le chiavi rappresentano i tipi di nodi oppure i tipi di archi e contengono gli embedding del nodo opppure l'edge index dell'arco (vale a dire un tensore di dimensione 2xnumero di edge, ovvero la lista dei nodi di tipo sorgente che presdentano una relazione di quel tipo con i nodi di tipo destinazione, ovvero quelli della seconda lista).

# Pre-training

Disciminator, calcola semplicemente il prodotto scalare tra z e summary (embedding globale del grafo di partenza) e predice se è vero o falso.

L'obiettivo di questo pre training è quello di costruire un **sunmmary** dell'intero grafo, che in questo caso è calcolato come media degli embedding dei suoi nodi. Successivamente cerchiamo di massimizzare la somiglianza (similarity) tra gli embeddings dei nodi effettivamente presenti dentro il grafo (detti **nodi positivi**) e minimizzare quella con nodi "corrotti", ovvero nodi modificati appositamente (detti **nodi negativi**). Speriamo in questo modo che i parametri del modello iniziale (prima di cominciare il training vero e proprio) contengano già informazioni rilevanti sul grafo, ovvero siano in grado di riconoscere una certa struttura degli embeddings dei nodi.

In questo modo quindi speriamo che i pesi del modello iniziale non siano "casuali" ma contengono già una certa sematica sui nodi, ovvero che gli embeddings dei nodi contengano già rilevanti informazioni sulla struttura del grafo, la "logica" dei nodi e le relazioni rilevanti tra i nodi.

Classe del discrinatore (non viene usata poi per il down stream task, ma migliorare lei porta a migliorare l'encoder della gnn).

In [None]:
class DGIHead(torch.nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, x_dict, corrupted_x_dict, summary):
        loss = 0
        for node_type in x_dict:
            z = x_dict[node_type]                     # [N, hidden]
            #qui prendiamo la struttura [num_nodes, features] positivi
            z_corrupt = corrupted_x_dict[node_type]   # [N, hidden]
            #struttura [num_nodes, features] negativi.
            if z.shape[0] == 0:
              continue
            #print("x corrotto: ", z_corrupt)

            pos = torch.matmul(z, self.weight)        # [N, hidden]
            neg = torch.matmul(z_corrupt, self.weight)# [N, hidden]

            summary_proj = summary.t()                # [hidden, 1]
            #andiamo a trasporre la summary del grafo.

            pos_score = torch.matmul(pos, summary_proj).squeeze()  # [N]
            #usiamo come metrica di similarity il prodotto.
            neg_score = torch.matmul(neg, summary_proj).squeeze()  # [N]

            #print("il positive score è: ", pos_score)-> qui è già nan!

            pos_loss = -torch.log(torch.sigmoid(pos_score) + 1e-15).mean()
            #print("il positive score dopo la log è: ", pos_loss)

            neg_loss = -torch.log(1 - torch.sigmoid(neg_score) + 1e-15).mean()

            loss += pos_loss + neg_loss

        return loss


Per capire la funzione di forward ricorda che x_dict è il dizionario originale quello, ovvero, che contiene gli emebeddings effettivi per i tipi di nodi. corrupted_x_dict è un dizionario che segue la stessa struttura ma che contiene gli embeddings negativi.

Qui sotto definiamo una funzione che calcoli una sintesi dell'intero grafo.

In [None]:
@torch.no_grad()
def compute_summary(x_dict):
    summary_dict = {}
    for k, z in x_dict.items():
        #se nonm abbiamo nodi di quel tipo escludi
        if z.numel() == 0:
            summary_dict[k] = torch.zeros(z.shape[1], device=z.device)
        else:
            summary_dict[k] = torch.tanh(z.mean(dim=0))
    return summary_dict


Di seguito si mostra l'implementazione di una funzione che prende come parametro un HeteroGraph e restituisce un altro HeteroGrapf in cui i nodi sono stati corrotti, ovvero mescoliamo in modo randomico le features dei nodi: in pratica un tipo di nodo x riceve casualmente le features del nodo y, per esempoio alla tabella delle corse sono dati deglki embeddings della tabella di costruzioni (chiaramente archi rimangono invariati).

Lo vorremmo fare direttamente per le x, ovvero le features dei tipi di nodi. Tuttavia questo nel batch non è direttamente presente, ma è presente n_id che associa, ordinatamente, ad ogni nodo un id del nodo aa cui fa riferiemnto e quindi successivamente da quell'id si recuperano i suoi embeddings. Possiamo quindi modificare semplicemente l'ordine di n_id per cambiare i nodi di partenza.

In [None]:
def corrupt_features(batch: HeteroData) -> HeteroData:
    corrupted = copy.deepcopy(batch)
    #costruisco una clone del batch di grafo per evitare di modificare
    #il grafo originale
    for node_type in corrupted.node_types:
        #print(corrupted[node_type])
        #if hasattr(corrupted[node_type], "n_id"):
        # Use embeddings, permute n_ids
        n_id = corrupted[node_type].n_id
        perm = torch.randperm(n_id.size(0))
        corrupted[node_type].n_id = n_id[perm]
    return corrupted


Funzione che effettua il pre training per qualche epoca.

In [None]:
def train_dgi(model, discriminator, optimizer, loader, device, entity_table, epochs=20):
    model.train()
    discriminator.train()

    for epoch in range(epochs):
        epoch_loss = 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")

        for batch in pbar:
            batch = batch.to(device)

            # Corrupt the batch
            corrupted_batch = corrupt_features(batch)
            corrupted_batch = corrupted_batch.to(device)

            # Get positive and negative embeddings + summary
            z_dict, summary = model.pretrain_dgi_forward(batch, entity_table=entity_table)
            corrupted_z_dict, _ = model.pretrain_dgi_forward(corrupted_batch, entity_table=entity_table)

            # Compute DGI loss
            loss = discriminator(z_dict, corrupted_z_dict, summary)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({"DGI Loss": loss.item()})

        print(f"Epoch {epoch+1}, Average DGI Loss: {epoch_loss / len(loader):.4f}")


## Graphormer

In [154]:
_spatial_bias_cache = None
_node_offset_cache = None
from collections import defaultdict


def compute_spatial_bias(edge_index_dict, x_dict):
    global _spatial_bias_cache, _node_offset_cache
    if _spatial_bias_cache is not None:
        return _spatial_bias_cache, _node_offset_cache
    #creiamo un grafo diretto con Networkx
    G = nx.DiGraph()

    node_offset = {}
    curr_offset = 0

    #aggiungiamo i nodi con offset per mantenere indici globali univoci
    for node_type, x in x_dict.items():
        node_offset[node_type] = curr_offset
        for i in range(x.size(0)):
            G.add_node(curr_offset + i, type=node_type)
        curr_offset += x.size(0)

    #Aggiungiamo gli archi con offset
    for (src_type, _, dst_type), edge_index in edge_index_dict.items():
        src_offset = node_offset[src_type]
        dst_offset = node_offset[dst_type]
        src, dst = edge_index
        for s, d in zip(src.tolist(), dst.tolist()):
            G.add_edge(src_offset + s, dst_offset + d)


    spatial_bias = defaultdict(lambda: -1)



    for node in G.nodes():
        lengths = nx.single_source_dijkstra_path_length(G, node)
        for target, dist in lengths.items():
            spatial_bias[(node, target)] = dist
        #quelli non raggiungibili li lasciamo con default value, ovvero -1

    _spatial_bias_cache = spatial_bias
    _node_offset_cache = node_offset

    return spatial_bias, node_offset

In [155]:
class HeteroGraphormerLayerComplete(nn.Module):
    def __init__(self, channels, edge_types, device, num_heads=4, dropout=0.1):
        super().__init__()
        self.device = device
        self.num_heads = num_heads
        self.channels = channels
        self.head_dim = channels // num_heads

        assert self.channels % num_heads == 0, "channels must be divisible by num_heads"

        self.q_lin = Linear(channels, channels)
        self.k_lin = Linear(channels, channels)
        self.v_lin = Linear(channels, channels)
        self.out_lin = Linear(channels, channels)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(channels)

        # Registriamo i bias per ogni tipo di edge nel __init__
        self.edge_type_bias = nn.ParameterDict({
            "__".join(edge_type): nn.Parameter(torch.randn(1))
            for edge_type in edge_types
        })

    def compute_total_degrees(self, x_dict, edge_index_dict):
        device = self.device
        in_deg = defaultdict(lambda: torch.zeros(0, device=device))
        out_deg = defaultdict(lambda: torch.zeros(0, device=device))
        for edge_type, edge_index in edge_index_dict.items():
            src_type, _, dst_type = edge_type
            src = edge_index[0]
            dst = edge_index[1]

            num_src = x_dict[src_type].size(0)
            num_dst = x_dict[dst_type].size(0)

            if out_deg[src_type].numel() == 0:
                out_deg[src_type] = torch.zeros(num_src, device=device)
            if in_deg[dst_type].numel() == 0:
                in_deg[dst_type] = torch.zeros(num_dst, device=device)

            out_deg[src_type] += degree(src, num_nodes=num_src)
            in_deg[dst_type]  += degree(dst, num_nodes=num_dst)

        total_deg = {
            node_type: in_deg[node_type] + out_deg[node_type]
            for node_type in x_dict
        }

        return total_deg

    def forward(self, x_dict, edge_index_dict):
        #print(edge_index_dict)
        self.spatial_bias, self.node_offset = compute_spatial_bias(edge_index_dict, x_dict)

        out_dict = {k: torch.zeros_like(v) for k, v in x_dict.items()}
        for edge_type, edge_index in edge_index_dict.items():
            src_type, _, dst_type = edge_type
            x_src, x_dst = x_dict[src_type], x_dict[dst_type]

            #src, dst = edge_index
            src = edge_index[0]
            dst = edge_index[1]

            Q = self.q_lin(x_dst).view(-1, self.num_heads, self.head_dim)
            K = self.k_lin(x_src).view(-1, self.num_heads, self.head_dim)
            V = self.v_lin(x_src).view(-1, self.num_heads, self.head_dim)

            attn_scores = (Q[dst] * K[src]).sum(dim=-1) / self.head_dim**0.5
            src_offset = self.node_offset[src_type]
            dst_offset = self.node_offset[dst_type]

            spatial_bias_vals = []
            for s, d in zip(src.tolist(), dst.tolist()):
                global_s = src_offset + s
                global_d = dst_offset + d
                dist = self.spatial_bias.get((global_d, global_s), -1.0)
                spatial_bias_vals.append(dist)

            spatial_bias_tensor = torch.tensor(spatial_bias_vals, dtype=torch.float, device=self.device)
            attn_scores = attn_scores + spatial_bias_tensor.unsqueeze(-1)  # broadcast su heads


            bias_name = "__".join(edge_type)
            attn_scores = attn_scores + self.edge_type_bias[bias_name]

            attn_weights = softmax(attn_scores, dst)
            attn_weights = self.dropout(attn_weights)

            out = V[src] * attn_weights.unsqueeze(-1)
            out = out.view(-1, self.channels)

            out_dict[dst_type].index_add_(0, dst, out)

        #calcolo della degree centrality

        total_deg = self.compute_total_degrees(x_dict, edge_index_dict)


        for node_type in out_dict:

            degree_embed = total_deg[node_type].view(-1, 1)                                                                                  # Assicurati che sia una colonna
            degree_embed = degree_embed.expand(-1, self.channels)                                                                            # Espandi lungo la dimensione dei canali


            out_dict[node_type] = out_dict[node_type] + degree_embed


        for node_type in out_dict:
            out_dict[node_type] = self.norm(out_dict[node_type] + x_dict[node_type])

        return out_dict


In [156]:
class HeteroGraphormer(torch.nn.Module):
    def __init__(self, node_types, edge_types, channels, num_layers=2):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            HeteroGraphormerLayerComplete(channels, edge_types, device) for _ in range(num_layers)
        ])

    def forward(self, x_dict, edge_index_dict, *args, **kwargs):
        for layer in self.layers:
            x_dict = layer(x_dict, edge_index_dict)
        return x_dict

    def reset_parameters(self):
        for layer in self.layers:
            if hasattr(layer, "reset_parameters"):
                layer.reset_parameters()

# Model

In [None]:
class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData, #notice that "data2 is the graph we created with function make_pkey_fkey_graph
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        # List of node types to add shallow embeddings to input
        shallow_list: List[NodeType] = [],
        # ID awareness
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )
        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )
        self.gnn = HeteroGraphormer(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,#one, since we are doing regression
            out_channels=out_channels,
            norm=norm,
            num_layers=1,
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()
            


    def pretrain_dgi_forward(self, batch: HeteroData, entity_table: NodeType):
        x_dict = self.encoder(batch.tf_dict)
        # Costruiamo x_dict

        if self.id_awareness_emb is not None:
            for node_type in x_dict:
                x_dict[node_type] += self.id_awareness_emb.weight

        if hasattr(batch[entity_table], "seed_time"):
            rel_time_dict = self.temporal_encoder(
                batch[entity_table].seed_time,
                batch.time_dict,
                batch.batch_dict
            )
            for node_type, rel_time in rel_time_dict.items():
                x_dict[node_type] += rel_time
        # Aggiungiamo il temporal encoding se presente

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] += embedding(batch[node_type].n_id)
        # Aggiungiamo gli shallow embeddings per i nodi che lo hanno abilitato

        z_dict = self.gnn(x_dict, batch.edge_index_dict)
        # Passiamo alla GNN e riceviamo gli embeddings aggiornati (z_dict)

        # ⛔️ Gestione dei tensori vuoti per evitare errori nel summary
        valid_z_list = []
        for node_type, z in z_dict.items():
            if z.size(0) != 0:
               
                valid_z_list.append(z.mean(dim=0, keepdim=True))

        # Se nessun nodo ha embeddings validi, ritorniamo None
        if len(valid_z_list) == 0:
            #print("[DEBUG] Nessun embedding valido per il summary, ritorno None.")
            return None, None

        # Calcolo del summary
        summary = torch.cat(valid_z_list, dim=0).mean(dim=0, keepdim=True)
        summary = torch.tanh(summary)

        return z_dict, summary



    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        #takes the timestamp of the nodes for which we want to make predictions
        #not the neighbours, but the nodes we want to make prediction for.
        x_dict = self.encoder(batch.tf_dict)
        #this creates a dictionar for all the nodes: each nodes has its
        #embedding

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )
        #this add the temporal information to the node using the
        #HeteroTemporalEncoder

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time
        #add some other shallow embedder

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        # for edge_type, edge_index in batch.edge_index_dict.items():
        #     print("model edge_tipe: ", edge_type)
        #     print("model edge_index: ", edge_index)
        #print("model x_dict : ", x_dict['constructors'])

        x_dict = self.gnn(
            x_dict,#feature of nodes
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )#apply the gnn

        return self.head(x_dict[entity_table][: seed_time.size(0)])#final prediction

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        # Add ID-awareness to the root node
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])

### Scheduler

In [158]:
def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps, num_cycles=0.5):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2 * progress)))

    return LambdaLR(optimizer, lr_lambda)


In [159]:
def train(model, optimizer, scheduler) -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(model, loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [162]:
def training_function(model, optimizer, epochs):
    state_dict = None
    best_val_metric = -math.inf if higher_is_better else math.inf
    for epoch in range(1, epochs + 1):
        train_loss = train(model, optimizer)
        val_pred = test(model, loader_dict["val"])
        #val_metrics = task.evaluate(val_pred, val_table)
        val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
        if epoch % 10 == 0:
            print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")
        #print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

        if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
            not higher_is_better and val_metrics[tune_metric] < best_val_metric
        ):
            best_val_metric = val_metrics[tune_metric]
            state_dict = copy.deepcopy(model.state_dict())


    model.load_state_dict(state_dict)
    val_pred = test(model, loader_dict["val"])
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
    print(f"Best Val metrics for parameters {optimizer}, are: {val_metrics}")

## Cross validation cycle

In [163]:
# #cross validation cycle:
# #possible learning rates: [0.01, 0.001, 0.0001, 0.00001]
# #possible batch sizes: [64, 256, 512]
# #possible number of layers: [1, 2, 3]
# #possible weight decay: [0.0001, 0.001, 0.01]

# for lr in [0.01, 0.001, 0.0001, 0.00001]:#0.001
#     #for batch_size in [64, 256, 512]:
#         for num_layers in [1, 2, 3]:#1
#             #for weight_decay in [0.0001, 0.001, 0.01]:
#                 model = Model(
#                     data=data,
#                     col_stats_dict=col_stats_dict,
#                     num_layers=num_layers,
#                     channels=128,
#                     out_channels=1,
#                     aggr="sum",
#                     norm="batch_norm",
#                 ).to(device)
#                 print(f"Training with lr={lr}, num_layers={num_layers}")
#                 optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
#                 training_function(model, optimizer, epochs=10) # Set epochs to a smaller number for testing

# Training

In [None]:
model = Model(
    data=data,
    col_stats_dict=col_stats_dict,
    num_layers=1,
    channels=128,
    out_channels=1,
    aggr="sum",
    norm="batch_norm",
).to(device)

#copiamo i pesi del modello perché così possiamo vedere se stiamo effettivamente modificando model:
initial_state = copy.deepcopy(model.state_dict())

discriminator = DGIHead(128).to(device)

#andiamo ad ottimizzare sia i parametri di model che del discriminatore
optimizer_dgi = torch.optim.Adam(
    list(model.parameters()) + list(discriminator.parameters()),
    lr=0.001,
    weight_decay=1e-5,
)

train_dgi(model, discriminator, optimizer_dgi, loader_dict["train"], device, entity_table=entity_table, epochs=20)

after_state = model.state_dict()

#verifica se sono cambiati
changed_params = [k for k in initial_state if not torch.equal(initial_state[k], after_state[k])]
print("Parametri modificati:", changed_params if changed_params else "Nessuno è cambiato")


#down stream
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)
epochs = 100
total_steps = epochs * len(loader_dict["train"])
warmup_steps = int(0.1 * total_steps)  # 10% warmup
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)
state_dict = None
best_val_metric = -math.inf if higher_is_better else math.inf

#per mantenere la storia dei MAE nel tempo:
val_metr_history = []

for epoch in range(1, epochs + 1):
    train_loss = train(model, optimizer, scheduler)
    val_pred = test(model, loader_dict["val"])
    #val_metrics = task.evaluate(val_pred, val_table)
    val_metrics = custom_evaluate(val_pred, val_table, task.metrics)

    val_metr_history.append(val_metrics[tune_metric])

    print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, Val metrics: {val_metrics}")

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
            not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())


model.load_state_dict(state_dict)
val_pred = test(model, loader_dict["val"])
val_metrics = custom_evaluate(val_pred, val_table, task.metrics)
print(f"Best Val metrics: {val_metrics}")

test_table = task.get_table("test", mask_input_cols=False)
test_pred = test(model,loader_dict["test"])
test_metrics = custom_evaluate(test_pred, test_table, task.metrics)
print(f"Best test metrics: {test_metrics}")

plot_validation_metrics([val_metr_history], ["basic model"],  metric_name=tune_metric)


Epoch 1, DGI Loss: nan
Epoch 2, DGI Loss: nan
Epoch 3, DGI Loss: nan
Epoch 4, DGI Loss: nan
Epoch 5, DGI Loss: nan
Epoch 6, DGI Loss: nan
Epoch 7, DGI Loss: nan
Epoch 8, DGI Loss: nan
Epoch 9, DGI Loss: nan
Epoch 10, DGI Loss: nan
Epoch 11, DGI Loss: nan
Epoch 12, DGI Loss: nan
Epoch 13, DGI Loss: nan
Epoch 14, DGI Loss: nan
Epoch 15, DGI Loss: nan
Epoch 16, DGI Loss: nan
Epoch 17, DGI Loss: nan


KeyboardInterrupt: 

# Import a predefined model to use it

In [None]:
# model.load_state_dict(torch.load('best_model_GAT_head2.pth', map_location=torch.device('cpu')))