## **Training Graph Neural Networks with a Custom `ModelAssessment` Class**

This notebook provides a detailed walkthrough of a custom-built `Train` class for training **Graph Neural Networks (GNNs)** using PyTorch.  
It supports multiple architectures—**GCN**, **GAT**, **GraphSAGE**, and **GIN**—and includes a full training pipeline with evaluation and **early stopping**.

---

### Setup and Requirements

Make sure the following packages and custom modules are installed:

```bash
pip install torch torchvision torchaudio tqdm
```

You will also need the following **custom modules** available in your working directory:

- `datasets/` (e.g., `IMDB`, `DD`, `PROTEINS`)
- `src/models.py` (GNN architectures)
- `utils/utils.py` (helper functions)

---

### What the `ModelAssessment` Class Does

The `Train` class handles the **entire training lifecycle**, including:

- **Dataset loading** (e.g., IMDB or DD)
- **Model initialization** (GCN, GAT, GraphSAGE, GIN)
- **Optional graph subsampling and augmentation**
- **Training loop** with optimizer & scheduler
- **Evaluation** with accuracy tracking and early stopping
```

In [None]:
import sys
sys.path.append('../datasets')
from datasets.manager import IMDBBinary, DD, PROTEINS
from datasets.dataset import *
from Pipeline.Model_Assessment import ModelAssessment, plot_gnn_comparison


### **Training Setup**

To understand the full training process used for benchmarking the different GNN architectures, please refer to the notebook *toy_examples.ipynb*, which outlines each step in detail, along with the properties of the core classes.

We begin this benchmark by defining the list of hyperparameters that will be used to generate the **search grid** for model selection.

---

In [None]:
params = {"model_type": "GCN",  # "GCN", "GAT", "GIN", "GraphSAGE"
               "n_graph_subsampling": 0, # the number of running graph subsampling each train graph data run subsampling 5 times: increasing graph data 5 times
               "graph_node_subsampling": True, # TRUE: removing node randomly to subsampling and augmentation of graph dataset \n'+
                # FALSE: removing edge randomly to subsampling and augmentation of graph dataset
               "graph_subsampling_rate": 0.2, # graph subsampling rate
               "dataset": "DD", 
               "pooling_type": "mean", 
               "seed": 42,
               "n_folds": 10, 
               "cuda": True, 
               "lr": [0.001, 0.01, 0.1], 
               "epochs": 50, 
               "weight_decay":5e-4,
               "batch_size": 32, 
               "dropout": 0, # dropout rate of layer
               "num_lay": [2, 3, 5], 
               "num_agg_layer": 2, # the number of graph aggregation layers
               "hidden_agg_lay_size": [16, 32, 64], # size of hidden graph aggregation layer
               "fc_hidden_size": 128, # size of fully-connected layer after readout
               "threads":10, # how many subprocesses to use for data loading
               "random_walk":True,
               "walk_length": 20, # walk length of random walk, 
               "num_walk": 10, # num of random walk
               "p": 0.65, # Possibility to return to the previous vertex, how well you navigate around
               "q": 0.35, # Possibility of moving away from the previous vertex, how well you are exploring new places
               "print_logger": 10,  # printing rate
               "eps":0.0, # for GIN only
               "early_stopping": False # early stopping
               }

### **Benchmarking GNNs**

First, we initialize a dictionary to store the performance of each GNN architecture across the different datasets.  
Performance is measured using the **accuracy metric**, averaged and accompanied by its standard deviation across the different folds of a cross-validation procedure.

This notebook trains models on **three datasets**:

- **D&D**: A molecular graph dataset with 2 classes.
- **PROTEINS**: Another molecular graph dataset with 2 classes.
- **IMDB-BINARY**: A social network dataset used for binary graph classification.

We conduct several experiments to evaluate the effect of various training and model selection techniques. The benchmark begins with a **standard evaluation setup** based on the experimental framework from [Federico et al., 2022](https://arxiv.org/abs/1912.09893).

We then compare the results of each method across the datasets under **three different experimental settings**:

- **With and without Early Stopping** during training.
- **With and without Node Degree features** (i.e., including the node degree as an additional feature).
- Using either **Grid Search** or **Random Search** during hyperparameter tuning.

The results of these experiments will be visualized and analyzed to draw insights into the performance of different architectures and the impact of these common training strategies.

We begin by initializing the results dictionary below:

In [3]:
metrics = ["benchmark", "random", "early", "node_features"]
datasets = ["DD", "IMDB", "PROTEINS"]

acc = {metric: {ds: [] for ds in datasets} for metric in metrics}
std_acc = {metric: {ds: [] for ds in datasets} for metric in metrics}

# Exemple d'accès :
# acc["benchmark"]["DD"].append(0.81)
# std_acc["early"]["PROTEINS"].append(0.05)

### **Model Selection**

As explained in *toy_examples.ipynb*, we compare the following GNN architectures:

- **GCN**: Graph Convolutional Network  
- **GAT**: Graph Attention Network  
- **GraphSAGE**: Neighborhood aggregation-based GNN  
- **GIN**: Graph Isomorphism Network

Each model is trained with a specific set of core hyperparameters, including:

- Number of convolutional layers  
- Embedding dimension  
- Learning rate  

---

### **GNN Performance Comparison on 3 Datasets for Graph Classification**

We now begin the training phase by initializing the datasets.

In [None]:
## Intialisation des Datasets
DD = DD()
IMDB = IMDBBinary()
PROTEINS = PROTEINS()

#### Model Training — All Architectures on All Datasets

In this section, we loop through all selected **datasets** (`DD`, `IMDB`, `PROTEINS`) and **model architectures** (`GCN`, `GAT`, `GIN`, `GraphSAGE`, `Baseline`) to train and evaluate each combination.

The training is performed using **grid search** for hyperparameter selection, and the resulting mean accuracy and standard deviation are recorded for each run.


In [None]:
# List of datasets
datasets = {
    "DD": DD,
    "IMDB": IMDB,
    "PROTEINS": PROTEINS 
}

# List of model types to train & Evaluate
model_types = ["GCN", "GAT", "GIN", "GraphSAGE", "Baseline"]

for dataset_name, dataset_obj in datasets.items():
    print(f"\n Training on {dataset_name} dataset")

    for model in model_types:
        print(f"  ➤ Model: {model}")
        
        params_list = params.copy()
        params_list["model_type"] = model
        
        # Initialize and run ModelAssessment
        model_assessment = ModelAssessment(dataset_obj, params_list, random_search=False)
        mean, std = model_assessment.assess()
        
        # Stockage des Resultats
        acc["benchmark"][dataset_name].append(mean)
        std_acc["benchmark"][dataset_name].append(std)


### **Experiment 1: Evaluating GNN Performance With and Without Early Stopping**

In this experiment, we compare the performance of Graph Neural Networks (GNNs) when trained **with** and **without early stopping**.

**Early stopping** is a regularization technique that halts training when the model’s performance on a validation set no longer improves, helping to prevent overfitting and reduce unnecessary computation.

---

In [None]:
## Update param_list -> early stopping
params["early_stopping"] = True

datasets = {
    "DD": DD,
    "IMDB": IMDB,
    "PROTEINS": PROTEINS 
}

# List of model types to train & Evaluate
model_types = ["GCN", "GAT", "GIN", "GraphSAGE", "Baseline"]

for dataset_name, dataset_obj in datasets.items():
    print(f"\n Training on {dataset_name} dataset")

    for model in model_types:
        print(f"  ➤ Model: {model}")
        
        params_list = params.copy()
        params_list["model_type"] = model
        
        # Initialize and run ModelAssessment
        model_assessment = ModelAssessment(dataset_obj, params_list, random_search=False)
        mean, std = model_assessment.assess()
        
        # Stockage des Resultats -> pour early stopping
        acc["early"][dataset_name].append(mean)
        std_acc["early"][dataset_name].append(std)


### **Experiment 2: Evaluating GNN Performance With and Without Node Degree Features**

In this experiment, we evaluate the impact of adding **node degree information** to the feature matrices of the graphs.

We compare the performance of each GNN architecture when trained:
- **With node degrees** included as an additional feature per node.
- **Without node degrees**, using only the original node features.

This allows us to assess whether incorporating structural information like node connectivity can improve model accuracy across datasets.

---

In [None]:
## Update param_list -> Node Degree
params["Node_Degree"] = True

datasets = {
    "DD": DD,
    "IMDB": IMDB,
    "PROTEINS": PROTEINS 
}

# List of model types to train & Evaluate
model_types = ["GCN", "GAT", "GIN", "GraphSAGE", "Baseline"]

for dataset_name, dataset_obj in datasets.items():
    print(f"\n Training on {dataset_name} dataset")

    for model in model_types:
        print(f"  ➤ Model: {model}")
        
        params_list = params.copy()
        params_list["model_type"] = model
        
        # Initialize and run ModelAssessment
        model_assessment = ModelAssessment(dataset_obj, params_list, random_search=False)
        mean, std = model_assessment.assess()
        
        # Stockage des Resultats -> pour node degree
        acc["node_features"][dataset_name].append(mean)
        std_acc["node_features"][dataset_name].append(std)


### **Experiment 3: Comparing Grid Search and Random Search for Hyperparameter Tuning**

In this experiment, we compare two strategies for **hyperparameter tuning** during model selection:

- **Grid Search**: systematically explores all combinations of specified hyperparameter values.
- **Random Search**: samples a subset of random combinations from the hyperparameter space.

This comparison allows us to evaluate the trade-off between **exploration efficiency** and **computational cost**, and to assess whether Random Search can achieve comparable performance with fewer evaluations.

---

In [None]:
## Update param_list -> random search

datasets = {
    "DD": DD,
    "IMDB": IMDB,
    "PROTEINS": PROTEINS 
}

# List of model types to train & Evaluate
model_types = ["GCN", "GAT", "GIN", "GraphSAGE", "Baseline"]

for dataset_name, dataset_obj in datasets.items():
    print(f"\n Training on {dataset_name} dataset")

    for model in model_types:
        print(f"  ➤ Model: {model}")
        
        params_list = params.copy()
        params_list["model_type"] = model
        
        # Initialize and run ModelAssessment
        model_assessment = ModelAssessment(dataset_obj, params_list, random_search=True) ## -> On teste avec RandomSearch
        mean, std = model_assessment.assess()
        
        # Stockage des Resultats 
        acc["random"][dataset_name].append(mean)
        std_acc["random"][dataset_name].append(std)


### **Visualizing Experimental Results with Boxplots**

In this section, we visualize the outcomes of our different experiments using **boxplots**.  
These plots illustrate the **performance distribution** (in terms of accuracy) of each GNN architecture across the **three datasets** (`D&D`, `PROTEINS`, `IMDB-BINARY`) under the various **experimental settings**:

- With and without **Early Stopping**
- With and without **Node Degree Features**
- Using **Grid Search** vs. **Random Search** for model selection

Boxplots provide a clear comparison of the **variance**, **robustness**, and **average accuracy** of each method, allowing us to draw conclusions on which settings contribute most to model performance and generalization.

---


In [None]:
models = ["GCN", "GAT", "GIN", "GraphSAGE", "Baseline"]

## Plotting the results -> Experiment 1
plot_gnn_comparison(models, datasets, acc["benchmark"], std_acc["benchmark"], acc["early"], std_acc["early"], "Comparative Study of GNNs - Impact of Early Stopping on Training", "Without Early Stopping", "With Early Stopping")

## Plotting the results -> Experiment 2
plot_gnn_comparison(models, datasets, acc["benchmark"], std_acc["benchmark"], acc["node_features"], std_acc["node_features"], "Comparative Study of GNNs - Impact of Node Features on Training", "Without Node Features", "With Node Features")

## Plotting the results -> Experiment 3
plot_gnn_comparison(models, datasets, acc["benchmark"], std_acc["benchmark"], acc["random"], std_acc["random"], "Comparative Study of GNNs - Impact of Random Search on Training", "Without Random Search", "With Random Search")