# Graph Neural Network Training Pipeline

Multi-Dataset Graph Classification with Noise-Robust Training

## 1. Setup and Dependencies

In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
import os
import gc
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import argparse

In [3]:
helper_scripts_path = '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/'

if os.path.exists(helper_scripts_path):
    # Add this path to the beginning of Python's search list
    sys.path.insert(0, helper_scripts_path)
    print(f"Successfully added '{helper_scripts_path}' to sys.path.")
    print(f"Contents of '{helper_scripts_path}': {os.listdir(helper_scripts_path)}") # Verify
else:
    print(f"WARNING: Helper scripts path not found: {helper_scripts_path}")
    print("Please ensure 'myhackathonhelperscripts' dataset is correctly added to the notebook.")

# Start import of utils modules
try:
    from preprocessor import MultiDatasetLoader
    from utils import set_seed
    # from conv import GINConv as OriginalRepoGINConv
    from models_EDandBatch_norm import GNN
    print("Successfully imported modules.")
except ImportError as e:
    print(f"ERROR importing module: {e}")
    print("Please check that the .py files exist directly under the helper_scripts_path and have no syntax errors.")
    # print("Current sys.path:", sys.path)

# Set the random seed
set_seed()

Successfully added '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/' to sys.path.
Contents of '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/': ['models_edge_drop.py', 'zipthefolder.py', 'loadData.py', 'utils.py', 'models_EDandBatch_norm.py', 'models.py', 'conv.py', 'preprocessor.py', '__init__.py']
Successfully imported modules.


## 2. Data Preprocessing Functions


In [4]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

## 3. Training and Evaluation Functions

In [5]:
def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch, scheduler=None, args=None):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, data in enumerate(tqdm(data_loader, desc="Iterating training graphs", unit="batch")):
        data = data.to(device)
        optimizer.zero_grad()
        
        try:
            output = model(data)
        except IndexError as e:
            print(f"Error in batch with {data.num_nodes} nodes, edge_max={data.edge_index.max()}")
            print(f"Batch info: x.shape={data.x.shape}, edge_index.shape={data.edge_index.shape}")
            raise e
            
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()

        # Step OneCycleLR scheduler after each batch
        if scheduler is not None and args.scheduler_type == 'OneCycleLR':
            scheduler.step()
            
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader),  correct / total

In [6]:
# CELL 7 (Corrected)
def evaluate(data_loader, model, criterion, device, calculate_accuracy=False): # Added 'criterion' argument
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    # REMOVE THE HARDCODED CRITERION:
    # criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                # NOW USES THE PASSED-IN CRITERION
                total_loss += criterion(output, data.y).item()
            else:
                predictions.extend(pred.cpu().numpy())
    if calculate_accuracy:
        accuracy = correct / total
        return  total_loss / len(data_loader), accuracy # Ensure consistent return order
    return predictions

## 4. Utility Functions

In [7]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd() 
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [8]:
def plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    """
    Plot training and validation progress over epochs.
    
    Args:
        train_losses: List of training losses per epoch
        train_accuracies: List of training accuracies per epoch  
        val_losses: List of validation losses per epoch
        val_accuracies: List of validation accuracies per epoch
        learning_rates: List of learning rates per epoch
        output_dir: Directory to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    
    # Create figure with 3 subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))
        
    # Plot losses
    ax1.plot(epochs, train_losses, label="Training Loss", color='blue', marker='o', markersize=3)
    ax1.plot(epochs, val_losses, label="Validation Loss", color='red', marker='s', markersize=3)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss per Epoch')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracies
    ax2.plot(epochs, train_accuracies, label="Training Accuracy", color='green', marker='o', markersize=3)
    ax2.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange', marker='s', markersize=3)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Training and Validation Accuracy per Epoch')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Plot learning rate
    ax3.plot(epochs, learning_rates, label="Learning Rate", color='purple', marker='d', markersize=3)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Learning Rate')
    ax3.set_title('Learning Rate Schedule')
    ax3.set_yscale('log')  # Use log scale for better visualization
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # Save plot
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress_with_lr.png"), dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

    # Create a summary plot showing best epochs
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    
    # Plot validation accuracy and learning rate on same plot with different y-axes
    ax_lr = ax.twinx()
    
    line1 = ax.plot(epochs, val_accuracies, 'b-', label='Validation Accuracy', linewidth=2, alpha=0.8)
    line2 = ax_lr.plot(epochs, learning_rates, 'r--', label='Learning Rate', linewidth=2, alpha=0.8)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Validation Accuracy', color='b')
    ax_lr.set_ylabel('Learning Rate', color='r')
    ax_lr.set_yscale('log')
    
    # Find and mark best validation accuracy
    best_epoch = epochs[val_accuracies.index(max(val_accuracies))]
    best_acc = max(val_accuracies)
    ax.scatter([best_epoch], [best_acc], color='gold', s=100, zorder=5, 
               label=f'Best: Epoch {best_epoch}, Acc: {best_acc:.4f}')
    
    ax.set_title('Validation Accuracy vs Learning Rate Schedule')
    ax.grid(True, alpha=0.3)
    
    # Combine legends
    lines = line1 + line2 + ax.collections
    labels = [l.get_label() for l in lines]
    ax.legend(lines, labels, loc='center left')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "accuracy_vs_lr.png"), dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()

## 5. Configuration and Arguments

In [9]:
def get_user_input(prompt, default=None, required=False, type_cast=str):

    while True:
        user_input = input(f"{prompt} [{default}]: ")
        
        if user_input == "" and required:
            print("This field is required. Please enter a value.")
            continue
        
        if user_input == "" and default is not None:
            return default
        
        if user_input == "" and not required:
            return None
        
        try:
            return type_cast(user_input)
        except ValueError:
            print(f"Invalid input. Please enter a valid {type_cast.__name__}.")

In [10]:
def get_arguments():
    """Set training configuration directly"""
    args = {
        # Dataset selection
        'dataset': 'A',  # Choose: A, B, C, D
        'train_mode': 1,  # 1=single dataset, 2=all datasets
        
        # Model config
        #'gnn': 'gin',  # gin, gin-virtual, gcn, gcn-virtual
        'num_layer': 3,
        'emb_dim': 256,
        'drop_ratio': 0.3,   # Dropout ratio
        'virtual_node': True, # True to use virtual node, False otherwise
        'residual': True,    # True to use residual connections, False otherwise
        'JK': "last",         # Jumping Knowledge: "last", "sum", "cat"
        'edge_drop_ratio' : 0.15,
        'batch_norm' : True,
        'graph_pooling': "mean", # "sum", "mean", "max", "attention", "set2set"
        
        # Training config
        'batch_size': 64,
        'epochs': 200,
        'baseline_mode': 3,  # 1=CE, 2=Noisy CE, 3 GCE
        'noise_prob': 0.2,
        'gce_q' : 0.5,
        'initial_lr' : 1e-3,

        # Lr scheduler config =================================================================================================================
        'use_scheduler' : True,
        'scheduler_type': 'ReduceLROnPlateau',  # Options: 'StepLR', 'ReduceLROnPlateau', 'CosineAnnealingLR', 'ExponentialLR', 'OneCycleLR'

        # StepLR parameters
        'step_size': 30,      # Period of learning rate decay for StepLR
        'gamma': 0.5,         # Multiplicative factor of learning rate decay
        
        # ReduceLROnPlateau parameters
        'patience_lr': 10,    # Number of epochs with no improvement after which LR will be reduced
        'factor': 0.5,        # Factor by which the learning rate will be reduced
        'min_lr': 1e-7,       # Lower bound on the learning rate
        
        # CosineAnnealingLR parameters
        'T_max': 50,          # Maximum number of iterations for cosine annealing
        'eta_min': 1e-6,      # Minimum learning rate
        
        # ExponentialLR parameters
        'gamma_exp': 0.95,    # Multiplicative factor of learning rate decay for ExponentialLR
        
        # OneCycleLR parameters
        'max_lr': 1e-3,       # Upper learning rate boundary
        'pct_start': 0.3,     # Percentage of cycle spent increasing learning rate

        # =====================================================================================================================================
        
        # Early stopping config
        'early_stopping': True,  # Enable/disable early stopping
        'patience': 25,          # Number of epochs to wait without improvement
        
        # System config
        'device': 0,
        'num_checkpoints': 3,
    }
    return argparse.Namespace(**args)

In [11]:
def populate_args(args):
    print("Arguments received:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")
args = get_arguments()
populate_args(args)

Arguments received:
dataset: A
train_mode: 1
num_layer: 3
emb_dim: 256
drop_ratio: 0.3
virtual_node: True
residual: True
JK: last
edge_drop_ratio: 0.15
batch_norm: True
graph_pooling: mean
batch_size: 64
epochs: 200
baseline_mode: 3
noise_prob: 0.2
gce_q: 0.5
initial_lr: 0.001
use_scheduler: True
scheduler_type: ReduceLROnPlateau
step_size: 30
gamma: 0.5
patience_lr: 10
factor: 0.5
min_lr: 1e-07
T_max: 50
eta_min: 1e-06
gamma_exp: 0.95
max_lr: 0.001
pct_start: 0.3
early_stopping: True
patience: 25
device: 0
num_checkpoints: 3


## 6. Loss Function Definition

In [12]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (1 - torch.nn.functional.one_hot(targets, num_classes=logits.size(1)).float().sum(dim=1))
        return (losses * weights).mean()

In [13]:
# CELL 12.1 (New Cell or append to existing cell 12)
class GeneralizedCrossEntropyLoss(torch.nn.Module):
    def __init__(self, q=0.7):
        """
        Generalized Cross Entropy Loss.
        q is a hyperparameter, 0 < q <= 1.
        As q -> 0, GCE approaches standard CE.
        """
        super(GeneralizedCrossEntropyLoss, self).__init__()
        if not (0 < q <= 1):
            # While the limit q->0 is CE, for q=0 direct computation is 1/0.
            # The paper usually uses q > 0.
            raise ValueError("q should be in (0, 1]")
        self.q = q

    def forward(self, logits, targets):
        probs = torch.softmax(logits, dim=1)
        # Select probabilities of the target class for each sample
        target_probs = probs[torch.arange(targets.size(0)), targets]

        # To prevent issues with target_probs being exactly 0,
        # especially if q is very small (though here q > 0).
        # However, 0^q is 0 for q > 0, so it should be fine.
        # For extra safety: target_probs = target_probs.clamp(min=1e-8)

        # GCE loss: (1 - p_t^q) / q
        loss = (1 - (target_probs ** self.q)) / self.q
        return loss.mean()

## 7. Model creation

### 7.1 Config section


In [14]:
print("=" * 60)
print("Enhanced GNN Training Pipeline")
print("=" * 60)

# Get configuration
args = get_arguments()

print("\nConfiguration:")
for key, value in vars(args).items():
    print(f"  {key}: {value}")

# Setup device
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

Enhanced GNN Training Pipeline

Configuration:
  dataset: A
  train_mode: 1
  num_layer: 3
  emb_dim: 256
  drop_ratio: 0.3
  virtual_node: True
  residual: True
  JK: last
  edge_drop_ratio: 0.15
  batch_norm: True
  graph_pooling: mean
  batch_size: 64
  epochs: 200
  baseline_mode: 3
  noise_prob: 0.2
  gce_q: 0.5
  initial_lr: 0.001
  use_scheduler: True
  scheduler_type: ReduceLROnPlateau
  step_size: 30
  gamma: 0.5
  patience_lr: 10
  factor: 0.5
  min_lr: 1e-07
  T_max: 50
  eta_min: 1e-06
  gamma_exp: 0.95
  max_lr: 0.001
  pct_start: 0.3
  early_stopping: True
  patience: 25
  device: 0
  num_checkpoints: 3

Using device: cuda:0


### 7.2 Data Loading

In [15]:
print("\n" + "="*40)
print("LOADING DATA")
print("="*40)

base_path = '/kaggle/input/deep-dataset-preprocessed/processed_data_separate'

# Prepare training/validation data based on mode
if args.train_mode == 1:
    # Single dataset mode
    dataset_name = args.dataset
    train_dataset = torch.load(f'{base_path}/{dataset_name}_train_graphs.pt', weights_only=False)
    train_dataset = [add_zeros(data) for data in train_dataset]
    
    val_dataset = torch.load(f'{base_path}/{dataset_name}_val_graphs.pt', weights_only=False)
    val_dataset = [add_zeros(data) for data in val_dataset]
    
    test_dataset = torch.load(f'{base_path}/{dataset_name}_test_graphs.pt', weights_only=False)
    test_dataset = [add_zeros(data) for data in test_dataset]
    print(f"Using single dataset: {dataset_name}")
else:
    # All datasets mode
    train_dataset = []
    val_dataset = []
    test_dataset = torch.load(f'{base_path}/{args.dataset}_test_graphs.pt', weights_only=False)  # Test on specified dataset
    
    for ds_name in ['A', 'B', 'C', 'D']:
        train_dataset.extend(torch.load(f'{base_path}/{ds_name}_train_graphs.pt', weights_only=False))
        val_dataset.extend(torch.load(f'{base_path}/{ds_name}_val_graphs.pt', weights_only=False))
    
    print("Using all datasets for training")

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)


LOADING DATA
Using single dataset: A
Train samples: 10152
Val samples: 1128
Test samples: 2340


### 7.3 Model Setup

In [16]:

print("\n" + "="*40)
print("MODEL SETUP")
print("="*40)

# Initialize model
model = GNN(num_class=6, # Assuming 6 classes based on original notebook
            num_layer=args.num_layer,
            emb_dim=args.emb_dim,
            drop_ratio=args.drop_ratio,
            virtual_node=args.virtual_node,
            residual=args.residual,
            JK=args.JK,
            graph_pooling=args.graph_pooling,
            edge_drop_ratio = args.edge_drop_ratio,
            batch_norm=args.batch_norm
           )

model = model.to(device)

# Setup optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)

if args.baseline_mode == 2:
    criterion = NoisyCrossEntropyLoss(args.noise_prob)
    print(f"Using Noisy Cross Entropy Loss (p={args.noise_prob})")
elif args.baseline_mode == 3: # <--- ADD THIS BLOCK FOR GCE
    criterion = GeneralizedCrossEntropyLoss(q=args.gce_q)
    print(f"Using Generalized Cross Entropy (GCE) Loss (q={args.gce_q})")
else:
    criterion = torch.nn.CrossEntropyLoss()
    print("Using standard Cross Entropy Loss")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Setup logging and checkpoints
exp_name = f"gin_dataset{args.dataset}_mode{args.train_mode}"
logs_dir = os.path.join("logs", exp_name)
checkpoints_dir = os.path.join("checkpoints", exp_name)
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)

# Setup logging
log_file = os.path.join(logs_dir, "training.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

#best_model_path = os.path.join(checkpoints_dir, "best_model.pth")
best_model_path = '/kaggle/working/best_model.pth'


MODEL SETUP
Using Generalized Cross Entropy (GCE) Loss (q=0.5)
Model parameters: 1,330,953


## 8. Main training loop

### 8.1 Learning rate settings

In [17]:
# Learning Rate Scheduler Setup
print("\n" + "="*40)
print("SCHEDULER SETUP")
print("="*40)

# Update optimizer with initial learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)

scheduler = None
if args.use_scheduler:
    if args.scheduler_type == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, 
            step_size=args.step_size, 
            gamma=args.gamma
        )
        print(f"Using StepLR scheduler: step_size={args.step_size}, gamma={args.gamma}")
        
    elif args.scheduler_type == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min',  # We want to reduce LR when validation loss stops decreasing
            factor=args.factor,
            patience=args.patience_lr,
            min_lr=args.min_lr,
        )
        print(f"Using ReduceLROnPlateau scheduler: factor={args.factor}, patience={args.patience_lr}, min_lr={args.min_lr}")
        
    elif args.scheduler_type == 'CosineAnnealingLR':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=args.T_max,
            eta_min=args.eta_min
        )
        print(f"Using CosineAnnealingLR scheduler: T_max={args.T_max}, eta_min={args.eta_min}")
        
    elif args.scheduler_type == 'ExponentialLR':
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer,
            gamma=args.gamma_exp
        )
        print(f"Using ExponentialLR scheduler: gamma={args.gamma_exp}")
        
    elif args.scheduler_type == 'OneCycleLR':
        # Calculate total steps for OneCycleLR
        total_steps = len(train_loader) * args.epochs
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=args.max_lr,
            total_steps=total_steps,
            pct_start=args.pct_start
        )
        print(f"Using OneCycleLR scheduler: max_lr={args.max_lr}, total_steps={total_steps}, pct_start={args.pct_start}")
        
    else:
        print(f"Unknown scheduler type: {args.scheduler_type}. No scheduler will be used.")
        args.use_scheduler = False
else:
    print("No learning rate scheduler will be used.")

print(f"Initial learning rate: {args.initial_lr}")


SCHEDULER SETUP
Using ReduceLROnPlateau scheduler: factor=0.5, patience=10, min_lr=1e-07
Initial learning rate: 0.001


### Training loop Call


In [None]:
print("\n" + "="*40)
print("TRAINING")
print("="*40)

best_val_accuracy = 0.0
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
learning_rates = []  # Track learning rates

# Early stopping variables
if args.early_stopping:
    epochs_without_improvement = 0
    print(f"Early stopping enabled with patience: {args.patience}")
else:
    print("Early stopping disabled")

# Calculate checkpoint intervals
if args.num_checkpoints > 1:
    checkpoint_intervals = [int((i + 1) * args.epochs / args.num_checkpoints) 
                          for i in range(args.num_checkpoints)]
else:
    checkpoint_intervals = [args.epochs]

for epoch in range(args.epochs):
    print(f"\nEpoch {epoch + 1}/{args.epochs}")
    print("-" * 30)

    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Training
    train_loss, train_acc = train(
        train_loader, model, optimizer, criterion, device,
        save_checkpoints=(epoch + 1 in checkpoint_intervals),
        checkpoint_path=os.path.join(checkpoints_dir, "checkpoint"),
        current_epoch=epoch,
        scheduler=scheduler,
        args=args
    )
    
    # Validation
    val_loss, val_acc = evaluate(val_loader, model, criterion, device, calculate_accuracy=True)
    
    # Log results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    print(f"Learning Rate: {current_lr:.2e}")
    
    logging.info(f"Epoch {epoch + 1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
                f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, LR={current_lr:.2e}")
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Save best model
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"★ New best model saved! Val Acc: {val_acc:.4f}")

        # Reset early stopping counter
        if args.early_stopping:
            epochs_without_improvement = 0

    else:
        # No improvement
        if args.early_stopping:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")
            
            # Check if we should stop early
            if epochs_without_improvement >= args.patience:
                print(f"\nEarly stopping triggered! No improvement for {args.patience} epochs.")
                print(f"Best validation accuracy: {best_val_accuracy:.4f}")
                break

    # Learning rate scheduler step
    if scheduler is not None:
        if args.scheduler_type == 'ReduceLROnPlateau':
            # ReduceLROnPlateau needs the metric to monitor
            scheduler.step(val_loss)
        elif args.scheduler_type == 'OneCycleLR':
            # OneCycleLR steps every batch, not every epoch
            # This is handled in the training function
            pass
        else:
            # Other schedulers step every epoch
            scheduler.step()
        
        # Check if learning rate changed
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr != current_lr:
            print(f"Learning rate changed: {current_lr:.2e} → {new_lr:.2e}")

print(f"\nBest validation accuracy: {best_val_accuracy:.4f}")


TRAINING
Early stopping enabled with patience: 25

Epoch 1/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.34batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:00<00:00, 18.29batch/s]


Train Loss: 1.0736, Train Acc: 0.3378
Val Loss: 1.1378, Val Acc: 0.3023
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.3023

Epoch 2/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.92batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:00<00:00, 18.19batch/s]


Train Loss: 1.0046, Train Acc: 0.3784
Val Loss: 0.9918, Val Acc: 0.3945
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.3945

Epoch 3/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.80batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.90batch/s]


Train Loss: 0.9753, Train Acc: 0.4015
Val Loss: 1.0126, Val Acc: 0.3652
Learning Rate: 1.00e-03
No improvement for 1 epoch(s)

Epoch 4/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.60batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.64batch/s]


Train Loss: 0.9452, Train Acc: 0.4249
Val Loss: 1.1261, Val Acc: 0.3076
Learning Rate: 1.00e-03
No improvement for 2 epoch(s)

Epoch 5/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.49batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.42batch/s]


Train Loss: 0.9130, Train Acc: 0.4516
Val Loss: 1.4353, Val Acc: 0.2207
Learning Rate: 1.00e-03
No improvement for 3 epoch(s)

Epoch 6/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.32batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.16batch/s]


Train Loss: 0.8954, Train Acc: 0.4613
Val Loss: 0.9273, Val Acc: 0.4397
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.4397

Epoch 7/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.10batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.91batch/s]


Train Loss: 0.8794, Train Acc: 0.4722
Val Loss: 1.0472, Val Acc: 0.3493
Learning Rate: 1.00e-03
No improvement for 1 epoch(s)

Epoch 8/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  8.93batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.63batch/s]


Train Loss: 0.8627, Train Acc: 0.4830
Val Loss: 1.0004, Val Acc: 0.3803
Learning Rate: 1.00e-03
No improvement for 2 epoch(s)

Epoch 9/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.00batch/s]


Train Loss: 0.8434, Train Acc: 0.4985
Val Loss: 1.0120, Val Acc: 0.3537
Learning Rate: 1.00e-03
No improvement for 3 epoch(s)

Epoch 10/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.13batch/s]


Train Loss: 0.8336, Train Acc: 0.5048
Val Loss: 0.9455, Val Acc: 0.4344
Learning Rate: 1.00e-03
No improvement for 4 epoch(s)

Epoch 11/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.12batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.95batch/s]


Train Loss: 0.8248, Train Acc: 0.5142
Val Loss: 1.0337, Val Acc: 0.3546
Learning Rate: 1.00e-03
No improvement for 5 epoch(s)

Epoch 12/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.03batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.89batch/s]


Train Loss: 0.8109, Train Acc: 0.5233
Val Loss: 0.8922, Val Acc: 0.4548
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.4548

Epoch 13/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.04batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.85batch/s]


Train Loss: 0.7911, Train Acc: 0.5361
Val Loss: 1.0909, Val Acc: 0.3440
Learning Rate: 1.00e-03
No improvement for 1 epoch(s)

Epoch 14/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.99batch/s]


Train Loss: 0.7837, Train Acc: 0.5423
Val Loss: 0.8948, Val Acc: 0.4619
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.4619

Epoch 15/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.08batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.99batch/s]


Train Loss: 0.7720, Train Acc: 0.5499
Val Loss: 0.8687, Val Acc: 0.5035
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.5035

Epoch 16/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.10batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.02batch/s]


Train Loss: 0.7581, Train Acc: 0.5594
Val Loss: 0.8696, Val Acc: 0.4690
Learning Rate: 1.00e-03
No improvement for 1 epoch(s)

Epoch 17/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.09batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.96batch/s]


Train Loss: 0.7494, Train Acc: 0.5628
Val Loss: 0.8389, Val Acc: 0.4894
Learning Rate: 1.00e-03
No improvement for 2 epoch(s)

Epoch 18/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.11batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.00batch/s]


Train Loss: 0.7366, Train Acc: 0.5707
Val Loss: 0.9218, Val Acc: 0.4441
Learning Rate: 1.00e-03
No improvement for 3 epoch(s)

Epoch 19/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.00batch/s]


Train Loss: 0.7289, Train Acc: 0.5785
Val Loss: 0.9062, Val Acc: 0.4388
Learning Rate: 1.00e-03
No improvement for 4 epoch(s)

Epoch 20/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.87batch/s]


Train Loss: 0.7227, Train Acc: 0.5868
Val Loss: 0.7780, Val Acc: 0.5443
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.5443

Epoch 21/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.97batch/s]


Train Loss: 0.7155, Train Acc: 0.5919
Val Loss: 0.8720, Val Acc: 0.4681
Learning Rate: 1.00e-03
No improvement for 1 epoch(s)

Epoch 22/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.91batch/s]


Train Loss: 0.7060, Train Acc: 0.5932
Val Loss: 0.7309, Val Acc: 0.5771
Learning Rate: 1.00e-03
★ New best model saved! Val Acc: 0.5771

Epoch 23/200
------------------------------


Iterating training graphs:  36%|███▋      | 58/159 [00:06<00:11,  9.01batch/s]

In [None]:
# Plot training progress
plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, learning_rates, logs_dir)

### 7.5 Testing and predictions

In [None]:
print("\n" + "="*40)
print("TESTING")
print("="*40)

# Load best model and make predictions
model.load_state_dict(torch.load(best_model_path))
print(f"Loaded best model from: {best_model_path}")

predictions = evaluate(test_loader, model, criterion, device, calculate_accuracy=False)

# Save predictions
save_predictions(predictions, args.dataset)

# Cleanup for memory
del train_dataset, val_dataset, test_dataset
del train_loader, val_loader, test_loader
gc.collect()

print("\n" + "="*60)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"Best validation accuracy: {best_val_accuracy:.4f}")
print(f"Predictions saved for dataset {args.dataset}")
print(f"Logs and plots saved in: {logs_dir}")