# Graph Neural Network Training Pipeline

Multi-Dataset Graph Classification with Noise-Robust Training

## 1. Setup and Dependencies

In [None]:
!pip install torch_geometric

In [None]:
import os
import gc
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import argparse

In [None]:
helper_scripts_path = '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/'

if os.path.exists(helper_scripts_path):
    # Add this path to the beginning of Python's search list
    sys.path.insert(0, helper_scripts_path)
    print(f"Successfully added '{helper_scripts_path}' to sys.path.")
    print(f"Contents of '{helper_scripts_path}': {os.listdir(helper_scripts_path)}") # Verify
else:
    print(f"WARNING: Helper scripts path not found: {helper_scripts_path}")
    print("Please ensure 'myhackathonhelperscripts' dataset is correctly added to the notebook.")

# Start import of utils modules
try:
    from preprocessor import MultiDatasetLoader
    from utils import set_seed
    # from conv import GINConv as OriginalRepoGINConv
    from models import GNN
    print("Successfully imported modules.")
except ImportError as e:
    print(f"ERROR importing module: {e}")
    print("Please check that the .py files exist directly under the helper_scripts_path and have no syntax errors.")
    # print("Current sys.path:", sys.path)

# Set the random seed
set_seed()

## 2. Data Preprocessing Functions


In [None]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

## 3. Training and Evaluation Functions

In [None]:
def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data in tqdm(data_loader, desc="Iterating training graphs", unit="batch"):
        data = data.to(device)
        optimizer.zero_grad()
        try:
            output = model(data)
        except IndexError as e:
            print(f"Error in batch with {data.num_nodes} nodes, edge_max={data.edge_index.max()}")
            print(f"Batch info: x.shape={data.x.shape}, edge_index.shape={data.edge_index.shape}")
            raise e
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader),  correct / total

In [None]:
def evaluate(data_loader, model, device, calculate_accuracy=False):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            
            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                total_loss += criterion(output, data.y).item()
            else:
                predictions.extend(pred.cpu().numpy())
    if calculate_accuracy:
        accuracy = correct / total
        return  total_loss / len(data_loader),accuracy
    return predictions

## 4. Utility Functions

In [None]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd() 
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [None]:
def plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    """
    Plot training and validation progress over epochs.
    
    Args:
        train_losses: List of training losses per epoch
        train_accuracies: List of training accuracies per epoch  
        val_losses: List of validation losses per epoch
        val_accuracies: List of validation accuracies per epoch
        output_dir: Directory to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(15, 6))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue', marker='o')
    plt.plot(epochs, val_losses, label="Validation Loss", color='red', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='green', marker='o')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Save plot
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.show()
    plt.close()

## 5. Configuration and Arguments

In [None]:
def get_user_input(prompt, default=None, required=False, type_cast=str):

    while True:
        user_input = input(f"{prompt} [{default}]: ")
        
        if user_input == "" and required:
            print("This field is required. Please enter a value.")
            continue
        
        if user_input == "" and default is not None:
            return default
        
        if user_input == "" and not required:
            return None
        
        try:
            return type_cast(user_input)
        except ValueError:
            print(f"Invalid input. Please enter a valid {type_cast.__name__}.")

In [None]:
def get_arguments():
    """Set training configuration directly"""
    args = {
        # Dataset selection
        'dataset': 'A',  # Choose: A, B, C, D
        'train_mode': 1,  # 1=single dataset, 2=all datasets
        
        # Model config
        'gnn': 'gin',  # gin, gin-virtual, gcn, gcn-virtual
        'drop_ratio': 0.0,
        'num_layer': 5,
        'emb_dim': 300,
        
        # Training config
        'batch_size': 32,
        'epochs': 10,
        'baseline_mode': 1,  # 1=CE, 2=Noisy CE
        'noise_prob': 0.2,
        
        # System config
        'device': 0,
        'num_checkpoints': 3
    }
    return argparse.Namespace(**args)

In [None]:
def populate_args(args):
    print("Arguments received:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")
args = get_arguments()
populate_args(args)

## 6. Loss Function Definition

In [None]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (1 - torch.nn.functional.one_hot(targets, num_classes=logits.size(1)).float().sum(dim=1))
        return (losses * weights).mean()

## 7. Main training pipeline

### 7.1 Config section


In [None]:
print("=" * 60)
print("Enhanced GNN Training Pipeline")
print("=" * 60)

# Get configuration
args = get_arguments()

print("\nConfiguration:")
for key, value in vars(args).items():
    print(f"  {key}: {value}")

# Setup device
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

### 7.2 Data Loading

In [None]:
print("\n" + "="*40)
print("LOADING DATA")
print("="*40)

base_path = '/kaggle/input/deep-dataset-preprocessed/processed_data_separate'

# Prepare training/validation data based on mode
if args.train_mode == 1:
    # Single dataset mode
    dataset_name = args.dataset
    train_dataset = torch.load(f'{base_path}/{dataset_name}_train_graphs.pt', weights_only=False)
    train_dataset = [add_zeros(data) for data in train_dataset]
    
    val_dataset = torch.load(f'{base_path}/{dataset_name}_val_graphs.pt', weights_only=False)
    val_dataset = [add_zeros(data) for data in val_dataset]
    
    test_dataset = torch.load(f'{base_path}/{dataset_name}_test_graphs.pt', weights_only=False)
    test_dataset = [add_zeros(data) for data in test_dataset]
    print(f"Using single dataset: {dataset_name}")
else:
    # All datasets mode
    train_dataset = []
    val_dataset = []
    test_dataset = torch.load(f'{base_path}/{args.dataset}_test_graphs.pt', weights_only=False)  # Test on specified dataset
    
    for ds_name in ['A', 'B', 'C', 'D']:
        train_dataset.extend(torch.load(f'{base_path}/{ds_name}_train_graphs.pt', weights_only=False))
        val_dataset.extend(torch.load(f'{base_path}/{ds_name}_val_graphs.pt', weights_only=False))
    
    print("Using all datasets for training")

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

### 7.3 Model Setup

In [None]:

print("\n" + "="*40)
print("MODEL SETUP")
print("="*40)

# Initialize model
if args.gnn == 'gin':
    model = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer, 
               emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False)
elif args.gnn == 'gin-virtual':
    model = GNN(gnn_type='gin', num_class=6, num_layer=args.num_layer,
               emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True)
elif args.gnn == 'gcn':
    model = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer,
               emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=False)
elif args.gnn == 'gcn-virtual':
    model = GNN(gnn_type='gcn', num_class=6, num_layer=args.num_layer,
               emb_dim=args.emb_dim, drop_ratio=args.drop_ratio, virtual_node=True)
else:
    raise ValueError(f'Invalid GNN type: {args.gnn}')

model = model.to(device)

# Setup optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

if args.baseline_mode == 2:
    criterion = NoisyCrossEntropyLoss(args.noise_prob)
    print(f"Using Noisy Cross Entropy Loss (p={args.noise_prob})")
else:
    criterion = torch.nn.CrossEntropyLoss()
    print("Using standard Cross Entropy Loss")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Setup logging and checkpoints
exp_name = f"{args.gnn}_dataset{args.dataset}_mode{args.train_mode}"
logs_dir = os.path.join("logs", exp_name)
checkpoints_dir = os.path.join("checkpoints", exp_name)
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)

# Setup logging
log_file = os.path.join(logs_dir, "training.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

best_model_path = os.path.join(checkpoints_dir, "best_model.pth")

### 7.4 Training loop


In [None]:
print("\n" + "="*40)
print("TRAINING")
print("="*40)

best_val_accuracy = 0.0
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Calculate checkpoint intervals
if args.num_checkpoints > 1:
    checkpoint_intervals = [int((i + 1) * args.epochs / args.num_checkpoints) 
                          for i in range(args.num_checkpoints)]
else:
    checkpoint_intervals = [args.epochs]

for epoch in range(args.epochs):
    print(f"\nEpoch {epoch + 1}/{args.epochs}")
    print("-" * 30)
    
    # Training
    train_loss, train_acc = train(
        train_loader, model, optimizer, criterion, device,
        save_checkpoints=(epoch + 1 in checkpoint_intervals),
        checkpoint_path=os.path.join(checkpoints_dir, "checkpoint"),
        current_epoch=epoch
    )
    
    # Validation
    val_loss, val_acc = evaluate(val_loader, model, device, calculate_accuracy=True)
    
    # Log results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    logging.info(f"Epoch {epoch + 1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
                f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Save best model
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"★ New best model saved! Val Acc: {val_acc:.4f}")

print(f"\nBest validation accuracy: {best_val_accuracy:.4f}")

In [None]:
# Plot training progress
plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, logs_dir)

### 7.5 Testing and predictions

In [None]:
print("\n" + "="*40)
print("TESTING")
print("="*40)

# Load best model and make predictions
model.load_state_dict(torch.load(best_model_path))
print(f"Loaded best model from: {best_model_path}")

predictions = evaluate(test_loader, model, device, calculate_accuracy=False)

# Save predictions
save_predictions(predictions, args.dataset)

# Cleanup for memory
del train_dataset, val_dataset, test_dataset
del train_loader, val_loader, test_loader
gc.collect()

print("\n" + "="*60)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"Best validation accuracy: {best_val_accuracy:.4f}")
print(f"Predictions saved for dataset {args.dataset}")
print(f"Logs and plots saved in: {logs_dir}")