# Training

In [1]:
import pandas as pd
import optuna
import numpy as np
import torch
from dataset import Dataset
from torch_geometric_temporal.nn import MTGNN
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from config import CONFIG
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


## Dataset

In [2]:
dataset = Dataset(CONFIG.model_config.seq_length - 1, countries=CONFIG.countries)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [len(dataset)-100, 100])

In [None]:
num_nodes=len(dataset.dataframe.columns)

In [None]:
def train(model, optimizer, criterion, loader, epochs):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (x, y) in enumerate(loader):
            optimizer.zero_grad()
            outputs = model(x)

            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        # print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss/len(loader):.3f}')
    return running_loss / len(loader)

In [None]:
def objective(trial):
    
    batch_size = int(trial.suggest_discrete_uniform("batch_size", 2, 8, 2))
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    val_loader = DataLoader(train_dataset, batch_size=batch_size)

    model = MTGNN(
        gcn_true=True,
        build_adj=True,
        num_nodes=num_nodes,
        seq_length=dataset.input_length + 1,
        kernel_set=[1], 
        kernel_size=1,
        gcn_depth=trial.suggest_int("gcn_depth", 1, 8),
        dropout=0.3,
        subgraph_size=trial.suggest_int("subgraph_size", 1, 8),  # Warning: need to be lower than num_nodes, number of neighbors in the generated Adj matrix.
        node_dim=4, 
        dilation_exponential=1, # Using small enough sequences that we do not need dilation.
        conv_channels=trial.suggest_int("conv_channels", 1, 32),
        residual_channels=trial.suggest_int("residual_channels", 1, 32),
        skip_channels=trial.suggest_int("skip_channels", 1, 32),
        end_channels=trial.suggest_int("end_channels", 1, 32),
        in_dim=1, # Number of features per node (1 in our case) 
        out_dim=9, # Correspond to the seq length in y
        layers=trial.suggest_int("layers", 1, 8),
        propalpha=0.05,
        tanhalpha=3,
        layer_norm_affline=True
    )

    criterion = torch.nn.L1Loss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=CONFIG.momentum)
    
    train_final_loss = train(model, optimizer, criterion, train_loader, 50)
    eval_final_loss = train(model, optimizer, criterion, val_loader, 1)

    return eval_final_loss

In [None]:
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=100)

study.best_params