In [1]:
import sys
import numpy as np
sys.path.append('../datasets')
from datasets.manager import IMDBBinary, DD
import torch 
import itertools
from tqdm import tqdm

#from utils.utils import visualise_graph, get_adjacency_and_features
from utils.utils import get_adjacency_and_features, create_batch_from_loader
#from src.gnn import GNNClassifier

from datasets.dataset import *
from train import Training

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

import random

import numpy as np
import time
import statistics
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from src.models import GCN, GAT, GraphDenseNet, GraphSAGE, GIN
from datasets.dataloader import DataLoader

In [2]:
params = {"model_type": "GraphSAGE",  # "GCN", "GAT", "GIN", "GraphSAGE"
               "n_graph_subsampling": 0, # the number of running graph subsampling each train graph data run subsampling 5 times: increasing graph data 5 times
               "graph_node_subsampling": True, # TRUE: removing node randomly to subsampling and augmentation of graph dataset \n'+
                # FALSE: removing edge randomly to subsampling and augmentation of graph dataset
               "graph_subsampling_rate": 0.2, # graph subsampling rate
               "dataset": "DD", 
               "pooling_type": "mean", 
               "seed": 42,
               "n_folds": 10, 
               "cuda": True, 
               "lr": 0.001, 
               "epochs": 50, 
               "weight_decay":5e-4,
               "batch_size": 32, 
               "dropout": 0, # dropout rate of layer
               "num_lay": 5, 
               "num_agg_layer": 2, # the number of graph aggregation layers
               "hidden_agg_lay_size": 64, # size of hidden graph aggregation layer
               "fc_hidden_size": 128, # size of fully-connected layer after readout
               "threads":10, # how many subprocesses to use for data loading
               "random_walk":True,
               "walk_length": 20, # walk length of random walk, 
               "num_walk": 10, # num of random walk
               "p": 0.65, # Possibility to return to the previous vertex, how well you navigate around
               "q": 0.35, # Possibility of moving away from the previous vertex, how well you are exploring new places
               "print_logger": 10,  # printing rate
               "eps":0.0, # for GIN only
               }

In [3]:
class Train:
    def __init__(self, params):
        """
        Trainer class for training and evaluating a GNN model.

        Args:
            params (dict): Dictionary of training and model parameters.

        Note:
            - The dataset is automatically loaded depending on 'dataset' specified in params.
            - Only supports 'IMDB' or 'DD' datasets as currently coded.
        """
        if params["dataset"] == "IMDB":
            self.dataset = IMDBBinary()
        elif params["dataset"] == "DD":
            self.dataset = DD()

        self.params = params

        # Extract the graph data and their corresponding labels
        self.x_dataset, self.y_dataset = self.dataset.dataset.get_data(), self.dataset.dataset.get_targets()

        # Select device (GPU if available and requested, else CPU)
        if self.params["cuda"] and torch.cuda.is_available():
            self.device = "cuda:0"
        else:
            self.device = "cpu"

        self.model = self.get_model()  # Instantiate the model
        self.loss_fn = F.cross_entropy  # Loss function for classification tasks

    def get_model(self):
        """
        Instantiates the model specified in params["model_type"].

        Returns:
            torch.nn.Module: The corresponding GNN model moved to the correct device.
        """
        input_features = self.x_dataset[0].x.shape[1]  # Number of input node features

        if self.params["model_type"] == 'GCN':
            model = GCN(
                n_feat=input_features,
                n_class=2,
                n_layer=self.params['num_agg_layer'],
                agg_hidden=self.params['hidden_agg_lay_size'],
                fc_hidden=self.params['fc_hidden_size'],
                dropout=self.params['dropout'],
                pool_type=self.params['pooling_type'],
                device=self.device
            ).to(self.device)

        elif self.params["model_type"] == 'GAT':
            model = GAT(
                n_feat=input_features,
                n_class=2,
                n_layer=self.params['num_agg_layer'],
                agg_hidden=self.params['hidden_agg_lay_size'],
                fc_hidden=self.params['fc_hidden_size'],
                dropout=self.params['dropout'],
                pool_type=self.params['pooling_type'],
                device=self.device
            ).to(self.device)

        elif self.params["model_type"] == 'GraphSAGE':
            model = GraphSAGE(
                n_feat=input_features,
                n_class=2,
                n_layer=self.params['num_agg_layer'],
                agg_hidden=self.params['hidden_agg_lay_size'],
                fc_hidden=self.params['fc_hidden_size'],
                dropout=self.params['dropout'],
                pool_type=self.params["pooling_type"],
                device=self.device
            ).to(self.device)

        elif self.params["model_type"] == 'GIN':
            model = GIN(
                n_feat=input_features,
                n_class=2,
                n_layer=self.params['num_agg_layer'],
                agg_hidden=self.params['hidden_agg_lay_size'],
                fc_hidden=self.params['fc_hidden_size'],
                dropout=self.params['dropout'],
                pool_type=self.params["pooling_type"],
                device=self.device
            ).to(self.device)

        return model

    def loaders_train_test_setup(self):
        """
        Sets up data loading, optimizer, and learning rate scheduler.

        Returns:
            Tuple: (DataLoader, optimizer, scheduler)
        """
        # Create a custom DataLoader that simply returns indices (no batching)
        loader = torch.utils.data.DataLoader(
            range(len(self.x_dataset)),
            batch_size=1,
            shuffle=True,
            num_workers=0,
            pin_memory=True,
            drop_last=False,
            collate_fn=lambda x: x  # x will be a list of one index
        )

        # Count and display number of trainable parameters
        c = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print('N trainable parameters:', c)

        # Define Adam optimizer with weight decay
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, self.model.parameters()),
            lr=self.params["lr"],
            weight_decay=self.params["weight_decay"],
            betas=(0.5, 0.999)
        )

        # Define learning rate scheduler (reduce LR at epochs 20 and 30)
        scheduler = lr_scheduler.MultiStepLR(optimizer, [20, 30], gamma=0.1)

        return loader, optimizer, scheduler

    def train(self, train_loader, optimizer, scheduler, epoch):
        """
        One training epoch over the dataset.

        Args:
            train_loader (DataLoader)
            optimizer (Optimizer)
            scheduler (Scheduler)
            epoch (int): Current epoch index

        Returns:
            float: Average time per iteration
        """
        self.model.train()
        train_loss, n_samples = 0, 0
        total_time_iter = 0
        start = time.time()

        for batch_idx, data_batch in enumerate(train_loader):
            idx = data_batch[0]  # Extract index

            x = self.x_dataset[idx]
            y = self.y_dataset[idx]

            optimizer.zero_grad()

            if params["model_type"] == "GraphSAGE":
                # beware, GraphSAGE does not take adjency matrix as input
                output = self.model(x) # Forward pass

            else: 
                A, f = get_adjacency_and_features(x)
                A = A.to(self.device)
                f = f.to(self.device)
                y = torch.tensor([y], device=self.device)

                output = self.model(f, A) # Forward pass

            y = torch.tensor([y], device=self.device)  # Wrap in tensor for batch dim

        
            loss = self.loss_fn(output.unsqueeze(0), y)  # Add batch dimension to output

            loss.backward()
            optimizer.step()

            # Timing and logging
            time_iter = time.time() - start
            total_time_iter += time_iter
            train_loss += loss.item()
            n_samples += 1

            if batch_idx % self.params["print_logger"] == 0 or batch_idx == len(train_loader) - 1:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} (avg: {:.6f}) \tsec/iter: {:.4f}'.format(
                    epoch, n_samples, len(train_loader.dataset),
                    100. * (batch_idx + 1) / len(train_loader),
                    loss.item(), train_loss / n_samples, time_iter / (batch_idx + 1)
                ))

            start = time.time()  # Reset timer

        scheduler.step()  # Adjust learning rate
        return total_time_iter / (len(train_loader) + 1)

    def evaluate(self, test_loader):
        """
        Evaluate model on the test set.

        Args:
            test_loader (DataLoader)

        Returns:
            float: Accuracy on the test set
        """
        self.model.eval()
        correct, n_samples = 0, 0

        with torch.no_grad():
            for batch_idx, data_batch in enumerate(test_loader):
                idx = data_batch[0]
                x = self.x_dataset[idx]
                y = self.y_dataset[idx]

                if params["model_type"] == "GraphSAGE":
                    output = self.model(x)
                else: 
                    A, f = get_adjacency_and_features(x)
                    A = A.to(self.device)
                    f = f.to(self.device)
                    y = torch.tensor([y], device=self.device)

                    output = self.model(f, A)

                # Prediction: binary or multi-class
                if output.shape[-1] == 1:
                    pred = (torch.sigmoid(output) > 0.5).long()
                else:
                    pred = output.argmax(dim=-1)

                correct += (pred == y).sum().item()
                n_samples += 1

        acc = 100. * correct / n_samples
        print(f'Test set (epoch {self.params["epochs"]}): Accuracy: {correct}/{n_samples} ({acc:.2f}%)\n')

        return acc

    def fit(self):
        """
        Run the full training and evaluation loop.

        Returns:
            list: [dataset name, dataset name (again), best accuracy achieved]
        """
        loader, optimizer, scheduler = self.loaders_train_test_setup()
        total_time = 0
        best_acc = 0
        patience_counter = 0
        patience = self.params.get("early_stopping_patience", 5)

        for epoch in tqdm(range(self.params["epochs"]), desc="Epochs", position=1, leave=False):
            total_time_iter = self.train(loader, optimizer, scheduler, epoch)
            total_time += total_time_iter
            acc = self.evaluate(loader)  # Same loader used for train/test (no split)

            # Early stopping logic
            if acc > best_acc:
                best_acc = acc
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch + 1} epochs.")
                break

        print(f'Best Accuracy: {best_acc:.2f}%')
        print(f'Average training time per epoch: {total_time / (epoch + 1):.2f} seconds')

        return [self.params["dataset"], self.params["dataset"], best_acc]


In [4]:
# IMDB = IMDBBinary()
# DD = DD()

In [5]:
trainer = Train(params)

In [6]:
trainer.fit()

N trainable parameters: 36482










KeyboardInterrupt: 

In [1]:
x_data = trainer.x_dataset
x0 = x_data[0]
model = trainer.model
edge_index = x0.edge_index - 1
x = x0.x
print(x.shape)
print(edge_index.shape)

NameError: name 'trainer' is not defined

In [7]:
print("Avant modèle :")
print("edge_index min :", edge_index.min().item())
print("edge_index max :", edge_index.max().item())
print("x.shape[0] :", x.shape[0])

Avant modèle :
edge_index min : 0
edge_index max : 326
x.shape[0] : 327
