# Graph Neural Network Training Pipeline

Multi-Dataset Graph Classification with Noise-Robust Training

## 1. Setup and Dependencies

In [4]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [5]:
import os
import gc
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import argparse

In [6]:
helper_scripts_path = '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/'

if os.path.exists(helper_scripts_path):
    # Add this path to the beginning of Python's search list
    sys.path.insert(0, helper_scripts_path)
    print(f"Successfully added '{helper_scripts_path}' to sys.path.")
    print(f"Contents of '{helper_scripts_path}': {os.listdir(helper_scripts_path)}") # Verify
else:
    print(f"WARNING: Helper scripts path not found: {helper_scripts_path}")
    print("Please ensure 'myhackathonhelperscripts' dataset is correctly added to the notebook.")

# Start import of utils modules
try:
    from preprocessor import MultiDatasetLoader
    from utils import set_seed
    # from conv import GINConv as OriginalRepoGINConv
    from models_EDandBatch_norm import GNN
    print("Successfully imported modules.")
except ImportError as e:
    print(f"ERROR importing module: {e}")
    print("Please check that the .py files exist directly under the helper_scripts_path and have no syntax errors.")
    # print("Current sys.path:", sys.path)

# Set the random seed
set_seed()

Successfully added '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/' to sys.path.
Contents of '/kaggle/input/d/leonardosandri/myhackatonhelperscripts/': ['models_edge_drop.py', 'zipthefolder.py', 'loadData.py', 'utils.py', 'models_EDandBatch_norm.py', 'models.py', 'conv.py', 'preprocessor.py', '__init__.py']
Successfully imported modules.


## 2. Data Preprocessing Functions


In [7]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

## 3. Training and Evaluation Functions

In [8]:
def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for data in tqdm(data_loader, desc="Iterating training graphs", unit="batch"):
        data = data.to(device)
        optimizer.zero_grad()
        try:
            output = model(data)
        except IndexError as e:
            print(f"Error in batch with {data.num_nodes} nodes, edge_max={data.edge_index.max()}")
            print(f"Batch info: x.shape={data.x.shape}, edge_index.shape={data.edge_index.shape}")
            raise e
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)

    # Save checkpoints if required
    if save_checkpoints:
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    return total_loss / len(data_loader),  correct / total

In [9]:
# CELL 7 (Corrected)
def evaluate(data_loader, model, criterion, device, calculate_accuracy=False): # Added 'criterion' argument
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    # REMOVE THE HARDCODED CRITERION:
    # criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                # NOW USES THE PASSED-IN CRITERION
                total_loss += criterion(output, data.y).item()
            else:
                predictions.extend(pred.cpu().numpy())
    if calculate_accuracy:
        accuracy = correct / total
        return  total_loss / len(data_loader), accuracy # Ensure consistent return order
    return predictions

## 4. Utility Functions

In [10]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd() 
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [11]:
def plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    """
    Plot training and validation progress over epochs.
    
    Args:
        train_losses: List of training losses per epoch
        train_accuracies: List of training accuracies per epoch  
        val_losses: List of validation losses per epoch
        val_accuracies: List of validation accuracies per epoch
        output_dir: Directory to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(15, 6))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue', marker='o')
    plt.plot(epochs, val_losses, label="Validation Loss", color='red', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='green', marker='o')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Save plot
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.show()
    plt.close()

## 5. Configuration and Arguments

In [12]:
def get_user_input(prompt, default=None, required=False, type_cast=str):

    while True:
        user_input = input(f"{prompt} [{default}]: ")
        
        if user_input == "" and required:
            print("This field is required. Please enter a value.")
            continue
        
        if user_input == "" and default is not None:
            return default
        
        if user_input == "" and not required:
            return None
        
        try:
            return type_cast(user_input)
        except ValueError:
            print(f"Invalid input. Please enter a valid {type_cast.__name__}.")

In [13]:
def get_arguments():
    """Set training configuration directly"""
    args = {
        # Dataset selection
        'dataset': 'A',  # Choose: A, B, C, D
        'train_mode': 1,  # 1=single dataset, 2=all datasets
        
        # Model config
        #'gnn': 'gin',  # gin, gin-virtual, gcn, gcn-virtual
        'num_layer': 3,
        'emb_dim': 256,
        'drop_ratio': 0.3,   # Dropout ratio
        'virtual_node': True, # True to use virtual node, False otherwise
        'residual': True,    # True to use residual connections, False otherwise
        'JK': "last",         # Jumping Knowledge: "last", "sum", "cat"
        'edge_drop_ratio' : 0.15,
        'batch_norm' : True,
        'graph_pooling': "mean", # "sum", "mean", "max", "attention", "set2set"
        
        # Training config
        'batch_size': 64,
        'epochs': 200,
        'baseline_mode': 3,  # 1=CE, 2=Noisy CE, 3 GCE
        'noise_prob': 0.2,
        'gce_q' : 0.5,

        # Early stopping config
        'early_stopping': True,  # Enable/disable early stopping
        'patience': 25,          # Number of epochs to wait without improvement
        
        # System config
        'device': 0,
        'num_checkpoints': 3,
    }
    return argparse.Namespace(**args)

In [14]:
def populate_args(args):
    print("Arguments received:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")
args = get_arguments()
populate_args(args)

Arguments received:
dataset: A
train_mode: 1
num_layer: 3
emb_dim: 256
drop_ratio: 0.3
virtual_node: True
residual: True
JK: last
edge_drop_ratio: 0.15
batch_norm: True
graph_pooling: mean
batch_size: 64
epochs: 200
baseline_mode: 3
noise_prob: 0.2
gce_q: 0.5
early_stopping: True
patience: 25
device: 0
num_checkpoints: 3


## 6. Loss Function Definition

In [15]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (1 - torch.nn.functional.one_hot(targets, num_classes=logits.size(1)).float().sum(dim=1))
        return (losses * weights).mean()

In [16]:
# CELL 12.1 (New Cell or append to existing cell 12)
class GeneralizedCrossEntropyLoss(torch.nn.Module):
    def __init__(self, q=0.7):
        """
        Generalized Cross Entropy Loss.
        q is a hyperparameter, 0 < q <= 1.
        As q -> 0, GCE approaches standard CE.
        """
        super(GeneralizedCrossEntropyLoss, self).__init__()
        if not (0 < q <= 1):
            # While the limit q->0 is CE, for q=0 direct computation is 1/0.
            # The paper usually uses q > 0.
            raise ValueError("q should be in (0, 1]")
        self.q = q

    def forward(self, logits, targets):
        probs = torch.softmax(logits, dim=1)
        # Select probabilities of the target class for each sample
        target_probs = probs[torch.arange(targets.size(0)), targets]

        # To prevent issues with target_probs being exactly 0,
        # especially if q is very small (though here q > 0).
        # However, 0^q is 0 for q > 0, so it should be fine.
        # For extra safety: target_probs = target_probs.clamp(min=1e-8)

        # GCE loss: (1 - p_t^q) / q
        loss = (1 - (target_probs ** self.q)) / self.q
        return loss.mean()

## 7. Model creation

### 7.1 Config section


In [17]:
print("=" * 60)
print("Enhanced GNN Training Pipeline")
print("=" * 60)

# Get configuration
args = get_arguments()

print("\nConfiguration:")
for key, value in vars(args).items():
    print(f"  {key}: {value}")

# Setup device
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

Enhanced GNN Training Pipeline

Configuration:
  dataset: A
  train_mode: 1
  num_layer: 3
  emb_dim: 256
  drop_ratio: 0.3
  virtual_node: True
  residual: True
  JK: last
  edge_drop_ratio: 0.15
  batch_norm: True
  graph_pooling: mean
  batch_size: 64
  epochs: 200
  baseline_mode: 3
  noise_prob: 0.2
  gce_q: 0.5
  early_stopping: True
  patience: 25
  device: 0
  num_checkpoints: 3

Using device: cuda:0


### 7.2 Data Loading

In [18]:
print("\n" + "="*40)
print("LOADING DATA")
print("="*40)

base_path = '/kaggle/input/deep-dataset-preprocessed/processed_data_separate'

# Prepare training/validation data based on mode
if args.train_mode == 1:
    # Single dataset mode
    dataset_name = args.dataset
    train_dataset = torch.load(f'{base_path}/{dataset_name}_train_graphs.pt', weights_only=False)
    train_dataset = [add_zeros(data) for data in train_dataset]
    
    val_dataset = torch.load(f'{base_path}/{dataset_name}_val_graphs.pt', weights_only=False)
    val_dataset = [add_zeros(data) for data in val_dataset]
    
    test_dataset = torch.load(f'{base_path}/{dataset_name}_test_graphs.pt', weights_only=False)
    test_dataset = [add_zeros(data) for data in test_dataset]
    print(f"Using single dataset: {dataset_name}")
else:
    # All datasets mode
    train_dataset = []
    val_dataset = []
    test_dataset = torch.load(f'{base_path}/{args.dataset}_test_graphs.pt', weights_only=False)  # Test on specified dataset
    
    for ds_name in ['A', 'B', 'C', 'D']:
        train_dataset.extend(torch.load(f'{base_path}/{ds_name}_train_graphs.pt', weights_only=False))
        val_dataset.extend(torch.load(f'{base_path}/{ds_name}_val_graphs.pt', weights_only=False))
    
    print("Using all datasets for training")

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)


LOADING DATA
Using single dataset: A
Train samples: 10152
Val samples: 1128
Test samples: 2340


### 7.3 Model Setup

In [19]:

print("\n" + "="*40)
print("MODEL SETUP")
print("="*40)

# Initialize model
model = GNN(num_class=6, # Assuming 6 classes based on original notebook
            num_layer=args.num_layer,
            emb_dim=args.emb_dim,
            drop_ratio=args.drop_ratio,
            virtual_node=args.virtual_node,
            residual=args.residual,
            JK=args.JK,
            graph_pooling=args.graph_pooling,
            edge_drop_ratio = args.edge_drop_ratio,
            batch_norm=args.batch_norm
           )

model = model.to(device)

# Setup optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

if args.baseline_mode == 2:
    criterion = NoisyCrossEntropyLoss(args.noise_prob)
    print(f"Using Noisy Cross Entropy Loss (p={args.noise_prob})")
elif args.baseline_mode == 3: # <--- ADD THIS BLOCK FOR GCE
    criterion = GeneralizedCrossEntropyLoss(q=args.gce_q)
    print(f"Using Generalized Cross Entropy (GCE) Loss (q={args.gce_q})")
else:
    criterion = torch.nn.CrossEntropyLoss()
    print("Using standard Cross Entropy Loss")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Setup logging and checkpoints
#exp_name = f"{args.gnn}_dataset{args.dataset}_mode{args.train_mode}"
exp_name = f"gin_dataset{args.dataset}_mode{args.train_mode}"
logs_dir = os.path.join("logs", exp_name)
checkpoints_dir = os.path.join("checkpoints", exp_name)
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)

# Setup logging
log_file = os.path.join(logs_dir, "training.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

#best_model_path = os.path.join(checkpoints_dir, "best_model.pth")
best_model_path = '/kaggle/working/best_model.pth'


MODEL SETUP
Using Generalized Cross Entropy (GCE) Loss (q=0.5)
Model parameters: 1,330,953


## 8. Main training loop

### 7.4 Training loop


In [None]:
print("\n" + "="*40)
print("TRAINING")
print("="*40)

best_val_accuracy = 0.0
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Early stopping variables
if args.early_stopping:
    epochs_without_improvement = 0
    print(f"Early stopping enabled with patience: {args.patience}")
else:
    print("Early stopping disabled")

# Calculate checkpoint intervals
if args.num_checkpoints > 1:
    checkpoint_intervals = [int((i + 1) * args.epochs / args.num_checkpoints) 
                          for i in range(args.num_checkpoints)]
else:
    checkpoint_intervals = [args.epochs]

for epoch in range(args.epochs):
    print(f"\nEpoch {epoch + 1}/{args.epochs}")
    print("-" * 30)
    
    # Training
    train_loss, train_acc = train(
        train_loader, model, optimizer, criterion, device,
        save_checkpoints=(epoch + 1 in checkpoint_intervals),
        checkpoint_path=os.path.join(checkpoints_dir, "checkpoint"),
        current_epoch=epoch
    )
    
    # Validation
    val_loss, val_acc = evaluate(val_loader, model, criterion, device, calculate_accuracy=True)
    
    # Log results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    logging.info(f"Epoch {epoch + 1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
                f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}")
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Save best model
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"★ New best model saved! Val Acc: {val_acc:.4f}")

        # Reset early stopping counter
        if args.early_stopping:
            epochs_without_improvement = 0

    else:
        # No improvement
        if args.early_stopping:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")
            
            # Check if we should stop early
            if epochs_without_improvement >= args.patience:
                print(f"\nEarly stopping triggered! No improvement for {args.patience} epochs.")
                print(f"Best validation accuracy: {best_val_accuracy:.4f}")
                break

print(f"\nBest validation accuracy: {best_val_accuracy:.4f}")


TRAINING
Early stopping enabled with patience: 25

Epoch 1/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.30batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:00<00:00, 18.27batch/s]


Train Loss: 1.1188, Train Acc: 0.2837
Val Loss: 1.0594, Val Acc: 0.3502
★ New best model saved! Val Acc: 0.3502

Epoch 2/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.91batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:00<00:00, 18.15batch/s]


Train Loss: 1.0424, Train Acc: 0.3658
Val Loss: 1.0274, Val Acc: 0.3590
★ New best model saved! Val Acc: 0.3590

Epoch 3/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.80batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.85batch/s]


Train Loss: 1.0006, Train Acc: 0.3929
Val Loss: 1.0014, Val Acc: 0.3750
★ New best model saved! Val Acc: 0.3750

Epoch 4/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.66batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.83batch/s]


Train Loss: 0.9706, Train Acc: 0.4123
Val Loss: 0.9456, Val Acc: 0.4105
★ New best model saved! Val Acc: 0.4105

Epoch 5/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.51batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.52batch/s]


Train Loss: 0.9456, Train Acc: 0.4296
Val Loss: 1.1056, Val Acc: 0.3342
No improvement for 1 epoch(s)

Epoch 6/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:16<00:00,  9.37batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.08batch/s]


Train Loss: 0.9181, Train Acc: 0.4451
Val Loss: 0.9807, Val Acc: 0.4043
No improvement for 2 epoch(s)

Epoch 7/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.18batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.62batch/s]


Train Loss: 0.9000, Train Acc: 0.4602
Val Loss: 0.9318, Val Acc: 0.4309
★ New best model saved! Val Acc: 0.4309

Epoch 8/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  8.97batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.86batch/s]


Train Loss: 0.8825, Train Acc: 0.4741
Val Loss: 0.9096, Val Acc: 0.4424
★ New best model saved! Val Acc: 0.4424

Epoch 9/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.00batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.46batch/s]


Train Loss: 0.8783, Train Acc: 0.4785
Val Loss: 1.0921, Val Acc: 0.3289
No improvement for 1 epoch(s)

Epoch 10/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.09batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.72batch/s]


Train Loss: 0.8700, Train Acc: 0.4822
Val Loss: 1.0117, Val Acc: 0.3883
No improvement for 2 epoch(s)

Epoch 11/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.10batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.79batch/s]


Train Loss: 0.8602, Train Acc: 0.4885
Val Loss: 0.9503, Val Acc: 0.4264
No improvement for 3 epoch(s)

Epoch 12/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.10batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.62batch/s]


Train Loss: 0.8530, Train Acc: 0.4970
Val Loss: 0.9931, Val Acc: 0.3998
No improvement for 4 epoch(s)

Epoch 13/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.03batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.69batch/s]


Train Loss: 0.8419, Train Acc: 0.5029
Val Loss: 1.1644, Val Acc: 0.3422
No improvement for 5 epoch(s)

Epoch 14/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.13batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.05batch/s]


Train Loss: 0.8348, Train Acc: 0.5072
Val Loss: 0.8639, Val Acc: 0.4707
★ New best model saved! Val Acc: 0.4707

Epoch 15/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.10batch/s]


Train Loss: 0.8347, Train Acc: 0.5048
Val Loss: 0.8906, Val Acc: 0.4707
No improvement for 1 epoch(s)

Epoch 16/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.18batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.04batch/s]


Train Loss: 0.8205, Train Acc: 0.5154
Val Loss: 0.8320, Val Acc: 0.5044
★ New best model saved! Val Acc: 0.5044

Epoch 17/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.19batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.08batch/s]


Train Loss: 0.8233, Train Acc: 0.5157
Val Loss: 0.9333, Val Acc: 0.4530
No improvement for 1 epoch(s)

Epoch 18/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.12batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.91batch/s]


Train Loss: 0.8202, Train Acc: 0.5159
Val Loss: 1.2788, Val Acc: 0.2828
No improvement for 2 epoch(s)

Epoch 19/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.08batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.89batch/s]


Train Loss: 0.8077, Train Acc: 0.5273
Val Loss: 1.0768, Val Acc: 0.3697
No improvement for 3 epoch(s)

Epoch 20/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.71batch/s]


Train Loss: 0.8098, Train Acc: 0.5227
Val Loss: 1.4328, Val Acc: 0.2234
No improvement for 4 epoch(s)

Epoch 21/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.13batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.15batch/s]


Train Loss: 0.8016, Train Acc: 0.5312
Val Loss: 0.8453, Val Acc: 0.4956
No improvement for 5 epoch(s)

Epoch 22/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.16batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.04batch/s]


Train Loss: 0.7953, Train Acc: 0.5344
Val Loss: 1.0638, Val Acc: 0.3475
No improvement for 6 epoch(s)

Epoch 23/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.16batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.91batch/s]


Train Loss: 0.7943, Train Acc: 0.5329
Val Loss: 1.1250, Val Acc: 0.3404
No improvement for 7 epoch(s)

Epoch 24/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.16batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.73batch/s]


Train Loss: 0.7897, Train Acc: 0.5369
Val Loss: 0.8124, Val Acc: 0.5124
★ New best model saved! Val Acc: 0.5124

Epoch 25/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.00batch/s]


Train Loss: 0.7790, Train Acc: 0.5430
Val Loss: 0.9149, Val Acc: 0.4504
No improvement for 1 epoch(s)

Epoch 26/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.98batch/s]


Train Loss: 0.7777, Train Acc: 0.5473
Val Loss: 0.9321, Val Acc: 0.4512
No improvement for 2 epoch(s)

Epoch 27/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.15batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.99batch/s]


Train Loss: 0.7763, Train Acc: 0.5500
Val Loss: 0.8949, Val Acc: 0.4557
No improvement for 3 epoch(s)

Epoch 28/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.13batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.96batch/s]


Train Loss: 0.7690, Train Acc: 0.5513
Val Loss: 1.0796, Val Acc: 0.3768
No improvement for 4 epoch(s)

Epoch 29/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.15batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.08batch/s]


Train Loss: 0.7646, Train Acc: 0.5552
Val Loss: 1.0012, Val Acc: 0.4193
No improvement for 5 epoch(s)

Epoch 30/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.15batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.09batch/s]


Train Loss: 0.7603, Train Acc: 0.5611
Val Loss: 0.8989, Val Acc: 0.4654
No improvement for 6 epoch(s)

Epoch 31/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.13batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.07batch/s]


Train Loss: 0.7560, Train Acc: 0.5625
Val Loss: 0.9032, Val Acc: 0.4504
No improvement for 7 epoch(s)

Epoch 32/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.15batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.04batch/s]


Train Loss: 0.7496, Train Acc: 0.5652
Val Loss: 0.9267, Val Acc: 0.4441
No improvement for 8 epoch(s)

Epoch 33/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.10batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.86batch/s]


Train Loss: 0.7501, Train Acc: 0.5643
Val Loss: 0.7824, Val Acc: 0.5346
★ New best model saved! Val Acc: 0.5346

Epoch 34/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.08batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.49batch/s]


Train Loss: 0.7441, Train Acc: 0.5705
Val Loss: 0.8730, Val Acc: 0.4902
No improvement for 1 epoch(s)

Epoch 35/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.03batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.69batch/s]


Train Loss: 0.7471, Train Acc: 0.5642
Val Loss: 0.7742, Val Acc: 0.5355
★ New best model saved! Val Acc: 0.5355

Epoch 36/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.04batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.80batch/s]


Train Loss: 0.7401, Train Acc: 0.5722
Val Loss: 0.9576, Val Acc: 0.4450
No improvement for 1 epoch(s)

Epoch 37/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.67batch/s]


Train Loss: 0.7333, Train Acc: 0.5761
Val Loss: 0.8853, Val Acc: 0.4840
No improvement for 2 epoch(s)

Epoch 38/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.01batch/s]


Train Loss: 0.7389, Train Acc: 0.5720
Val Loss: 0.8550, Val Acc: 0.4902
No improvement for 3 epoch(s)

Epoch 39/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.16batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.04batch/s]


Train Loss: 0.7264, Train Acc: 0.5800
Val Loss: 1.0766, Val Acc: 0.3502
No improvement for 4 epoch(s)

Epoch 40/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 17.03batch/s]


Train Loss: 0.7238, Train Acc: 0.5808
Val Loss: 0.8576, Val Acc: 0.4920
No improvement for 5 epoch(s)

Epoch 41/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.14batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.76batch/s]


Train Loss: 0.7247, Train Acc: 0.5840
Val Loss: 0.9737, Val Acc: 0.4087
No improvement for 6 epoch(s)

Epoch 42/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.09batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.86batch/s]


Train Loss: 0.7230, Train Acc: 0.5834
Val Loss: 0.8674, Val Acc: 0.5027
No improvement for 7 epoch(s)

Epoch 43/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.76batch/s]


Train Loss: 0.7251, Train Acc: 0.5823
Val Loss: 0.7964, Val Acc: 0.5266
No improvement for 8 epoch(s)

Epoch 44/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.08batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.64batch/s]


Train Loss: 0.7201, Train Acc: 0.5868
Val Loss: 0.7799, Val Acc: 0.5426
★ New best model saved! Val Acc: 0.5426

Epoch 45/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.04batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.77batch/s]


Train Loss: 0.7240, Train Acc: 0.5830
Val Loss: 1.0891, Val Acc: 0.3741
No improvement for 1 epoch(s)

Epoch 46/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.74batch/s]


Train Loss: 0.7230, Train Acc: 0.5831
Val Loss: 0.8392, Val Acc: 0.5080
No improvement for 2 epoch(s)

Epoch 47/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.72batch/s]


Train Loss: 0.7152, Train Acc: 0.5884
Val Loss: 0.7843, Val Acc: 0.5363
No improvement for 3 epoch(s)

Epoch 48/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.04batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.79batch/s]


Train Loss: 0.7070, Train Acc: 0.5923
Val Loss: 1.2438, Val Acc: 0.3218
No improvement for 4 epoch(s)

Epoch 49/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.71batch/s]


Train Loss: 0.7175, Train Acc: 0.5865
Val Loss: 0.9361, Val Acc: 0.4397
No improvement for 5 epoch(s)

Epoch 50/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.77batch/s]


Train Loss: 0.7106, Train Acc: 0.5929
Val Loss: 0.8493, Val Acc: 0.4894
No improvement for 6 epoch(s)

Epoch 51/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.68batch/s]


Train Loss: 0.7050, Train Acc: 0.5948
Val Loss: 0.8179, Val Acc: 0.5089
No improvement for 7 epoch(s)

Epoch 52/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.75batch/s]


Train Loss: 0.7009, Train Acc: 0.6018
Val Loss: 1.0075, Val Acc: 0.4025
No improvement for 8 epoch(s)

Epoch 53/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.82batch/s]


Train Loss: 0.6976, Train Acc: 0.6008
Val Loss: 0.9469, Val Acc: 0.4530
No improvement for 9 epoch(s)

Epoch 54/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.07batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.74batch/s]


Train Loss: 0.6958, Train Acc: 0.6019
Val Loss: 1.0443, Val Acc: 0.3972
No improvement for 10 epoch(s)

Epoch 55/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.76batch/s]


Train Loss: 0.7011, Train Acc: 0.5961
Val Loss: 0.8238, Val Acc: 0.5195
No improvement for 11 epoch(s)

Epoch 56/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.73batch/s]


Train Loss: 0.6982, Train Acc: 0.6022
Val Loss: 0.9048, Val Acc: 0.4601
No improvement for 12 epoch(s)

Epoch 57/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.76batch/s]


Train Loss: 0.6954, Train Acc: 0.6021
Val Loss: 0.8597, Val Acc: 0.5000
No improvement for 13 epoch(s)

Epoch 58/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.68batch/s]


Train Loss: 0.6909, Train Acc: 0.6062
Val Loss: 0.9003, Val Acc: 0.4566
No improvement for 14 epoch(s)

Epoch 59/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.79batch/s]


Train Loss: 0.6888, Train Acc: 0.6072
Val Loss: 0.8095, Val Acc: 0.5195
No improvement for 15 epoch(s)

Epoch 60/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.05batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.76batch/s]


Train Loss: 0.6838, Train Acc: 0.6109
Val Loss: 0.7360, Val Acc: 0.5665
★ New best model saved! Val Acc: 0.5665

Epoch 61/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.70batch/s]


Train Loss: 0.6820, Train Acc: 0.6095
Val Loss: 1.0199, Val Acc: 0.4069
No improvement for 1 epoch(s)

Epoch 62/200
------------------------------


Iterating training graphs: 100%|██████████| 159/159 [00:17<00:00,  9.06batch/s]
Iterating eval graphs: 100%|██████████| 18/18 [00:01<00:00, 16.82batch/s]


Train Loss: 0.6840, Train Acc: 0.6099
Val Loss: 0.7430, Val Acc: 0.5674
★ New best model saved! Val Acc: 0.5674

Epoch 63/200
------------------------------


Iterating training graphs:   4%|▍         | 6/159 [00:00<00:16,  9.53batch/s]

In [None]:
# Plot training progress
plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, logs_dir)

### 7.5 Testing and predictions

In [None]:
print("\n" + "="*40)
print("TESTING")
print("="*40)

# Load best model and make predictions
model.load_state_dict(torch.load(best_model_path))
print(f"Loaded best model from: {best_model_path}")

predictions = evaluate(test_loader, model, criterion, device, calculate_accuracy=False)

# Save predictions
save_predictions(predictions, args.dataset)

# Cleanup for memory
del train_dataset, val_dataset, test_dataset
del train_loader, val_loader, test_loader
gc.collect()

print("\n" + "="*60)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"Best validation accuracy: {best_val_accuracy:.4f}")
print(f"Predictions saved for dataset {args.dataset}")
print(f"Logs and plots saved in: {logs_dir}")