In [None]:
"""
Contains functions for training and testing a PyTorch model.
"""

import numpy as np
from tqdm import tqdm
import time  
import matplotlib.pyplot as plt

import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F


#### @title [RUN] Helper functions for managing experiments, training, and evaluating models.

def train(model, loader, optimizer, device):
    model.train()
    loss_all = 0

    for batch in loader:
        #print(f"Loaded y shape from batch (before moving to device): {batch.y.shape}")
        batch = batch.to(device)
        #print(f"Loaded y shape from batch (after moving to device): {batch.y.shape}")
        optimizer.zero_grad()
        output = model(batch)
        #print(f"Forme de batch.y dans train: {batch.y.shape}")  # Imprimer la forme de batch.y
        #print(f"Forme des logits (bind) : {output['bind'].shape}, Forme de la cible (y) : {batch.y.shape}")
        loss = output['bce_loss']
        loss.backward()
        loss_all += loss.item() * batch.num_graphs
        optimizer.step()
    return loss_all / len(loader.dataset)


def eval(model, loader, device):
    model.eval()
    total_correct = 0  # Initialiser total_correct
    total = 0

    for batch  in loader:
        batch  = batch.to(device)
        with torch.no_grad():
            output = model(batch)
            predictions = output['preds']
            correct = (predictions == batch.y).sum().item()
            total_correct += correct
            total += batch.y.size(0)

    accuracy =  total_correct / total
    return accuracy



def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100):
    
    print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("\nModel architecture:")
    print(model)
    total_param = 0
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print(f'Total parameters: {total_param}')
    model = model.to(device)

    # Adam optimizer with LR 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

    # LR scheduler which decays LR when validation metric doesnt improve
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)
    
    print("\nStart training:")
    best_val_accuracy = 0  # Initialise to 0 to avoid None comparison
    corresponding_test_accuracy = 0  # Ajout de la variable pour stocker la meilleure test_accuracy
    
    train_losses = []  # Liste pour stocker les pertes dentraînement
    val_accuracies = []  # Liste pour stocker les precisions de validation
    test_accuracies = []  # Liste pour stocker les precisions de test
    perf_per_epoch = []  # Track performance per epoch
    t = time.time()
    for epoch in range(1, n_epochs+1):

        # Call LR scheduler at start of each epoch
        lr = scheduler.optimizer.param_groups[0]['lr']

        # Train model for one epoch, return avg. training loss
        avg_epoch_loss = train(model, train_loader, optimizer, device)
        train_losses.append(avg_epoch_loss)
        
        # Evaluate model on validation set
        val_accuracy = eval(model, val_loader, device)
        val_accuracies.append(val_accuracy)  
        
        # Si la précision de validation sameliore, evaluation sur lensemble de test
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            corresponding_test_accuracy = eval(model, test_loader, device)
            
        test_accuracies.append(corresponding_test_accuracy)

        if epoch % 10 == 0:
            # Print and track stats every 10 epochs
            print(f'Epoch: {epoch:03d}, LR: {lr:.6f}, Loss: {avg_epoch_loss:.7f}, Val Accuracy: {val_accuracy:.7f}, Test Accuracy: {corresponding_test_accuracy:.7f}')
        
        scheduler.step(val_accuracy)
        perf_per_epoch.append((corresponding_test_accuracy, val_accuracy, epoch, model_name))
    
    t = time.time() - t
    train_time = t/60
    print(f"\nDone! Training took {train_time:.2f} mins. Best validation accuracy: {best_val_accuracy:.7f}, corresponding test accuracy: {corresponding_test_accuracy:.7f}.")
    # Tracer les graphiques
    epochs = range(1, n_epochs + 1)
    
    plt.figure(figsize=(14, 5))
    
    # Graphique des pertes dentrainement
    plt.subplot(1, 3, 1)
    plt.plot(epochs, train_losses, 'b', label='Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.legend()
    
    # Graphique des precisions de validation
    plt.subplot(1, 3, 2)
    plt.plot(epochs, val_accuracies, 'r', label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Validation Accuracy')
    plt.legend()

    # Graphique des precisions de test
    plt.subplot(1, 3, 3)
    plt.plot(epochs, test_accuracies, 'g', label='Test Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

    return best_val_accuracy, corresponding_test_accuracy, train_time, perf_per_epoch



