# 📘 Training Graph Neural Networks with a Custom `Train` Class

This notebook provides an in-depth walkthrough of a custom `Train` class designed for training Graph Neural Networks (GNNs) using PyTorch. It supports multiple model types including GCN, GAT, GraphSAGE, and GIN, and facilitates full training loops with evaluation and early stopping.

---

## 🔧 **1. Setup and Requirements**

Ensure you have the following modules and packages installed:

```python
!pip install torch torchvision torchaudio tqdm
# Make sure you also have your custom modules:
# - datasets
# - src.models
# - utils.utils

This class handles the full training lifecycle:

- Dataset loading (IMDB or DD)

- Model initialization (GCN, GAT, GraphSAGE, GIN)

- Graph subsampling (optional)

- Training loop with optimizer and scheduler

- Evaluation

## 🧠 **3. Initialization**

In [None]:
from Pipeline.train import Train

In [2]:
params = {"model_type": "GCN",  # "GCN", "GAT", "GIN", "GraphSAGE"
               "n_graph_subsampling": 0, # the number of running graph subsampling each train graph data run subsampling 5 times: increasing graph data 5 times
               "graph_node_subsampling": True, # TRUE: removing node randomly to subsampling and augmentation of graph dataset \n'+
                # FALSE: removing edge randomly to subsampling and augmentation of graph dataset
               "graph_subsampling_rate": 0.2, # graph subsampling rate
               "dataset": "DD", 
               "pooling_type": "mean", 
               "seed": 42,
               "n_folds": 10, 
               "cuda": True, 
               "lr": 0.001, 
               "epochs": 50, 
               "weight_decay":5e-4,
               "batch_size": 32, 
               "dropout": 0, # dropout rate of layer
               "num_lay": 5, 
               "num_agg_layer": 2, # the number of graph aggregation layers
               "hidden_agg_lay_size": 64, # size of hidden graph aggregation layer
               "fc_hidden_size": 128, # size of fully-connected layer after readout
               "threads":10, # how many subprocesses to use for data loading
               "random_walk":True,
               "walk_length": 20, # walk length of random walk, 
               "num_walk": 10, # num of random walk
               "p": 0.65, # Possibility to return to the previous vertex, how well you navigate around
               "q": 0.35, # Possibility of moving away from the previous vertex, how well you are exploring new places
               "print_logger": 10,  # printing rate
               "eps":0.0, # for GIN only
               }

In [3]:
trainer = Train(params)

🧱 4. Model Selection
Based on params["model_type"], one of the following models is instantiated:

- GCN: Graph Convolutional Network

- GAT: Graph Attention Network

- GraphSAGE: Neighborhood aggregation

- GIN: Graph Isomorphism Network

Each model is passed the following key hyperparameters:

- Number of layers

- Hidden layer sizes

- Dropout

- Pooling method

The following `.fit()` method on a `Trainer` object runs the full training + evaluation loop:

- Monitors accuracy

- Applies early stopping after patience epochs without improvement

It returns `['IMDB', 'IMDB', best_accuracy]`

In [None]:
trainer.fit()