# GNN training: Tutorial step by step

Here we present the typical workflow applied to train GAME-Net. From the Graph FG-dataset generation to the training itself, with corersponding post-processing of the results.

## 0) Imports

In [None]:
import sys
sys.path.append('../../src/')
import time
from os.path import exists
import toml

import torch

from gnn_eads.constants import FG_RAW_GROUPS, sigma_dict, pool_dict, pool_seq_dict, conv_layer, loss_dict
from gnn_eads.functions import create_loaders, scale_target, train_loop, test_loop, get_id
from gnn_eads.processed_datasets import create_post_processed_datasets
from gnn_eads.nets import FlexibleNet
from gnn_eads.post_training import create_model_report
from gnn_eads.create_graph_datasets import create_graph_datasets
from gnn_eads.paths import create_paths
from gnn_eads.graph_tools import plotter

## 1) Load hyperparameters 

The hyperparameters are the variables initialized before performing the model training (i.e., everything not trainable). Hyperparameters can be categorized into model-related and process-related: Model-related hyperparameters define the model architecture (e.g., layers' depth and width, bias, etc.), while the process-related ones define the training workflow (i.e., number of epochs, loss function, optimizer, batch size, etc.).

The hyperparameters, together with the graph settings and the data path, are given as input via a .toml file. In the folder `input_train_GNN` a TEMPLATE.toml file is present. We will use this setting for this tutorial.

P.S. Before loading the .toml file, open it and set you root folder where you store the data.

In [None]:
HYPERPARAMS = toml.load("input_train_GNN/TEMPLATE.toml")  
data_path = HYPERPARAMS["data"]["root"]    
graph_settings = HYPERPARAMS["graph"]
train = HYPERPARAMS["train"]
architecture = HYPERPARAMS["architecture"]

## 2) Create graphs from DFT data

Based on the graph representation settings provided in the input .toml file (voronoi tolerance, metal scaling factor and 2nd-order metal atoms inclusion), the next cell will create the graph FG-dataset from the DFT data. This process involves two steps: 
1. Converting all the DFT data to graphs, which are saved as "pre_xx_bool_yy.dat" files in each FG-dataset subset family. These are plain text files which contain the necessary information to then generate the graph in a format suitable for Pytorch geometric.
2. Generate the graph FG-dataset, processing the information in the pre_xx_bool_yy.dat files and filtering out wrong graph representations. The final FG_dataset object is a container of the chemical families in the FG-dataset. These are saved as "post_xx_bool_yy.dat" files.

In [None]:
graph_identifier = get_id(graph_settings)
family_paths = create_paths(FG_RAW_GROUPS, data_path, graph_identifier)
if exists(data_path + "/amides/pre_" + graph_identifier):  
    FG_dataset = create_post_processed_datasets(graph_identifier, family_paths)
else:
    print("Creating graphs from raw data ...")  
    create_graph_datasets(graph_settings, family_paths)
    FG_dataset = create_post_processed_datasets(graph_identifier, family_paths)

In [None]:
type(FG_dataset)
data_points = [len(FG_dataset[i]) for i in range(len(FG_dataset))]
total_data_points = sum(data_points)
print("Total number of data points: ", total_data_points)

## 3) Data Splitting and target scaling

The FG-dataset is split among the train, validation and test sets via a stratified data split approach and then a target scaling is applied.
The target scaling must be applied using parameters independent of the test set, as this would lead to "data leakage".
Here, we apply the target scaling with the `scale_target` function, providing the optional parameter mode="std" in order to apply standardization.

In [None]:
train_loader, val_loader, test_loader = create_loaders(FG_dataset,
                                                       batch_size=train["batch_size"],
                                                       split=train["splits"], 
                                                       test=train["test_set"])
train_loader, val_loader, test_loader, mean, std = scale_target(train_loader,
                                                                val_loader,
                                                                test_loader,
                                                                mode=train["target_scaling"],
                                                                test=train["test_set"])

## 4) Graph inspection

To have an idea of what the graphs objects are, here we show a visualization and the mathematical representation of a random sample.

In [None]:
random_graph = train_loader.dataset[991]  # Change index to see different graphs
plotter(random_graph, dpi=150) 

This graph is represented mathematically by (i) node feature matrix containing only the atomic element via one-hot encoding, (ii) the edge list which defined the connectvity and (iii) its scaled DFT scaled energy.

In [None]:
# node atrtibutes
random_graph.x

In [None]:
# Connectivity
random_graph.edge_index

In [None]:
# Graph label
random_graph.y

In [None]:
print(random_graph)

## 5) Device selection

Having a CUDA capable GPU is optimal for working with Deep Learning models, as its parallelized architecture can be exploited to speed up the training (i.e. huge number of matrix multiplications).

In [None]:
device_dict = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    print("Device name: {} (GPU)".format(torch.cuda.get_device_name(0)))
    device_dict["name"] = torch.cuda.get_device_name(0)
    device_dict["CudaDNN_enabled"] = torch.backends.cudnn.enabled
    device_dict["CUDNN_version"] = torch.backends.cudnn.version()
    device_dict["CUDA_version"] = torch.version.cuda
else:
    print("Device name: CPU")
    device_dict["name"] = "CPU" 

## 6) GNN model instantiation

Instantiate model object representing the graph neural network architecture and store it to the training device. We created the `FlexibleNet` class to build different architectures with the same class.

In [None]:
model = FlexibleNet(dim=architecture["dim"],
                        N_linear=architecture["n_linear"], 
                        N_conv=architecture["n_conv"], 
                        adj_conv=architecture["adj_conv"],  
                        sigma=sigma_dict[architecture["sigma"]], 
                        bias=architecture["bias"], 
                        conv=conv_layer[architecture["conv_layer"]], 
                        pool=pool_dict[architecture["pool_layer"]], 
                        pool_ratio=architecture["pool_ratio"], 
                        pool_heads=architecture["pool_heads"], 
                        pool_seq=pool_seq_dict[architecture["pool_seq"]], 
                        pool_layer_norm=architecture["pool_layer_norm"]).to(device)   
print(model)

## 7) Define optimizer and learning rate scheduler

Used optimizer for the training is Adam, algorithm for first-order gradient-based optimization of
stochastic objective functions, based on adaptive estimates of lower-order moments.

In [None]:
optimizer = torch.optim.Adam(model.parameters(),
                             lr=train["lr0"],
                             eps=train["eps"], 
                             weight_decay=train["weight_decay"],
                             amsgrad=train["amsgrad"])

Helps steering the learning rate during the training, providing faster convergence and higher accuracy. The used scheduler is the "Reduce On Loss Plateau Decay".

In [None]:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                          mode='min',
                                                          factor=train["factor"],
                                                          patience=train["patience"],
                                                          min_lr=train["minlr"])  

## 8) Run Training

Everything is set up. Training a deep learning model is an iterative process, requiring multiple iterations (epochs) to make the model find the patterns found in the training dataset.

In [None]:
loss_list, train_list, val_list, test_list, lr_list = [], [], [], [], []         
t0 = time.time() 
for epoch in range(1, train["epochs"]+1):
    torch.cuda.empty_cache()
    # Update learning rate
    lr = lr_scheduler.optimizer.param_groups[0]['lr']
    # Train iteration        
    loss, train_MAE = train_loop(model, device, train_loader, optimizer, loss_dict[train["loss_function"]])  
    # Validation iteration to update learning rate
    val_MAE = test_loop(model, val_loader, device, std)  
    lr_scheduler.step(val_MAE)
    # Test iteration
    if train["test_set"]:
        test_MAE = test_loop(model, test_loader, device, std, mean)         
        print('Epoch {:03d}: LR={:.7f}  Train MAE: {:.4f} eV  Validation MAE: {:.4f} eV '             
              'Test MAE: {:.4f} eV'.format(epoch, lr, train_MAE*std, val_MAE, test_MAE))
        test_list.append(test_MAE)
    else:
        print('Epoch {:03d}: LR={:.7f}  Train MAE: {:.6f} eV  Validation MAE: {:.6f} eV '
              .format(epoch, lr, train_MAE*std, val_MAE))  
    # Save information       
    loss_list.append(loss)
    train_list.append(train_MAE * std)
    val_list.append(val_MAE)
    lr_list.append(lr)
print("-----------------------------------------------------------------------------------------")
training_time = (time.time() - t0)/60  
print("Training time: {:.2f} min".format(training_time))
device_dict["training_time"] = training_time


## 9) Save model and performance analysis

Depending on the use or not of a test set, the information stored in the model report folder will be different. If a test set is used to test the final model, more files will be generated (as learning curve, error distribution plot, etc.). 
The results are saved in the directory provided as second argument, in a folder called as the name provided as first argument to the function `create_model_report`.

In [None]:
create_model_report("TEMPLATE_test",
                    "../../models",
                    HYPERPARAMS,
                    model,
                    (train_loader, val_loader, test_loader), 
                    (mean, std),  
                    (train_list, val_list, test_list, lr_list), 
                    device_dict)                               