In [None]:
!pip install torch_geometric

In [None]:
import os
import gc
import sys
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import argparse
import torch.nn.functional as F
from torch_geometric.utils import dropout_edge
import torch.nn as nn


In [None]:
helper_scripts_path = '/kaggle/input/myhackatonhelperscripts'

if os.path.exists(helper_scripts_path):
    # Add this path to the beginning of Python's search list
    sys.path.insert(0, helper_scripts_path)
    print(f"Successfully added '{helper_scripts_path}' to sys.path.")
    print(f"Contents of '{helper_scripts_path}': {os.listdir(helper_scripts_path)}") # Verify
else:
    print(f"WARNING: Helper scripts path not found: {helper_scripts_path}")
    print("Please ensure 'myhackathonhelperscripts' dataset is correctly added to the notebook.")

# Start import of utils modules
try:
    from preprocessor import MultiDatasetLoader
    from utils import set_seed
    # from conv import GINConv as OriginalRepoGINConv
    from models_EDandBatch_norm import GNN
    print("Successfully imported modules.")
except ImportError as e:
    print(f"ERROR importing module: {e}")
    print("Please check that the .py files exist directly under the helper_scripts_path and have no syntax errors.")
    # print("Current sys.path:", sys.path)

# Set the random seed
set_seed()

In [None]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [None]:
# In the cell with 'def train(...)'

def train(data_loader, model, optimizer, criterion, device, save_checkpoints, checkpoint_path, current_epoch,
          args_namespace, u_values_global, current_baseline_mode, scheduler=None): # Added new params
    model.train()
    total_loss_accum = 0.0 
    correct_preds = 0 
    total_samples_processed = 0 

    for data in tqdm(data_loader, desc="Iterating training graphs", unit="batch"):
        data = data.to(device)
        optimizer.zero_grad()

        output = model(data)

        if current_baseline_mode == 4: # GCOD specific logic
            # Corrected line for batch_indices:
            batch_indices = data.original_idx.to(device=device, dtype=torch.long)
            
            # Ensure u_values_global is on the correct device if batch_indices are used directly
            # If u_values_global is on CPU, then indices should be CPU too before indexing.
            # Assuming u_values_global is already on the same device as model/data or CPU (handled by .to(device) below)
            if u_values_global.device != device: # If u_values_global is on CPU and device is GPU
                 u_batch_cpu = u_values_global[batch_indices.cpu()].clone().detach()
                 u_batch = u_batch_cpu.to(device).requires_grad_(True)
            else: # If u_values_global is already on the target device
                 u_batch = u_values_global[batch_indices].clone().detach().requires_grad_(True)


            # Using the 'output' from the main model pass for u_optimization as well.
            # Detach it to prevent gradients from L2 flowing back to model parameters during u-opt.
            output_for_u_optim = output.detach()

            for _ in range(args_namespace.gcod_T_u):
                # Corrected condition for checking gradient:
                if u_batch.grad is not None:
                    u_batch.grad.zero_()
                
                L2_for_u = criterion.compute_L2(output_for_u_optim, data.y, u_batch)
                
                L2_for_u.backward() 
                with torch.no_grad():
                    u_batch.data -= args_namespace.gcod_lr_u * u_batch.grad.data
                    u_batch.data.clamp_(0, 1) 
            
            u_batch_optimized = u_batch.detach() 

            pred_for_acc = output.argmax(dim=1)
            if data.y.size(0) > 0:
                batch_accuracy = (pred_for_acc == data.y).sum().item() / data.y.size(0)
            else:
                batch_accuracy = 0.0

            loss_theta_components = criterion(output, data.y, u_batch_optimized, batch_accuracy)
            actual_loss_for_bp = loss_theta_components[0] 

            with torch.no_grad():
                # Ensure batch_indices are on the same device as u_values_global for assignment
                if u_values_global.device != device:
                    u_values_global[batch_indices.cpu()] = u_batch_optimized.cpu()
                else:
                    u_values_global[batch_indices] = u_batch_optimized
        
        else: 
            actual_loss_for_bp = criterion(output, data.y)
        
        try:
            actual_loss_for_bp.backward()
            optimizer.step()
            # Step OneCycleLR scheduler after each batch
            if scheduler is not None and args.scheduler_type == 'OneCycleLR':
                scheduler.step()
        except IndexError as e:
            edge_max_val = data.edge_index.max().item() if data.edge_index.numel() > 0 else 'N/A'
            print(f"Error in batch with {data.num_nodes} nodes, edge_max={edge_max_val}")
            print(f"Batch info: x.shape={data.x.shape}, edge_index.shape={data.edge_index.shape}")
            if current_baseline_mode == 4:
                print(f"GCOD context: u_batch_optimized shape: {u_batch_optimized.shape if 'u_batch_optimized' in locals() else 'N/A'}")
            raise e
        
        total_loss_accum += actual_loss_for_bp.item() 
        
        pred_final = output.argmax(dim=1)
        correct_preds += (pred_final == data.y).sum().item()
        total_samples_processed += data.y.size(0)

    if save_checkpoints: 
        checkpoint_file = f"{checkpoint_path}_epoch_{current_epoch + 1}.pth"
        torch.save(model.state_dict(), checkpoint_file)
        print(f"Checkpoint saved at {checkpoint_file}")

    avg_loss = total_loss_accum / len(data_loader) if len(data_loader) > 0 else 0.0
    accuracy = correct_preds / total_samples_processed if total_samples_processed > 0 else 0.0
    return avg_loss, accuracy

In [None]:
# CELL 7 (Corrected)
# def evaluate(data_loader, model, criterion, device, calculate_accuracy=False): # Original
def evaluate(data_loader, model, criterion, device, calculate_accuracy=False, args_namespace=None, u_values_global_eval=None): # Added args_namespace and u_values for eval if needed
    model.eval()
    correct = 0
    total = 0
    predictions_list = [] # Renamed
    total_loss_val = 0 # Renamed

    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            if calculate_accuracy:
                correct += (pred == data.y).sum().item()
                total += data.y.size(0)
                
                # Loss calculation
                if args_namespace and args_namespace.baseline_mode == 4:
                    # For GCOD evaluation, use L1 with u=0 as a proxy.
                    # GCODLoss.compute_L1(self, logits, targets, u_params)
                    u_eval_dummy = torch.zeros(data.y.size(0), device=device, dtype=torch.float)
                    # If you trained u_values for validation set and want to use them:
                    # batch_indices_eval = torch.tensor(data.original_idx, dtype=torch.long).to(device)
                    # u_eval_dummy = u_values_global_eval[batch_indices_eval].clone().detach().to(device)

                    loss_value = criterion.compute_L1(output, data.y, u_eval_dummy)
                else:
                    loss_value = criterion(output, data.y) # Standard call for other losses
                total_loss_val += loss_value.item()
            else:
                predictions_list.extend(pred.cpu().numpy()) # Renamed

    if calculate_accuracy:
        accuracy = correct / total if total > 0 else 0.0
        avg_loss = total_loss_val / len(data_loader) if len(data_loader) > 0 else 0.0
        return avg_loss, accuracy
    return predictions_list # Renamed

In [None]:
def plot_training_progress(train_losses, train_accuracies, val_losses, val_accuracies, output_dir):
    """
    Plot training and validation progress over epochs.
    
    Args:
        train_losses: List of training losses per epoch
        train_accuracies: List of training accuracies per epoch  
        val_losses: List of validation losses per epoch
        val_accuracies: List of validation accuracies per epoch
        output_dir: Directory to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(15, 6))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Training Loss", color='blue', marker='o')
    plt.plot(epochs, val_losses, label="Validation Loss", color='red', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Training Accuracy", color='green', marker='o')
    plt.plot(epochs, val_accuracies, label="Validation Accuracy", color='orange', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy per Epoch')
    plt.legend()
    plt.grid(True)
    
    # Save plot
    os.makedirs(output_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_progress.png"))
    plt.show()
    plt.close()

In [None]:
def save_predictions(predictions, test_path):
    script_dir = os.getcwd() 
    submission_folder = os.path.join(script_dir, "submission")
    test_dir_name = os.path.basename(os.path.dirname(test_path))
    
    os.makedirs(submission_folder, exist_ok=True)
    
    output_csv_path = os.path.join(submission_folder, f"testset_{test_dir_name}.csv")
    
    test_graph_ids = list(range(len(predictions)))
    output_df = pd.DataFrame({
        "id": test_graph_ids,
        "pred": predictions
    })
    
    output_df.to_csv(output_csv_path, index=False)
    print(f"Predictions saved to {output_csv_path}")

In [None]:
def get_user_input(prompt, default=None, required=False, type_cast=str):

    while True:
        user_input = input(f"{prompt} [{default}]: ")
        
        if user_input == "" and required:
            print("This field is required. Please enter a value.")
            continue
        
        if user_input == "" and default is not None:
            return default
        
        if user_input == "" and not required:
            return None
        
        try:
            return type_cast(user_input)
        except ValueError:
            print(f"Invalid input. Please enter a valid {type_cast.__name__}.")

In [None]:
def get_arguments():
    """Set training configuration directly"""
    args = {
        # Dataset selection
        'dataset': 'D',  # Choose: A, B, C, D
        'train_mode': 1,  # 1=single dataset, 2=all datasets
        
        # Model config
        #'gnn': 'gin',  # gin, gin-virtual, gcn, gcn-virtual
        'num_layer': 3,
        'emb_dim': 218,
        'drop_ratio': 0.7,   # Dropout ratio
        'virtual_node': True, # True to use virtual node, False otherwise
        'residual': True,    # True to use residual connections, False otherwise
        'JK': "last",         # Jumping Knowledge: "last", "sum", "cat"
        'graph_pooling': "mean", # "sum", "mean", "max", "attention", "set2set"
        'edge_drop_ratio' : 0.3,
        'batch_norm' : True,
        'layer_norm': False,
        
        # Training config
        'batch_size': 64,
        'epochs': 250,
        'baseline_mode': 4,  # 1=CE, 2=Noisy CE, 3 GCE, 4 GCOD
        'noise_prob': 0.2,
        'gce_q' : 0.4,
        'initial_lr' : 5e-3,

        # Early stopping config
        'early_stopping': True,  # Enable/disable early stopping
        'patience': 25,
        
        # GCOD Loss Hyperparameters
        'gcod_lambda_p': 2.0,    # Weight for prediction penalty in GCOD
        'gcod_T_u': 15,           # Number of optimization iterations for u in GCOD
        'gcod_lr_u': 0.1,       # Learning rate for optimizing u in GCOD

        

        # Lr scheduler config =================================================================================================================
        'use_scheduler' : True,
        'scheduler_type': 'ReduceLROnPlateau',  # Options: 'StepLR', 'ReduceLROnPlateau', 'CosineAnnealingLR', 'ExponentialLR', 'OneCycleLR'

        # StepLR parameters
        'step_size': 30,      # Period of learning rate decay for StepLR
        'gamma': 0.5,         # Multiplicative factor of learning rate decay
        
        # ReduceLROnPlateau parameters
        'patience_lr': 10,    # Number of epochs with no improvement after which LR will be reduced
        'factor': 0.5,        # Factor by which the learning rate will be reduced
        'min_lr': 1e-7,       # Lower bound on the learning rate

        
        
        # System config
        'device': 0,
        'num_checkpoints': 10
    }
    return argparse.Namespace(**args)

In [None]:
def populate_args(args):
    print("Arguments received:")
    for key, value in vars(args).items():
        print(f"{key}: {value}")
args = get_arguments()
populate_args(args)

In [None]:
class NoisyCrossEntropyLoss(torch.nn.Module):
    def __init__(self, p_noisy):
        super().__init__()
        self.p = p_noisy
        self.ce = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, targets):
        losses = self.ce(logits, targets)
        weights = (1 - self.p) + self.p * (1 - torch.nn.functional.one_hot(targets, num_classes=logits.size(1)).float().sum(dim=1))
        return (losses * weights).mean()

GCOD vecchia

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCODLoss(nn.Module):
    """
    Graph Centroid Outlier Discounting (GCOD) Loss Function
    Based on the NCOD method adapted for graph classification.
    The model parameters (theta) are updated using L1 + L3.
    The sample-specific parameters (u) are updated using L2.
    """
    def __init__(self, num_classes, alpha_train=0.01):
        """
        Args:
            num_classes (int): Number of classes.
            alpha_train (float): Corresponds to lambda_p in args, coefficient for the
                                 feedback term in L1.
        """
        super(GCODLoss, self).__init__()
        self.num_classes = num_classes
        self.alpha_train = alpha_train
        self.ce_loss = nn.CrossEntropyLoss(reduction='none') # for per-sample CE

    def _ensure_u_shape(self, u_params, batch_size, target_ndim):
        """Helper to ensure u_params has the correct shape for operations."""
        if u_params.shape[0] != batch_size:
            raise ValueError(f"u_params batch dimension {u_params.shape[0]} does not match expected batch_size {batch_size}")

        if target_ndim == 1: # Expected shape [batch_size]
            return u_params.squeeze() if u_params.ndim > 1 else u_params
        elif target_ndim == 2: # Expected shape [batch_size, 1]
            return u_params.unsqueeze(1) if u_params.ndim == 1 else u_params
        return u_params


    def compute_L1(self, logits, targets, u_params):
        """
        Computes L1 = CE(f_θ(Z_B)) + α_train * u_B * (y_B ⋅ ỹ_B)
        Args:
            logits (Tensor): Model output logits, shape [batch_size, num_classes].
            targets (Tensor): Ground truth labels, shape [batch_size].
            u_params (Tensor): Per-sample u values for the batch, shape [batch_size] or [batch_size, 1].
        Returns:
            Tensor: Scalar L1 loss for the batch.
        """
        batch_size = logits.size(0)
        if batch_size == 0:
            # Corrected line:
            return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad)

        y_onehot = F.one_hot(targets, num_classes=self.num_classes).float()
        y_soft = F.softmax(logits, dim=1)

        ce_loss_values = self.ce_loss(logits, targets) # Shape: [batch_size]

        current_u_params = self._ensure_u_shape(u_params, batch_size, target_ndim=1)

        feedback_term_values = self.alpha_train * current_u_params * (y_onehot * y_soft).sum(dim=1) # Shape: [batch_size]

        L1 = ce_loss_values + feedback_term_values
        return L1.mean()

    def compute_L2(self, logits, targets, u_params):
        """
        Computes L2 = (1/|C|) * ||ỹ_B + u_B * y_B - y_B||²
        Args:
            logits (Tensor): Model output logits, shape [batch_size, num_classes].
            targets (Tensor): Ground truth labels, shape [batch_size].
            u_params (Tensor): Per-sample u values for the batch, shape [batch_size] or [batch_size, 1].
        Returns:
            Tensor: Scalar L2 loss for the batch.
        """
        batch_size = logits.size(0)
        if batch_size == 0:
            # Corrected line:
            return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad)

        y_onehot = F.one_hot(targets, num_classes=self.num_classes).float()
        y_soft = F.softmax(logits, dim=1)

        current_u_params_unsqueezed = self._ensure_u_shape(u_params, batch_size, target_ndim=2)

        term = y_soft + current_u_params_unsqueezed * y_onehot - y_onehot # Shape: [batch_size, num_classes]

        # L2 norm squared of the matrix 'term', then scaled
        L2 = (1.0 / self.num_classes) * torch.norm(term, p='fro').pow(2) # Frobenius norm for matrix
        return L2

    def compute_L3(self, logits, targets, u_params, l3_coeff):
        """
        Computes L3 = l3_coeff * D_KL(L || σ(-log(u_B)))
                     where l3_coeff = (1 - training_accuracy)
                     and L = log(σ(logit_true_class)) are log-probabilities
                     and σ(-log(u_B)) are probabilities
        Args:
            logits (Tensor): Model output logits, shape [batch_size, num_classes].
            targets (Tensor): Ground truth labels, shape [batch_size].
            u_params (Tensor): Per-sample u values for the batch, shape [batch_size] or [batch_size, 1].
            l3_coeff (float): Coefficient for the KL divergence term, e.g., (1 - training_accuracy).
        Returns:
            Tensor: Scalar L3 loss for the batch.
        """
        batch_size = logits.size(0)
        if batch_size == 0:
            # Corrected line:
            return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad)

        y_onehot = F.one_hot(targets, num_classes=self.num_classes).float()

        # Logit of the true class for each sample in the batch
        diag_elements = (logits * y_onehot).sum(dim=1) # Shape: [batch_size]

        # L_log_probs = log(sigma(true_class_logit)) which are log-probabilities
        L_log_probs = F.logsigmoid(diag_elements) # Shape: [batch_size]

        current_u_params = self._ensure_u_shape(u_params, batch_size, target_ndim=1)

        # target_probs_for_kl = sigma(-log(u_B)) which are probabilities
        target_probs_for_kl = torch.sigmoid(-torch.log(current_u_params + 1e-8)) # Shape: [batch_size]

        # F.kl_div expects input (L_log_probs) as log-probabilities and target (target_probs_for_kl) as probabilities.
        # reduction='mean' averages the loss over all elements in the batch.
        # log_target=False means target_probs_for_kl are probabilities, not log-probabilities.
        kl_div = F.kl_div(L_log_probs, target_probs_for_kl, reduction='mean', log_target=False)

        L3 = l3_coeff * kl_div
        return L3

    def forward(self, logits, targets, u_params, training_accuracy):
        """
        Calculates the GCOD loss components.
        The main loss for model (theta) update is L1 + L3.
        L2 is primarily used for updating u_params (called separately).
        Args:
            logits (Tensor): Model output logits.
            targets (Tensor): Ground truth labels.
            u_params (Tensor): Per-sample u values for the batch.
            training_accuracy (float): The actual training accuracy (value between 0 and 1)
                                     for the current batch or epoch.
        Returns:
            tuple: (total_loss_for_theta, L1, L2, L3)
                   total_loss_for_theta = L1 + L3
        """
        calculated_L1 = self.compute_L1(logits, targets, u_params)
        # L2 is calculated here mainly for complete reporting if needed,
        # but the train loop will call compute_L2 separately for u-optimization.
        calculated_L2 = self.compute_L2(logits, targets, u_params)

        l3_coefficient = (1.0 - training_accuracy) # As per GCOD paper (1 - alpha_train where alpha_train is accuracy)
        calculated_L3 = self.compute_L3(logits, targets, u_params, l3_coefficient)

        total_loss_for_theta = calculated_L1 + calculated_L3

        return total_loss_for_theta, calculated_L1, calculated_L2, calculated_L3

In [None]:
print("=" * 60)
print("Enhanced GNN Training Pipeline")
print("=" * 60)

# Get configuration
args = get_arguments()

print("\nConfiguration:")
for key, value in vars(args).items():
    print(f"  {key}: {value}")

# Setup device
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

In [None]:
print("\\n" + "="*40)
print("LOADING DATA")
print("="*40)

base_path = '/kaggle/input/deep-dataset-preprocessed/processed_data_separate'

# Prepare training/validation data based on mode
if args.train_mode == 1:
    # Single dataset mode
    dataset_name = args.dataset
    
    loaded_train_graphs = torch.load(f'{base_path}/{dataset_name}_train_graphs.pt', weights_only=False)
    train_dataset_with_indices = []
    for i, data_item in enumerate(loaded_train_graphs):
        data_item = add_zeros(data_item)
        data_item.original_idx = i # Store original index
        train_dataset_with_indices.append(data_item)
    train_dataset = train_dataset_with_indices
    
    loaded_val_graphs = torch.load(f'{base_path}/{dataset_name}_val_graphs.pt', weights_only=False)
    val_dataset_with_indices = []
    for i, data_item in enumerate(loaded_val_graphs):
        data_item = add_zeros(data_item)
        data_item.original_idx = i # Store original index for val set too if needed by u_params in eval
        val_dataset_with_indices.append(data_item)
    val_dataset = val_dataset_with_indices
    
    # Test dataset usually doesn't need original_idx for u_params update
    test_dataset = torch.load(f'{base_path}/{dataset_name}_test_graphs.pt', weights_only=False)
    test_dataset = [add_zeros(data) for data in test_dataset] # original_idx not strictly needed for test
    
    print(f"Using single dataset: {dataset_name}")
else:
    # All datasets mode
    train_dataset_with_indices = []
    val_dataset_with_indices = []
    current_train_idx = 0
    current_val_idx = 0

    for ds_name in ['A', 'B', 'C', 'D']:
        loaded_train_ds = torch.load(f'{base_path}/{ds_name}_train_graphs.pt', weights_only=False)
        for data_item in loaded_train_ds:
            data_item = add_zeros(data_item)
            data_item.original_idx = current_train_idx
            train_dataset_with_indices.append(data_item)
            current_train_idx += 1
            
        loaded_val_ds = torch.load(f'{base_path}/{ds_name}_val_graphs.pt', weights_only=False)
        for data_item in loaded_val_ds:
            data_item = add_zeros(data_item)
            data_item.original_idx = current_val_idx
            val_dataset_with_indices.append(data_item)
            current_val_idx +=1
            
    train_dataset = train_dataset_with_indices
    val_dataset = val_dataset_with_indices
    
    test_dataset = torch.load(f'{base_path}/{args.dataset}_test_graphs.pt', weights_only=False)
    test_dataset = [add_zeros(data) for data in test_dataset] # original_idx not strictly needed for test
    print("Using all datasets for training")

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)

Model setup

Scheduler setup

In [None]:
print("\n" + "="*40)
print("MODEL SETUP")
print("="*40)

# Initialize model
model = GNN(num_class=6, # Assuming 6 classes based on original notebook
            num_layer=args.num_layer,
            emb_dim=args.emb_dim,
            drop_ratio=args.drop_ratio,
            virtual_node=args.virtual_node,
            residual=args.residual,
            JK=args.JK,
            graph_pooling=args.graph_pooling,
            edge_drop_ratio = args.edge_drop_ratio,
            batch_norm=args.batch_norm
           )

model = model.to(device)

# Setup optimizer and loss
#optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

if args.baseline_mode == 2:
    criterion = NoisyCrossEntropyLoss(args.noise_prob)
    print(f"Using Noisy Cross Entropy Loss (p={args.noise_prob})")
elif args.baseline_mode == 3: # <--- ADD THIS BLOCK FOR GCE
    criterion = GeneralizedCrossEntropyLoss(q=args.gce_q)
    print(f"Using Generalized Cross Entropy (GCE) Loss (q={args.gce_q})")
elif args.baseline_mode == 4: # GCOD Loss
    criterion = GCODLoss(
        num_classes=6, # Assuming 6 classes
        alpha_train=args.gcod_lambda_p # Map gcod_lambda_p to alpha_train
    )
    # Updated print statement to reflect lambda_r is used
    print(f"Using GCOD Loss. Effective alpha_train (lambda_p for L1)={args.gcod_lambda_p}, "
          f"T_u={args.gcod_T_u}, lr_u={args.gcod_lr_u}")
else:
    criterion = torch.nn.CrossEntropyLoss()
    print("Using standard Cross Entropy Loss")

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Setup logging and checkpoints
#exp_name = f"{args.gnn}_dataset{args.dataset}_mode{args.train_mode}"
exp_name = f"gin_dataset{args.dataset}_mode{args.train_mode}"
logs_dir = os.path.join("logs", exp_name)
checkpoints_dir = os.path.join("checkpoints", exp_name)
os.makedirs(logs_dir, exist_ok=True)
os.makedirs(checkpoints_dir, exist_ok=True)

# Setup logging
log_file = os.path.join(logs_dir, "training.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

#best_model_path = os.path.join(checkpoints_dir, "best_model.pth")
best_model_path = '/kaggle/working/checkpoints/best_model.pth'

In [None]:
# Learning Rate Scheduler Setup
print("\n" + "="*40)
print("SCHEDULER SETUP")
print("="*40)

# Update optimizer with initial learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)

scheduler = None
if args.use_scheduler:
    if args.scheduler_type == 'StepLR':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, 
            step_size=args.step_size, 
            gamma=args.gamma
        )
        print(f"Using StepLR scheduler: step_size={args.step_size}, gamma={args.gamma}")
        
    elif args.scheduler_type == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='max',  # We want to reduce LR when validation loss stops decreasing
            factor=args.factor,
            patience=args.patience_lr,
            min_lr=args.min_lr,
        )
        print(f"Using ReduceLROnPlateau scheduler: factor={args.factor}, patience={args.patience_lr}, min_lr={args.min_lr}")
        
    else:
        print(f"Unknown scheduler type: {args.scheduler_type}. No scheduler will be used.")
        args.use_scheduler = False
else:
    print("No learning rate scheduler will be used.")

print(f"Initial learning rate: {args.initial_lr}")

Train loop

In [None]:
print("\\n" + "="*40)
print("TRAINING")
print("="*40)

# Initialize u_values_for_train here, after train_dataset is fully formed (from Step 1)
u_values_for_train = None
if args.baseline_mode == 4:
    if 'train_dataset' in locals() and len(train_dataset) > 0 : # Check if train_dataset exists
        u_values_for_train = torch.zeros(len(train_dataset), device=device, requires_grad=False) # on main device
        print(f"Initialized u_values_for_train for GCOD with size: {u_values_for_train.size()}")
    else:
        print("Warning: train_dataset not found or empty when trying to initialize u_values_for_train for GCOD.")

best_val_accuracy = 0.0
train_losses_list = [] # Renamed
train_accuracies_list = [] # Renamed
val_losses_list = [] # Renamed
val_accuracies_list = [] # Renamed
learning_rates = []

# Early stopping variables
if args.early_stopping:
    epochs_without_improvement = 0
    print(f"Early stopping enabled with patience: {args.patience}")
else:
    print("Early stopping disabled")

# Calculate checkpoint intervals
if args.num_checkpoints > 1:
    checkpoint_intervals = [int((i + 1) * args.epochs / args.num_checkpoints) 
                          for i in range(args.num_checkpoints)]
else:
    checkpoint_intervals = [args.epochs]

for epoch in range(args.epochs):
    print(f"\\nEpoch {epoch + 1}/{args.epochs}")
    print("-" * 30)

    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)
    
    # Training
    train_loss, train_acc = train(
        train_loader, model, optimizer, criterion, device,
        save_checkpoints=(epoch + 1 in checkpoint_intervals), # This was inside train, moved out
        checkpoint_path=os.path.join(checkpoints_dir, "checkpoint"),
        current_epoch=epoch,
        args_namespace=args, # Pass the whole args namespace
        u_values_global=u_values_for_train, # Pass global u_values
        current_baseline_mode=args.baseline_mode, # Pass baseline_mode
        scheduler=args.scheduler_type
    )
    
    # Validation
    val_loss, val_acc = evaluate(
        val_loader, model, criterion, device, calculate_accuracy=True,
        args_namespace=args, # Pass args for baseline_mode check
        u_values_global_eval=None # Pass u_values_for_val if you implement using them, else dummy zeros are used
    )    
    # Log results
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    print(f"Learning Rate: {current_lr:.2e}")
    
    
    logging.info(f"Epoch {epoch + 1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
                f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, LR={current_lr:.2e}")
    
    # Store metrics (existing code using new list names)
    train_losses_list.append(train_loss)
    train_accuracies_list.append(train_acc)
    val_losses_list.append(val_loss)
    val_accuracies_list.append(val_acc)
    
    # Save best model
    if val_acc > best_val_accuracy:
        best_val_accuracy = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"★ New best model saved! Val Acc: {val_acc:.4f}")


        # Reset early stopping counter
        if args.early_stopping:
            epochs_without_improvement = 0

    else:
        # No improvement
        if args.early_stopping:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")
            
            # Check if we should stop early
            if epochs_without_improvement >= args.patience:
                print(f"\nEarly stopping triggered! No improvement for {args.patience} epochs.")
                print(f"Best validation accuracy: {best_val_accuracy:.4f}")
                break

    # Learning rate scheduler step
    if scheduler is not None:
        if args.scheduler_type == 'ReduceLROnPlateau':
            # ReduceLROnPlateau needs the metric to monitor
            scheduler.step(val_acc)
        elif args.scheduler_type == 'OneCycleLR':
            # OneCycleLR steps every batch, not every epoch
            # This is handled in the training function
            pass
        else:
            # Other schedulers step every epoch
            scheduler.step()
        
        # Check if learning rate changed
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr != current_lr:
            print(f"Learning rate changed: {current_lr:.2e} → {new_lr:.2e}")
    
print(f"\nBest validation accuracy: {best_val_accuracy:.4f}")

Plot

In [None]:
plot_training_progress(train_losses_list, train_accuracies_list, val_losses_list, val_accuracies_list, logs_dir)

Test

In [None]:
print("\n" + "="*40)
print("TESTING")
print("="*40)

# Load best model and make predictions
model.load_state_dict(torch.load(best_model_path))
print(f"Loaded best model from: {best_model_path}")

#predictions = evaluate(test_loader, model, criterion, device, calculate_accuracy=False)

# Save predictions
predictions = evaluate(
    test_loader, model, criterion, device, calculate_accuracy=False,
    args_namespace=args, # For consistency, though not used if calculate_accuracy=False and GCOD loss not computed
    u_values_global_eval=None # Not relevant for test set predictions without loss
)

# Save predictions
save_predictions(predictions,f"/kaggle/working/submission/testset_{args.dataset}.json")

# Cleanup for memory
del train_dataset, val_dataset, test_dataset
del train_loader, val_loader, test_loader
gc.collect()

print("\n" + "="*60)
print("TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"Best validation accuracy: {best_val_accuracy:.4f}")
print(f"Predictions saved for dataset {args.dataset}")
print(f"Logs and plots saved in: {logs_dir}")