# GNN model training

In [None]:
from pkgutil import find_loader
import sys
import platform
import time
from os.path import exists, isdir
import toml

import torch
import torch_geometric
from torch_geometric.loader import DataLoader

if not find_loader("gnn_eads"):
    sys.path.insert(0, "../src/")
from gnn_eads.constants import FG_RAW_GROUPS, loss_dict, pool_seq_dict, conv_layer, sigma_dict, pool_dict
from gnn_eads.functions import create_loaders, scale_target, train_loop, test_loop, get_id
from gnn_eads.processed_datasets import create_post_processed_datasets
from gnn_eads.nets import FlexibleNet
from gnn_eads.post_training import create_model_report
from gnn_eads.create_graph_datasets import create_paths, create_graph_datasets

## Load hyperparameters of the learning process

The hyperparameters are all those parameters that are initialized before performing the model training (i.e., everything different from the model parameters). Hyperparameters can be categorized into model-related and process-related: Model-related hyperparameters are the activation function and the depth of the hidden layers, while the process-related ones are for example the batch size, the number of epochs and the loss function for the model optimization.

The hyperparameters, together with the graph settings and the data path, are given as input via a toml file. In this folder are present a TEMPLATE.toml file and the configuration file referred to the best model presented in the work.

In [None]:
HYPERPARAMS = toml.load("best_model.toml")  
data_path = HYPERPARAMS["data"]["root"]    
graph_settings = HYPERPARAMS["graph"]
train = HYPERPARAMS["train"]
architecture = HYPERPARAMS["architecture"]

## Create graphs from raw DFT FG-dataset

In [None]:
graph_identifier = get_id(graph_settings)
family_paths = create_paths(FG_RAW_GROUPS, data_path, graph_identifier)
if exists(data_path + "/amides/pre_" + graph_identifier):  
    FG_dataset = create_post_processed_datasets(graph_identifier, family_paths)
else:
    print("Creating graphs from raw data ...")  
    create_graph_datasets(graph_settings, family_paths)
    FG_dataset = create_post_processed_datasets(graph_identifier, family_paths)

## Data Splitting and target scaling

The FG-dataset is split among the train, validation and test sets via a stratified data split approach.
The target scaling must be applied using parameters independent of the test set, as this would lead to "data leakage".
Here, we apply the target scaling with the `scale_target` function, providing the optional parameter mode="std" in order to apply standardization. Normalization can be applied optionally, providing the parameter mode="norm". 

In [None]:
train_loader, val_loader, test_loader = create_loaders(FG_dataset,
                                                       batch_size=train["batch_size"],
                                                       split=train["splits"], 
                                                       test=train["test_set"])
train_loader, val_loader, test_loader, mean, std = scale_target(train_loader,
                                                                val_loader,
                                                                test_loader,
                                                                mode=train["target_scaling"],
                                                                test=train["test_set"])

## Visualize graphs

In [None]:
from gnn_eads.graph_tools import plotter
plotter(train_loader.dataset[10]) # Random graph, change index to see other graphs

### Device selection (GPU/CPU)

Having a CUDA capable GPU is optimal for working with Deep Learning models, as its structure can be exploited in order to speed up the training.

In [None]:
device_dict = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print("Device name: {} (GPU)".format(torch.cuda.get_device_name(0)))
    device_dict["name"] = torch.cuda.get_device_name(0)
    device_dict["CudaDNN_enabled"] = torch.backends.cudnn.enabled
    device_dict["CUDNN_version"] = torch.backends.cudnn.version()
    device_dict["CUDA_version"] = torch.version.cuda
else:
    print("Device name: CPU")
    device_dict["name"] = "CPU" 

### GNN model instantiation

Instantiate model object and store it to the available device (GPU or CPU).

In [None]:
model = FlexibleNet(dim=architecture["dim"],
                        N_linear=architecture["n_linear"], 
                        N_conv=architecture["n_conv"], 
                        adj_conv=architecture["adj_conv"],  
                        sigma=sigma_dict[architecture["sigma"]], 
                        bias=architecture["bias"], 
                        conv=conv_layer[architecture["conv_layer"]], 
                        pool=pool_dict[architecture["pool_layer"]], 
                        pool_ratio=architecture["pool_ratio"], 
                        pool_heads=architecture["pool_heads"], 
                        pool_seq=pool_seq_dict[architecture["pool_seq"]], 
                        pool_layer_norm=architecture["pool_layer_norm"]).to(device)   

## GNN Training

### Optimizer

Used optimizer for the training is Adam, algorithm for first-order gradient-based optimization of
stochastic objective functions, based on adaptive estimates of lower-order mo-
ments.

In [None]:
optimizer = torch.optim.Adam(model.parameters(),
                                 lr=train["lr0"],
                                 eps=train["eps"], 
                                 weight_decay=train["weight_decay"],
                                 amsgrad=train["amsgrad"])

### Learning Rate (LR) Scheduler

Helps steering the learning rate during the training, providing faster convergence and higher accuracy. The used scheduler is the "Reduce On Loss Plateau Decay".

In [None]:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              mode='min',
                                                              factor=train["factor"],
                                                              patience=train["patience"],
                                                              min_lr=train["minlr"])  

### Run Training 

In [None]:
loss_list, train_list, val_list, test_list = [], [], [], []         
t0 = time.time() 
for epoch in range(1, train["epochs"]+1):
    torch.cuda.empty_cache()
    lr = lr_scheduler.optimizer.param_groups[0]['lr']        
    loss, train_MAE = train_loop(model, device, train_loader, optimizer, loss_dict[train["loss_function"]])  
    val_MAE = test_loop(model, val_loader, device, std)  
    lr_scheduler.step(val_MAE)
    if train["test_set"]:
        test_MAE = test_loop(model, test_loader, device, std, mean)         
        print('Epoch {:03d}: LR={:.7f}  Train MAE: {:.4f} eV  Validation MAE: {:.4f} eV '             
              'Test MAE: {:.4f} eV'.format(epoch, lr, train_MAE*std, val_MAE, test_MAE))
        test_list.append(test_MAE)
    else:
        print('Epoch {:03d}: LR={:.7f}  Train MAE: {:.6f} eV  Validation MAE: {:.6f} eV '
              .format(epoch, lr, train_MAE*std, val_MAE))         
    loss_list.append(loss)
    train_list.append(train_MAE * std)
    val_list.append(val_MAE)
print("-----------------------------------------------------------------------------------------")
training_time = (time.time() - t0)/60  
print("Training time: {:.2f} min".format(training_time))
device_dict["training_time"] = training_time


### Save model and performance analysis

Depending on the use or not of a test set, the information stored in the model report folder will be different. If a test set is used to test the final model, more files will be generated (as learning curve, error distribution plot, etc.).

In [None]:
create_model_report("TEST",   # Provide a name different from models present in the directory "models"
                    HYPERPARAMS,
                    model,
                    (train_loader, val_loader, test_loader), 
                    (mean, std),  
                    (train_list, val_list, test_list), 
                    device_dict)
                               