<a href="https://colab.research.google.com/github/Andrea-1704/Pytorch_Geometric_tutorial/blob/main/train_model_baseline_f1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Obiettivo
Questo notebook segue il lavoro di Ablation sulle tabelle. L'obiettivo questa volta è quella di capire, per ogni tabella che consideremo nel dataset e per ogni sua features, quanto pesa la rimozione di qualche feature da quella tabella al fine di ridurre le colonne che la rete deve considerare per ogni specifico nodo.


# Libraries to install

In [1]:
# !pip install torch==2.6.0+cu118 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# !pip install pyg-lib -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cu118.html
# !pip install torch-geometric==2.6.0 -f https://data.pyg.org/whl/torch-2.6.0+cu118.html

# !pip install pytorch_frame[full]==1.2.2
# !pip install relbench[full]==1.0.0

# Import

In [2]:
import os
import torch
import relbench
import numpy as np
from torch.nn import BCEWithLogitsLoss, L1Loss
from relbench.datasets import get_dataset
from relbench.tasks import get_task
import math
from tqdm import tqdm
import torch_geometric
import torch_frame
from torch_geometric.seed import seed_everything
from relbench.modeling.utils import get_stype_proposal
from collections import defaultdict
import requests
from io import StringIO
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graph
from torch.nn import BCEWithLogitsLoss
import copy
from typing import Any, Dict, List
from torch import Tensor
from torch.nn import Embedding, ModuleDict
from torch_frame.data.stats import StatType
from torch_geometric.data import HeteroData
from torch_geometric.nn import MLP
from torch_geometric.typing import NodeType
from relbench.modeling.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from torch_geometric.loader import NeighborLoader
import pyg_lib
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

#Utility functions

In [3]:
@torch.no_grad()
def alignment_check(loader: NeighborLoader, expected_node_ids: torch.Tensor):
    node_id_list = []

    for batch in loader:
        batch = batch.to(device)

        node_id_list.append(batch[task.entity_table].n_id.cpu())

    actual_node_ids = torch.cat(node_id_list, dim=0)

    assert len(actual_node_ids) == len(expected_node_ids), "Mismatch nella lunghezza"

    if not torch.equal(actual_node_ids, expected_node_ids):
        raise ValueError("Ordine dei nodi predetti diverso da val_table!")

    return

In [4]:
def evaluate_performance(pred: np.ndarray, target_table, metrics) -> dict:
    """Custom evaluation function to replace task.evaluate."""
    target = target_table.df[task.target_col].to_numpy()

    if len(pred) != len(target):
        raise ValueError(
            f"The length of pred and target must be the same (got "
            f"{len(pred)} and {len(target)}, respectively)."
        )

    results = {}
    for metric_fn in metrics:
        if metric_fn.__name__ == "rmse":
            results["rmse"] = np.sqrt(np.mean((target - pred)**2))
        else:
            results[metric_fn.__name__] = metric_fn(target, pred)

    return results

In [5]:
import numpy as np

def evaluate_on_train_during_training() -> float:
    model.eval()
    pred_list, target_list = [], []

    for batch in loader_dict["train"]:
        batch = batch.to(device)
        pred = model(batch, task.entity_table)
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
        target_list.append(batch[task.entity_table].y.detach().cpu())

    pred_all = torch.cat(pred_list, dim=0).numpy()
    target_all = torch.cat(target_list, dim=0).numpy()

    mae = np.mean(np.abs(pred_all - target_all))
    return mae


In [6]:
def rmse(true, pred):
    """Calculate the Root Mean Squared Error (RMSE)."""
    return np.sqrt(np.mean((true - pred)**2))

In [7]:
@torch.no_grad()
def evaluate_on_full_train(model, loader) -> float:
    model.eval()
    pred_list, target_list = [], []

    for batch in loader:
        batch = batch.to(device)
        pred = model(batch, task.entity_table)
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.cpu())
        target_list.append(batch[task.entity_table].y.cpu())

    pred_all = torch.cat(pred_list, dim=0).numpy()
    target_all = torch.cat(target_list, dim=0).numpy()

    mae = np.mean(np.abs(pred_all - target_all))
    return mae


In [8]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0, verbose=False, path='checkpoint.pt'):
        """
        Args:
            patience (int): Quanto aspettare senza miglioramenti prima di fermare.
            delta (float): Miglioramento minimo richiesto per considerare un miglioramento.
            verbose (bool): Se stampare informazioni.
            path (str): Dove salvare il modello migliore.
        """
        self.patience = patience
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.verbose = verbose
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss  # Perché vogliamo MINIMIZZARE la loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} / {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Salva il modello migliore'''
        if self.verbose:
            print(f'Validation loss migliorata ({self.val_loss_min:.6f} --> {val_loss:.6f}). Salvo modello...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


# Dataset and task creation

In [9]:
dataset = get_dataset("rel-f1", download=True)
task = get_task("rel-f1", "driver-position", download=True)

train_table = task.get_table("train") #date  driverId  qualifying
val_table = task.get_table("val") #date  driverId  qualifying
test_table = task.get_table("test") # date  driverId

out_channels = 1
loss_fn = L1Loss()
# this is the mae loss and is used when have regressions tasks.
tune_metric = "mae"
higher_is_better = False

seed_everything(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
root_dir = "./data"

db = dataset.get_db() #get all tables
col_to_stype_dict = get_stype_proposal(db)
entity_table = task.entity_table
#this is used to get the stype of the columns

Downloading file 'rel-f1/db.zip' from 'https://relbench.stanford.edu/download/rel-f1/db.zip' to '/root/.cache/relbench'.
100%|████████████████████████████████████████| 704k/704k [00:00<00:00, 618MB/s]
Unzipping contents of '/root/.cache/relbench/rel-f1/db.zip' to '/root/.cache/relbench/rel-f1/.'
Downloading file 'rel-f1/tasks/driver-position.zip' from 'https://relbench.stanford.edu/download/rel-f1/tasks/driver-position.zip' to '/root/.cache/relbench'.
100%|█████████████████████████████████████| 36.5k/36.5k [00:00<00:00, 44.9MB/s]
Unzipping contents of '/root/.cache/relbench/rel-f1/tasks/driver-position.zip' to '/root/.cache/relbench/rel-f1/tasks/.'


cuda
Loading Database object from /root/.cache/relbench/rel-f1/db...
Done in 0.05 seconds.


In [10]:
list(db.table_dict.keys())

['standings',
 'constructor_results',
 'constructor_standings',
 'constructors',
 'qualifying',
 'results',
 'circuits',
 'races',
 'drivers']

# Embedder

In [11]:
class LightweightGloveEmbedder:
    def __init__(self, device=None):
        self.device = device
        self.embeddings = defaultdict(lambda: np.zeros(300))
        self._load_embeddings()

    def _load_embeddings(self):
      try:
          path = "glove.6B.300d.txt"
          with open(path, encoding="utf-8") as f:
              for line in f:
                  parts = line.strip().split()
                  word = parts[0]
                  vector = np.array(parts[1:], dtype=np.float32)
                  self.embeddings[word] = vector
          #print(f"Loaded {len(self.embeddings)} GloVe embeddings.")
      except Exception as e:
          print(f"Failed to load GloVe: {e}")

    def __call__(self, sentences):
        results = []
        for text in sentences:
            words = text.lower().split()
            vectors = [self.embeddings[w] for w in words if w in self.embeddings]
            if vectors:
                avg_vector = np.mean(vectors, axis=0)
            else:
                #print("non trovato")
                #print(f"Numero parole in embedding: {len(self.embeddings)}")

                avg_vector = np.zeros(300)
            results.append(avg_vector)

        tensor = torch.tensor(np.array(results), dtype=torch.float32)
        return tensor.to(self.device) if self.device else tensor

In [12]:
text_embedder_cfg = TextEmbedderConfig(
    text_embedder=LightweightGloveEmbedder(device=device), batch_size=256
)

# data, col_stats_dict = make_pkey_fkey_graph(
#     db,
#     col_to_stype_dict=col_to_stype_dict,
#     text_embedder_cfg=text_embedder_cfg,
#     cache_dir=os.path.join(
#         root_dir, f"rel-f1_materialized_cache"
#     ),
# )

# Graph Loader

In [13]:
# qui i parametri di train_table, val_table, test_table, task e data sono
#parametri globali

def loader_dict_fn(batch_size, num_neighbours, data):
    loader_dict = {}

    for split, table in [
        ("train", train_table),
        ("val", val_table),
        ("test", test_table),
    ]:
        table_input = get_node_train_table_input(
            table=table,
            task=task,
        )

        loader_dict[split] = NeighborLoader(
            data,
            num_neighbors=[num_neighbours for _ in range(2)],
            time_attr="time",
            input_nodes=table_input.nodes,
            input_time=table_input.time,
            transform=table_input.transform,
            batch_size=batch_size,
            temporal_strategy="uniform",
            shuffle=split == "train",
            num_workers=0,
            persistent_workers=False,
        )

    return loader_dict


# Model

In [14]:
class Model(torch.nn.Module):

    def __init__(
        self,
        data: HeteroData,
        col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]],
        num_layers: int,
        channels: int,
        out_channels: int,
        aggr: str,
        norm: str,
        shallow_list: List[NodeType] = [],
        id_awareness: bool = False,
    ):
        super().__init__()

        self.encoder = HeteroEncoder(
            channels=channels,
            node_to_col_names_dict={
                node_type: data[node_type].tf.col_names_dict
                for node_type in data.node_types
            },
            node_to_col_stats=col_stats_dict,
        )

        self.temporal_encoder = HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=channels,
        )

        self.gnn = HeteroGraphSAGE(
            node_types=data.node_types,
            edge_types=data.edge_types,
            channels=channels,
            aggr=aggr,
            num_layers=num_layers,
        )
        self.head = MLP(
            channels,
            out_channels=out_channels,
            norm=norm,
            num_layers=1, ###################################################
        )
        self.embedding_dict = ModuleDict(
            {
                node: Embedding(data.num_nodes_dict[node], channels)
                for node in shallow_list
            }
        )

        self.id_awareness_emb = None
        if id_awareness:
            self.id_awareness_emb = torch.nn.Embedding(1, channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.temporal_encoder.reset_parameters()
        self.gnn.reset_parameters()
        self.head.reset_parameters()
        for embedding in self.embedding_dict.values():
            torch.nn.init.normal_(embedding.weight, std=0.1)
        if self.id_awareness_emb is not None:
            self.id_awareness_emb.reset_parameters()

    def forward(
        self,
        batch: HeteroData,
        entity_table: NodeType,
    ) -> Tensor:
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
            batch.num_sampled_nodes_dict,
            batch.num_sampled_edges_dict,
        )

        return self.head(x_dict[entity_table][: seed_time.size(0)])

    def forward_dst_readout(
        self,
        batch: HeteroData,
        entity_table: NodeType,
        dst_table: NodeType,
    ) -> Tensor:
        if self.id_awareness_emb is None:
            raise RuntimeError(
                "id_awareness must be set True to use forward_dst_readout"
            )
        seed_time = batch[entity_table].seed_time
        x_dict = self.encoder(batch.tf_dict)
        x_dict[entity_table][: seed_time.size(0)] += self.id_awareness_emb.weight

        rel_time_dict = self.temporal_encoder(
            seed_time, batch.time_dict, batch.batch_dict
        )

        for node_type, rel_time in rel_time_dict.items():
            x_dict[node_type] = x_dict[node_type] + rel_time

        for node_type, embedding in self.embedding_dict.items():
            x_dict[node_type] = x_dict[node_type] + embedding(batch[node_type].n_id)

        x_dict = self.gnn(
            x_dict,
            batch.edge_index_dict,
        )

        return self.head(x_dict[dst_table])

# Training functions

Ora necessito di modificare la funzione di train per prendere anche il valore del loader_dict: utile per tuning dei parametri (vedi il codice della funzione di tuning).

In [15]:
def train(model, optimizer, loader_dict) -> float:
    model.train()

    loss_accum = count_accum = 0
    for batch in tqdm(loader_dict["train"]):
        batch = batch.to(device)

        optimizer.zero_grad()
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred

        loss = loss_fn(pred.float(), batch[entity_table].y.float())
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().item() * pred.size(0)
        count_accum += pred.size(0)

    return loss_accum / count_accum


@torch.no_grad()
def test(model, loader: NeighborLoader) -> np.ndarray:
    model.eval()

    pred_list = []
    for batch in loader:
        batch = batch.to(device)
        pred = model(
            batch,
            task.entity_table,
        )
        pred = pred.view(-1) if pred.size(1) == 1 else pred
        pred_list.append(pred.detach().cpu())
    return torch.cat(pred_list, dim=0).numpy()

In [16]:
print(task.target_col)

position


## Scheduler tuning

In [17]:
import torch
import copy
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, ReduceLROnPlateau, OneCycleLR


def get_scheduler(name, optimizer, loader_len, epochs):
    if name == "cosine":
        return CosineAnnealingLR(optimizer, T_max=epochs)

    elif name == "linear_warmup":
        def lr_lambda(epoch):
            warmup_epochs = 5
            if epoch < warmup_epochs:
                return (epoch + 1) / warmup_epochs
            return max(0.1, 1 - (epoch - warmup_epochs) / (epochs - warmup_epochs))
        return LambdaLR(optimizer, lr_lambda=lr_lambda)

    elif name == "plateau":
        return ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=10, verbose=True)

    elif name == "onecycle":
        return OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=loader_len, epochs=epochs)

    else:
        raise ValueError(f"Unknown scheduler: {name}")


def evaluate_scheduler(scheduler_name, model_init_fn, loader_dict_fn, train_fn, test_fn,
                        evaluate_fn, task, val_table, epochs, device):
    model = model_init_fn().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0)
    loader_dict = loader_dict_fn()
    scheduler = get_scheduler(scheduler_name, optimizer, len(loader_dict["train"]), epochs)

    for epoch in range(epochs):
        train_fn(model, optimizer, loader_dict)
        if scheduler_name == "plateau":
            val_pred = test_fn(model, loader_dict["val"])
            val_mae = evaluate_fn(val_pred, val_table, task.metrics)["mae"]
            scheduler.step(val_mae)
        elif scheduler_name == "onecycle":
            # Step per batch
            for _ in loader_dict["train"]:
                scheduler.step()
        else:
            scheduler.step()

    val_pred = test_fn(model, loader_dict["val"])
    val_mae = evaluate_fn(val_pred, val_table, task.metrics)["mae"]
    return val_mae


def compare_schedulers(model_init_fn, loader_dict_fn, train_fn, test_fn,
                        evaluate_fn, task, val_table, epochs, device):
    schedulers = ["cosine", "linear_warmup", "plateau", "onecycle"]
    results = {}

    for name in schedulers:
        print(f"Testing scheduler: {name}")
        val_mae = evaluate_scheduler(name, model_init_fn, loader_dict_fn, train_fn, test_fn,
                                     evaluate_fn, task, val_table, epochs, device)
        results[name] = val_mae
        print(f"{name} val_mae: {val_mae:.4f}")

    best_scheduler = min(results, key=results.get)
    print(f"\nBest scheduler: {best_scheduler} (val_mae: {results[best_scheduler]:.4f})")
    return results


In [18]:
# compare_schedulers(
#     model_init_fn=lambda: Model(
#             data=data,
#             col_stats_dict=col_stats_dict,
#             num_layers=2,
#             channels=128,
#             out_channels=1,
#             aggr="max",
#             norm="batch_norm",
#     ).to(device),
#     loader_dict_fn=lambda: loader_dict_fn(batch_size=512, num_neighbours=256, data=data),
#     train_fn=train,
#     test_fn=test,
#     evaluate_fn=evaluate_performance,
#     task=task,
#     val_table=val_table,
#     epochs=50,
#     device=device
# )


Il migliore scheduler è quindi cosine:

Best scheduler: cosine (val_mae: 2.9038)
{'cosine': 2.9038224850325243,
 'linear_warmup': 2.915358568750865,
 'plateau': 3.30474050643847,
 'onecycle': 3.1991489658215557}

# Ablation on tables

In [19]:
def build_graph_excluding_tables(dataset, tables_to_remove, root_dir, text_embedder_cfg):
    import copy
    from relbench.modeling.graph import make_pkey_fkey_graph

    # 1. Deepcopy del db originale
    original_db = dataset.get_db()
    db = copy.deepcopy(original_db)

    # 2. Rimuovi tutte le tabelle
    for table_to_remove in tables_to_remove:
        if table_to_remove in db.table_dict:
            del db.table_dict[table_to_remove]

    # 3. Rimuovi fkey che puntano a tabelle escluse
    for table in db.table_dict.values():
        table.fkey_col_to_pkey_table = {
            col: tgt for col, tgt in table.fkey_col_to_pkey_table.items()
            if tgt not in tables_to_remove
        }
        # Rimuoviamo anche le colonne fkey nei df
        table.df = table.df.drop(
            columns=[
                col for col, tgt in table.fkey_col_to_pkey_table.items()
                if tgt not in db.table_dict
            ],
            errors="ignore"
        )

    # 4. Filtra anche lo stype
    full_stype = get_stype_proposal(original_db)
    filtered_stype = {
        tab: stype for tab, stype in full_stype.items()
        if tab not in tables_to_remove
    }

    # 5. Costruzione del grafo UNA sola volta
    cache_name = "_".join(sorted(tables_to_remove))
    data, col_stats_dict = make_pkey_fkey_graph(
        db,
        col_to_stype_dict=filtered_stype,
        text_embedder_cfg=text_embedder_cfg,
        cache_dir=os.path.join(root_dir, f"ablation_cache_{cache_name}")
    )

    all_tables = list(db.table_dict.keys())
    print(f"Tabelle rimanenti nel grafo: {all_tables}")

    return data, col_stats_dict, db


# Ablation on features

## Analysis

In [20]:
import copy
import os
from relbench.modeling.graph import make_pkey_fkey_graph
from tqdm import tqdm

def run_feature_ablation(dataset, task, root_dir, text_embedder_cfg):
    db_original = dataset.get_db()
    results = []

    for table_name, table in db_original.table_dict.items():
        df = table.df
        pkey = table.pkey_col
        fkeys = set(table.fkey_col_to_pkey_table.keys())
        y_col = task.target_col if table_name == task.entity_table else None

        # Candidati: tutte le colonne meno chiavi e target
        feature_cols = [col for col in df.columns
                        if col not in fkeys and col != pkey and col != y_col]

        for feature in tqdm(feature_cols, desc=f"{table_name}"):
            # 1. Copia il db
            db = copy.deepcopy(db_original)

            # 2. Rimuovi la colonna dal df e dallo stype
            db.table_dict[table_name].df = db.table_dict[table_name].df.drop(columns=[feature])

            # 3. Ricostruisci stype
            stype_dict = get_stype_proposal(db)

            # 4. Costruzione grafo
            try:
                data, col_stats_dict = make_pkey_fkey_graph(
                    db,
                    col_to_stype_dict=stype_dict,
                    text_embedder_cfg=text_embedder_cfg,
                    cache_dir=os.path.join(root_dir, f"ablation_cache_{table_name}_{feature}")
                )
            except Exception as e:
                print(f"Skipping {table_name}.{feature} due to error: {e}")
                continue

            # 5. Build model and loader
            model = Model(
                data=data,
                col_stats_dict=col_stats_dict,
                num_layers=2,
                channels=128,
                out_channels=1,
                aggr="max",
                norm="batch_norm",
            ).to(device)

            optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0)
            loader_dict = loader_dict_fn(batch_size=512, num_neighbours=256, data=data)

            # 6. Training breve
            for _ in range(15):
                train(model, optimizer, loader_dict)

            val_pred = test(model, loader_dict["val"])
            val_mae = evaluate_performance(val_pred, val_table, task.metrics)["mae"]

            results.append({
                "table": table_name,
                "feature": feature,
                "val_mae": val_mae,
            })

    return sorted(results, key=lambda x: x["val_mae"], reverse=True)


In [21]:
# run_feature_ablation(dataset, task, "root", text_embedder_cfg)

[{'table': 'results',
  'feature': 'positionOrder',
  'val_mae': 3.950804795244175},


 {'table': 'standings', 'feature': 'position', 'val_mae': 3.8903072947092507},


 {'table': 'races', 'feature': 'round', 'val_mae': 3.8874379688688494},


 {'table': 'circuits', 'feature': 'country', 'val_mae': 3.8866149727790136},


 {'table': 'constructor_standings',
  'feature': 'points',
  'val_mae': 3.8830483684399644},


 {'table': 'results', 'feature': 'position', 'val_mae': 3.881930207282444},


 {'table': 'circuits', 'feature': 'name', 'val_mae': 3.8806671545198146},


 {'table': 'circuits', 'feature': 'circuitRef', 'val_mae': 3.8796486551951155},


 {'table': 'results', 'feature': 'rank', 'val_mae': 3.8729131191511033},


 {'table': 'drivers', 'feature': 'surname', 'val_mae': 3.8721600948529957},


 {'table': 'drivers', 'feature': 'code', 'val_mae': 3.8716865850434594},


 {'table': 'constructors',
  'feature': 'nationality',
  'val_mae': 3.8713674863816583},


 {'table': 'circuits', 'feature': 'alt', 'val_mae': 3.870047071733392},


 {'table': 'results', 'feature': 'grid', 'val_mae': 3.8692150931399745},


 {'table': 'standings', 'feature': 'points', 'val_mae': 3.8686533158989693},


 {'table': 'drivers', 'feature': 'dob', 'val_mae': 3.8663228321011736},


 {'table': 'standings', 'feature': 'wins', 'val_mae': 3.866027179684891},


 {'table': 'constructor_standings',
  'feature': 'position',
  'val_mae': 3.8660021299669247},


 {'table': 'results',
  'feature': 'milliseconds',
  'val_mae': 3.8608299076995136},



 {'table': 'results', 'feature': 'statusId', 'val_mae': 3.8595401444113406},


 {'table': 'races', 'feature': 'year', 'val_mae': 3.8587325821738285},


 {'table': 'qualifying', 'feature': 'number', 'val_mae': 3.858693720423228},


 {'table': 'results', 'feature': 'number', 'val_mae': 3.8582406295643863},


 {'table': 'results', 'feature': 'points', 'val_mae': 3.857474676194634},


 {'table': 'qualifying', 'feature': 'position', 'val_mae': 3.857085906559416},


 {'table': 'circuits', 'feature': 'location', 'val_mae': 3.8563255510094487},


 {'table': 'constructor_results',
  'feature': 'points',
  'val_mae': 3.8561439144667102},


 {'table': 'races', 'feature': 'time', 'val_mae': 3.8559820535745155},


 {'table': 'results', 'feature': 'fastestLap', 'val_mae': 3.855821519799446},


 {'table': 'drivers', 'feature': 'nationality', 'val_mae': 3.8555740881380274},


 {'table': 'results', 'feature': 'laps', 'val_mae': 3.8554142874881436},


 {'table': 'drivers', 'feature': 'driverRef', 'val_mae': 3.8553647816619163},


 {'table': 'constructor_standings',
  'feature': 'wins',
  'val_mae': 3.85485139335882},


 {'table': 'drivers', 'feature': 'forename', 'val_mae': 3.854089958761721},


 {'table': 'circuits', 'feature': 'lat', 'val_mae': 3.8531440858452335},


 {'table': 'constructors', 'feature': 'name', 'val_mae': 3.8531423841385974},


 {'table': 'circuits', 'feature': 'lng', 'val_mae': 3.852716157463446},


 {'table': 'races', 'feature': 'name', 'val_mae': 3.851620068419513},


 {'table': 'constructors',
  'feature': 'constructorRef',
  'val_mae': 3.8499111845721066}]

## Main results

Features essenziali:


| Tabella     | Feature         | MAE  | ΔMAE (≈) | Impatto        |
| ----------- | --------------- | ---- | -------- | -------------- |
| `results`   | `positionOrder` | 3.95 | +0.08    |  **Critica** |
| `standings` | `position`      | 3.89 | +0.02    |  Utile       |
| `races`     | `round`         | 3.88 | +0.01    |  Utile       |


Features meno importanti:


| Tabella        | Feature          | MAE    | ΔMAE (≈) | Impatto         |
| -------------- | ---------------- | ------ | -------- | --------------- |
| `constructors` | `constructorRef` | 3.8499 | −0.02    | Forse rumore |
| `races`        | `name`           | 3.8516 | −0.02    | Inutile      |
| `circuits`     | `lng`            | 3.8527 | −0.02    | Inutile      |


results.positionOrder è la feature più importante del dataset.

Feature come position, round, points... contribuiscono sensibilmente.

Alcune feature testuali o descrittive (name, location, constructorRef) non aiutano, e anzi potrebbero essere rumorose.



## Function for removing features

In [22]:
import copy
import os
from relbench.modeling.graph import make_pkey_fkey_graph

def get_graph_without_features(db, features_to_remove, root_dir, text_embedder_cfg):
    """
    Rimuove un insieme di feature specificate da diverse tabelle e restituisce il grafo aggiornato.

    :param dataset: Il dataset originale (oggetto relbench)
    :param features_to_remove: Lista di tuple (table_name, feature_name)
    :param root_dir: Cartella cache
    :param text_embedder_cfg: Configurazione dell'embedding testuale
    :return: (data, col_stats_dict)
     """
    # db_original = dataset.get_db()
    # db = copy.deepcopy(db_original)

    # Rimuove le feature indicate
    for table_name, feature in features_to_remove:
        if table_name in db.table_dict:
            table = db.table_dict[table_name]
            if feature in table.df.columns:
                table.df = table.df.drop(columns=[feature])
        else:
          print(f"feature {feature}  di tabella {table_name} non presente")

    # Ricostruisce lo stype_dict aggiornato
    stype_dict = get_stype_proposal(db)

    # Costruzione grafo aggiornato
    cache_name = "__".join([f"{t}_{f}" for t, f in features_to_remove])
    data, col_stats_dict = make_pkey_fkey_graph(
        db,
        col_to_stype_dict=stype_dict,
        text_embedder_cfg=text_embedder_cfg,
        cache_dir=os.path.join(root_dir, f"feature_ablation_cache_{cache_name}")
    )

    return data, col_stats_dict


# Main

## Tables ablation

In [23]:
print(f"Testing without table: circuits and constructor_result")
tables_to_remove = ["circuits","constructor_results"]

data, col_stats_dict, db = build_graph_excluding_tables(
        dataset, tables_to_remove=tables_to_remove,
        root_dir=root_dir,
        text_embedder_cfg=text_embedder_cfg
)

loader_dict = loader_dict_fn(batch_size=512, num_neighbours=256, data=data)


Testing without table: circuits and constructor_result


Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00,  6.14it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 250.57it/s]
Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 243.83it/s]
  ser = pd.to_datetime(ser, format=time_format)
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 219.34it/s]
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 724.53it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 290.75it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 247.69it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 245.99it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 296.91it/s]


Tabelle rimanenti nel grafo: ['standings', 'constructor_standings', 'constructors', 'qualifying', 'results', 'races', 'drivers']


## Features ablation

In [24]:
features_to_remove = [
    ("constructors", "constructorRef"),
    ("constructors", "name"),
    ("races", "name"),
    ("circuits", "lat"),#vediamo se non presente
    ("circuits", "lng"),#idem
    ("drivers", "forename"),
    ("drivers", "driverRef"),
    ("drivers", "nationality"),
    ("constructor_standings", "wins"),
    ("results", "laps"),
    ("results", "points")
]

data, col_stats_dict = get_graph_without_features(db, features_to_remove, root_dir, text_embedder_cfg)


feature lat  di tabella circuits non presente
feature lng  di tabella circuits non presente


Embedding raw data in mini-batch: 100%|██████████| 1/1 [00:00<00:00, 230.33it/s]
  ser = pd.to_datetime(ser, format=time_format)
  ser = pd.to_datetime(ser, format=self.format, errors='coerce')
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 480.68it/s]
Embedding raw data in mini-batch: 100%|██████████| 4/4 [00:00<00:00, 215.41it/s]


## Main

In [26]:
loader_dict = loader_dict_fn(batch_size=512, num_neighbours=256, data=data)

model = Model(
        data=data,
        col_stats_dict=col_stats_dict,
        num_layers=2,
        channels=128,
        out_channels=1,
        aggr="max",
        norm="batch_norm",
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0)

# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer,
#         mode="min",
#         factor=0.5,
#         patience=10,
#         verbose=True
# )
epochs = 600

scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

early_stopping = EarlyStopping(
    patience=30,
    delta=0.0,
    verbose=True,
    path="best_basic_model.pt"
)


state_dict = None
test_table = task.get_table("test", mask_input_cols=False)
best_val_metric = -math.inf if higher_is_better else math.inf
best_test_metric = -math.inf if higher_is_better else math.inf


for epoch in range(1, epochs + 1):
    train_loss = train(model, optimizer, loader_dict=loader_dict)

    train_pred = test(model, loader_dict["train"])
    train_metrics = evaluate_performance(train_pred, train_table, task.metrics)
    train_mae_preciso = evaluate_on_full_train(model, loader_dict["train"])

    val_pred = test(model, loader_dict["val"])
    val_metrics = evaluate_performance(val_pred, val_table, task.metrics)

    test_pred = test(model, loader_dict["test"])
    test_metrics = evaluate_performance(test_pred, test_table, task.metrics)

    scheduler.step(val_metrics[tune_metric])

    if (higher_is_better and val_metrics[tune_metric] > best_val_metric) or (
            not higher_is_better and val_metrics[tune_metric] < best_val_metric
    ):
        best_val_metric = val_metrics[tune_metric]
        state_dict = copy.deepcopy(model.state_dict())

    #test:
    if (higher_is_better and test_metrics[tune_metric] > best_test_metric) or (
            not higher_is_better and test_metrics[tune_metric] < best_test_metric
    ):
        best_test_metric = test_metrics[tune_metric]
        state_dict_test = copy.deepcopy(model.state_dict())

    current_lr = optimizer.param_groups[0]["lr"]
    print(f"Epoch: {epoch:02d}, Train mae: {train_mae_preciso:.2f}, Validation MAE: {val_metrics[tune_metric]:.2f}, Test MAE: {test_metrics[tune_metric]:.2f}, LR: {current_lr:.6f}")

    early_stopping(val_metrics[tune_metric], model)

    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch}")
        break
print(f"best validation results: {best_val_metric}")
print(f"best test results: {best_test_metric}")

100%|██████████| 15/15 [00:02<00:00,  5.43it/s]


Epoch: 01, Train mae: 10.52, Validation MAE: 7.70, Test MAE: 8.56, LR: 0.000500
Validation loss migliorata (inf --> 7.699859). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  6.14it/s]


Epoch: 02, Train mae: 10.00, Validation MAE: 7.17, Test MAE: 8.03, LR: 0.000500
Validation loss migliorata (7.699859 --> 7.165066). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.98it/s]


Epoch: 03, Train mae: 9.58, Validation MAE: 6.75, Test MAE: 7.59, LR: 0.000500
Validation loss migliorata (7.165066 --> 6.745498). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.95it/s]


Epoch: 04, Train mae: 9.16, Validation MAE: 6.35, Test MAE: 7.18, LR: 0.000500
Validation loss migliorata (6.745498 --> 6.348452). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.96it/s]


Epoch: 05, Train mae: 8.76, Validation MAE: 5.97, Test MAE: 6.80, LR: 0.000500
Validation loss migliorata (6.348452 --> 5.970862). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.57it/s]


Epoch: 06, Train mae: 8.37, Validation MAE: 5.61, Test MAE: 6.45, LR: 0.000500
Validation loss migliorata (5.970862 --> 5.608460). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.97it/s]


Epoch: 07, Train mae: 7.99, Validation MAE: 5.29, Test MAE: 6.12, LR: 0.000500
Validation loss migliorata (5.608460 --> 5.285204). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.19it/s]


Epoch: 08, Train mae: 7.64, Validation MAE: 4.99, Test MAE: 5.81, LR: 0.000500
Validation loss migliorata (5.285204 --> 4.989032). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.98it/s]


Epoch: 09, Train mae: 7.31, Validation MAE: 4.73, Test MAE: 5.52, LR: 0.000500
Validation loss migliorata (4.989032 --> 4.725163). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.23it/s]


Epoch: 10, Train mae: 7.01, Validation MAE: 4.50, Test MAE: 5.25, LR: 0.000500
Validation loss migliorata (4.725163 --> 4.503823). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.84it/s]


Epoch: 11, Train mae: 6.74, Validation MAE: 4.31, Test MAE: 5.03, LR: 0.000500
Validation loss migliorata (4.503823 --> 4.313928). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.38it/s]


Epoch: 12, Train mae: 6.50, Validation MAE: 4.16, Test MAE: 4.83, LR: 0.000500
Validation loss migliorata (4.313928 --> 4.163872). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.61it/s]


Epoch: 13, Train mae: 6.31, Validation MAE: 4.05, Test MAE: 4.68, LR: 0.000500
Validation loss migliorata (4.163872 --> 4.052688). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.49it/s]


Epoch: 14, Train mae: 6.13, Validation MAE: 3.96, Test MAE: 4.55, LR: 0.000500
Validation loss migliorata (4.052688 --> 3.963899). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.63it/s]


Epoch: 15, Train mae: 5.99, Validation MAE: 3.90, Test MAE: 4.46, LR: 0.000500
Validation loss migliorata (3.963899 --> 3.898593). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.55it/s]


Epoch: 16, Train mae: 5.87, Validation MAE: 3.86, Test MAE: 4.39, LR: 0.000500
Validation loss migliorata (3.898593 --> 3.863406). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.76it/s]


Epoch: 17, Train mae: 5.78, Validation MAE: 3.86, Test MAE: 4.35, LR: 0.000500
Validation loss migliorata (3.863406 --> 3.856908). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.77it/s]


Epoch: 18, Train mae: 5.71, Validation MAE: 3.86, Test MAE: 4.33, LR: 0.000500
EarlyStopping counter: 1 / 30


100%|██████████| 15/15 [00:02<00:00,  5.76it/s]


Epoch: 19, Train mae: 5.66, Validation MAE: 3.88, Test MAE: 4.32, LR: 0.000500
EarlyStopping counter: 2 / 30


100%|██████████| 15/15 [00:02<00:00,  5.77it/s]


Epoch: 20, Train mae: 5.62, Validation MAE: 3.90, Test MAE: 4.31, LR: 0.000500
EarlyStopping counter: 3 / 30


100%|██████████| 15/15 [00:02<00:00,  5.63it/s]


Epoch: 21, Train mae: 5.59, Validation MAE: 3.89, Test MAE: 4.29, LR: 0.000500
EarlyStopping counter: 4 / 30


100%|██████████| 15/15 [00:02<00:00,  5.71it/s]


Epoch: 22, Train mae: 5.55, Validation MAE: 3.77, Test MAE: 4.20, LR: 0.000500
Validation loss migliorata (3.856908 --> 3.772424). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.29it/s]


Epoch: 23, Train mae: 5.45, Validation MAE: 3.43, Test MAE: 3.96, LR: 0.000500
Validation loss migliorata (3.772424 --> 3.430688). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.75it/s]


Epoch: 24, Train mae: 5.28, Validation MAE: 3.36, Test MAE: 3.88, LR: 0.000500
Validation loss migliorata (3.430688 --> 3.359099). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.10it/s]


Epoch: 25, Train mae: 5.13, Validation MAE: 3.14, Test MAE: 3.84, LR: 0.000500
Validation loss migliorata (3.359099 --> 3.144816). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.78it/s]


Epoch: 26, Train mae: 4.97, Validation MAE: 3.04, Test MAE: 3.90, LR: 0.000500
Validation loss migliorata (3.144816 --> 3.036825). Salvo modello...


100%|██████████| 15/15 [00:03<00:00,  4.94it/s]


Epoch: 27, Train mae: 4.89, Validation MAE: 2.96, Test MAE: 4.01, LR: 0.000500
Validation loss migliorata (3.036825 --> 2.961598). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.76it/s]


Epoch: 28, Train mae: 4.81, Validation MAE: 2.85, Test MAE: 4.12, LR: 0.000500
Validation loss migliorata (2.961598 --> 2.847815). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.14it/s]


Epoch: 29, Train mae: 4.74, Validation MAE: 2.94, Test MAE: 4.01, LR: 0.000500
EarlyStopping counter: 1 / 30


100%|██████████| 15/15 [00:02<00:00,  5.93it/s]


Epoch: 30, Train mae: 4.68, Validation MAE: 2.83, Test MAE: 4.16, LR: 0.000500
Validation loss migliorata (2.847815 --> 2.826951). Salvo modello...


100%|██████████| 15/15 [00:02<00:00,  5.52it/s]


Epoch: 31, Train mae: 4.67, Validation MAE: 2.98, Test MAE: 4.03, LR: 0.000500
EarlyStopping counter: 1 / 30


100%|██████████| 15/15 [00:02<00:00,  5.88it/s]


Epoch: 32, Train mae: 4.64, Validation MAE: 2.90, Test MAE: 4.11, LR: 0.000500
EarlyStopping counter: 2 / 30


100%|██████████| 15/15 [00:02<00:00,  5.82it/s]


Epoch: 33, Train mae: 4.65, Validation MAE: 2.84, Test MAE: 4.26, LR: 0.000500
EarlyStopping counter: 3 / 30


100%|██████████| 15/15 [00:02<00:00,  5.90it/s]


Epoch: 34, Train mae: 4.55, Validation MAE: 2.87, Test MAE: 4.17, LR: 0.000500
EarlyStopping counter: 4 / 30


100%|██████████| 15/15 [00:02<00:00,  5.84it/s]


Epoch: 35, Train mae: 4.46, Validation MAE: 2.88, Test MAE: 4.20, LR: 0.000500
EarlyStopping counter: 5 / 30


100%|██████████| 15/15 [00:02<00:00,  5.50it/s]


Epoch: 36, Train mae: 4.42, Validation MAE: 2.95, Test MAE: 4.23, LR: 0.000500
EarlyStopping counter: 6 / 30


100%|██████████| 15/15 [00:02<00:00,  5.65it/s]


Epoch: 37, Train mae: 4.40, Validation MAE: 2.88, Test MAE: 4.31, LR: 0.000500
EarlyStopping counter: 7 / 30


100%|██████████| 15/15 [00:02<00:00,  5.24it/s]


Epoch: 38, Train mae: 4.36, Validation MAE: 3.15, Test MAE: 4.24, LR: 0.000500
EarlyStopping counter: 8 / 30


100%|██████████| 15/15 [00:02<00:00,  5.78it/s]


Epoch: 39, Train mae: 4.34, Validation MAE: 2.98, Test MAE: 4.28, LR: 0.000500
EarlyStopping counter: 9 / 30


100%|██████████| 15/15 [00:03<00:00,  4.98it/s]


Epoch: 40, Train mae: 4.24, Validation MAE: 3.02, Test MAE: 4.46, LR: 0.000500
EarlyStopping counter: 10 / 30


100%|██████████| 15/15 [00:02<00:00,  5.67it/s]


Epoch: 41, Train mae: 4.19, Validation MAE: 3.02, Test MAE: 4.55, LR: 0.000500
EarlyStopping counter: 11 / 30


100%|██████████| 15/15 [00:02<00:00,  5.02it/s]


Epoch: 42, Train mae: 4.12, Validation MAE: 3.06, Test MAE: 4.56, LR: 0.000500
EarlyStopping counter: 12 / 30


100%|██████████| 15/15 [00:02<00:00,  5.70it/s]


Epoch: 43, Train mae: 4.09, Validation MAE: 3.02, Test MAE: 4.63, LR: 0.000500
EarlyStopping counter: 13 / 30


100%|██████████| 15/15 [00:02<00:00,  5.29it/s]


Epoch: 44, Train mae: 4.11, Validation MAE: 3.25, Test MAE: 4.49, LR: 0.000500
EarlyStopping counter: 14 / 30


100%|██████████| 15/15 [00:02<00:00,  5.73it/s]


Epoch: 45, Train mae: 3.96, Validation MAE: 3.04, Test MAE: 4.68, LR: 0.000500
EarlyStopping counter: 15 / 30


100%|██████████| 15/15 [00:02<00:00,  5.47it/s]


Epoch: 46, Train mae: 3.88, Validation MAE: 3.03, Test MAE: 4.76, LR: 0.000500
EarlyStopping counter: 16 / 30


100%|██████████| 15/15 [00:02<00:00,  5.75it/s]


Epoch: 47, Train mae: 3.82, Validation MAE: 3.05, Test MAE: 4.51, LR: 0.000500
EarlyStopping counter: 17 / 30


100%|██████████| 15/15 [00:02<00:00,  5.60it/s]


Epoch: 48, Train mae: 3.80, Validation MAE: 3.03, Test MAE: 4.44, LR: 0.000500
EarlyStopping counter: 18 / 30


100%|██████████| 15/15 [00:02<00:00,  5.70it/s]


Epoch: 49, Train mae: 3.80, Validation MAE: 3.06, Test MAE: 4.46, LR: 0.000500
EarlyStopping counter: 19 / 30


100%|██████████| 15/15 [00:02<00:00,  5.80it/s]


Epoch: 50, Train mae: 3.60, Validation MAE: 3.32, Test MAE: 4.54, LR: 0.000500
EarlyStopping counter: 20 / 30


100%|██████████| 15/15 [00:02<00:00,  5.35it/s]


Epoch: 51, Train mae: 3.72, Validation MAE: 3.09, Test MAE: 4.82, LR: 0.000500
EarlyStopping counter: 21 / 30


100%|██████████| 15/15 [00:02<00:00,  5.73it/s]


Epoch: 52, Train mae: 3.51, Validation MAE: 3.17, Test MAE: 4.51, LR: 0.000500
EarlyStopping counter: 22 / 30


100%|██████████| 15/15 [00:02<00:00,  5.08it/s]


Epoch: 53, Train mae: 3.39, Validation MAE: 3.32, Test MAE: 4.61, LR: 0.000500
EarlyStopping counter: 23 / 30


100%|██████████| 15/15 [00:02<00:00,  5.64it/s]


Epoch: 54, Train mae: 3.59, Validation MAE: 3.22, Test MAE: 4.70, LR: 0.000500
EarlyStopping counter: 24 / 30


100%|██████████| 15/15 [00:03<00:00,  4.82it/s]


Epoch: 55, Train mae: 3.50, Validation MAE: 3.29, Test MAE: 4.76, LR: 0.000500
EarlyStopping counter: 25 / 30


100%|██████████| 15/15 [00:02<00:00,  5.66it/s]


Epoch: 56, Train mae: 3.21, Validation MAE: 3.26, Test MAE: 4.78, LR: 0.000500
EarlyStopping counter: 26 / 30


100%|██████████| 15/15 [00:02<00:00,  5.22it/s]


Epoch: 57, Train mae: 3.31, Validation MAE: 3.15, Test MAE: 4.44, LR: 0.000500
EarlyStopping counter: 27 / 30


100%|██████████| 15/15 [00:02<00:00,  5.71it/s]


Epoch: 58, Train mae: 3.23, Validation MAE: 3.50, Test MAE: 4.80, LR: 0.000500
EarlyStopping counter: 28 / 30


100%|██████████| 15/15 [00:02<00:00,  5.44it/s]


Epoch: 59, Train mae: 3.17, Validation MAE: 3.23, Test MAE: 4.72, LR: 0.000500
EarlyStopping counter: 29 / 30


100%|██████████| 15/15 [00:02<00:00,  5.79it/s]


Epoch: 60, Train mae: 3.12, Validation MAE: 3.26, Test MAE: 4.70, LR: 0.000500
EarlyStopping counter: 30 / 30
Early stopping triggered at epoch 60
best validation results: 2.826950886955083
best test results: 3.8405526310937446


