# Install requirements

In [None]:
!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
!pip install torch_geometric

# Clone Git Repository

In [None]:
!git clone https://github.com/matzamp/DL_Hackaton.git
%cd DL_Hackaton/

# Download dataset from drive

In [None]:
!gdown --folder https://drive.google.com/drive/folders/1Z-1JkPJ6q4C6jX4brvq1VRbJH5RPUCAk -O datasets

# Imports

In [None]:
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from torch.utils.data import random_split
import torch.nn as nn
# Load utility functions from cloned repository
from source.loadData import GraphDataset
from source.utils import *
from source.models import EdgeCentricGNNImproved, GINEClassifier
from source.noisy_loss import NoisyCrossEntropyLoss
import torch_scatter
import argparse
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import f1_score

# Set the random seed
set_seed()

# Hyperparameters Setup

In [None]:
# hyperparameter
device = 1
dropout = 0.3
batch_size = 32
epochs = 150
lr = 0.001
node_dim = 1 
hidden_dim = 64
edge_dim = 7
num_classes = 6
num_gnn_blocks=1
num_gine_layers_per_block=1
use_dense_skip=True
use_global_transformer=False
num_transformer_layers=1
num_transformer_heads=4

# Get training/test paths

In [None]:
args = get_arguments()
populate_args(args)

In [None]:
if args.train_path != None:
    is_zipped_train = is_gzipped_folder(args.train_path)
is_zipped_test = is_gzipped_folder(args.test_path)

# Load train data

In [None]:
if args.train_path:
    full_dataset = GraphDataset(args.train_path, transform=add_ones, is_zipped = is_zipped_train)
    val_size = int(0.2 * len(full_dataset))
    train_size = len(full_dataset) - val_size

# Setup folders for logs and checkpoints

In [None]:
script_dir = os.getcwd() 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Determine dataset_id
# Prioritize train_path for ID extraction if available (typical for training runs)
# Otherwise, use test_path (e.g., for evaluation-only runs)
path_for_id_extraction = args.train_path if args.train_path else args.test_path
if not path_for_id_extraction:
    print("CRITICAL ERROR: Neither train_path nor test_path is defined. Cannot determine dataset_id.")
    dataset_id = "error_no_path" # This should ideally not happen
else:
    dataset_id = get_dataset_id_from_path(path_for_id_extraction)

print(f"Determined Dataset ID: {dataset_id} (derived from path: {path_for_id_extraction})")
if dataset_id.startswith("unknown") or dataset_id == "error_no_path":
    print(f"CRITICAL WARNING: Dataset ID extraction failed or yielded an issue: '{dataset_id}'. Insert manually the dataset ID.")
    args_dataset_id = get_dataset_type()
    populate_args(args_dataset_id)
    dataset_id = args_dataset_id.dataset_id

In [None]:
logs_folder = os.path.join(script_dir, "logs", dataset_id)
log_file = os.path.join(logs_folder, "training.log")
os.makedirs(logs_folder, exist_ok=True)
os.makedirs(os.path.dirname(log_file), exist_ok=True)

# Get the root logger
logger = logging.getLogger()

# Remove all existing handlers from the root logger
# This ensures a clean setup if the cell is re-executed
for handler in logger.handlers[:]:
    logger.removeHandler(handler)
    handler.close() # It's good practice to close handlers when removing

# Now, configure logging as before
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - %(message)s', filemode='a')
logger.addHandler(logging.StreamHandler()) # This will now add only one StreamHandler

checkpoints_folder = os.path.join(script_dir, "checkpoints")
os.makedirs(checkpoints_folder, exist_ok=True)

# Choose noise and model based on the dataset 

For the noise if the dataset is A or C we set it to 0.2
else (B or D) 0.4

The light Model is used for A and B, 
The more "complex" model is used for C and D

In [None]:
p_noisy = 0.2 if dataset_id in ["A", "C"] else 0.4
use_lightModel = True if dataset_id in ["A", "B"] else False
criterion_robust = NoisyCrossEntropyLoss(p_noisy)

# Train and Evaluation functions

In [None]:
def train(data_loader, model, optimizer,
                          criterion, device, checkpoint_path_base, epoch,
                          save_intermediate_checkpoints_at_epoch=None):

    model.train()

    correct, total_loss, total_samples = 0, 0, 0
    all_preds, all_labels = [], []
    
    for data in tqdm(data_loader, desc=f"Epoch {epoch+1} Training", unit="batch"):
        data = data.to(device)
        
        optimizer.zero_grad()

        out = model(data)

        loss = criterion(out, data.y)
        total_loss += loss.item()

        loss.backward()
    
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
    
        # Compute accuracy
        preds = out.argmax(dim=1)
        
        correct += (preds == data.y).sum().item()
        total_samples += data.y.size(0)
        
        # Store for F1 calculation
        all_labels.extend(data.y.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
            
    
    # Calculate metrics
    avg_loss = total_loss / len(data_loader) if len(data_loader) > 0 else 0
    
    acc = correct / total_samples if total_samples > 0 else 0
    
    f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)

    if save_intermediate_checkpoints_at_epoch and epoch + 1 == save_intermediate_checkpoints_at_epoch:
        torch.save(model.state_dict(), f"{checkpoint_path_base}_epoch{epoch+1}.pth")
        print(f"Saved intermediate checkpoints for epoch {epoch+1}")
    
    return avg_loss, acc, f1

In [None]:
def evaluate(data_loader, model, device, criterion_eval, calculate_metrics=False):
    model.eval()
    
    all_preds, all_labels = [], []
    total_loss = 0.0 
    correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Evaluating", unit="batch"):
            data = data.to(device)
            
            out = model(data)

            pred = out.argmax(dim=1)
            
            if calculate_metrics:
                all_labels.extend(data.y.cpu().numpy())
                all_preds.extend(pred.cpu().numpy())
                
                correct += (pred == data.y).sum().item()
                total_samples += data.y.size(0)
                
                batch_losses = criterion_eval(out, data.y) 
                
                # If batch_losses is unexpectedly scalar (e.g. if criterion_eval somehow still has reduction='mean')
                if batch_losses.ndim == 0: # batch_losses is scalar
                    total_loss += batch_losses.item() * data.y.size(0) # Estimate batch sum
                else: # batch_losses is 1D tensor of per-sample losses
                    total_loss += batch_losses.sum().item()
                
            else: 
                all_preds.extend(pred.cpu().numpy())

    if calculate_metrics:
        avg_loss = total_loss / total_samples if total_samples > 0 else 0.0
        accuracy = correct / total_samples if total_samples > 0 else 0.0
        f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) if total_samples > 0 else 0.0
        return avg_loss, accuracy, f1
    
    return all_preds

In [None]:
 # Initialize weights properly
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

# Instantiate the model and main train loop (if train path is provided)

In [None]:
if use_lightModel:
    model = GINEClassifier(
        node_feat_dim=node_dim,
        edge_feat_dim=edge_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes,
        num_gnn_blocks=num_gnn_blocks,
        num_gine_layers_per_block=num_gine_layers_per_block,
        gnn_dropout=dropout,
        use_dense_skip=use_dense_skip,
        use_global_transformer=use_global_transformer,
        num_transformer_layers=num_transformer_layers,
        num_transformer_heads=num_transformer_heads,
        transformer_dropout=dropout,
        classifier_dropout=dropout
    ).to(device)
else:
    model = EdgeCentricGNNImproved(
        node_dim=node_dim,
        hidden_dim=hidden_dim,
        output_dim=num_classes,
        edge_dim=edge_dim,
        dropout=dropout
    ).to(device)

if args.train_path:
    generator = torch.Generator().manual_seed(12)
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=lr,
        weight_decay=1e-4,
        eps=1e-8
    )
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True, min_lr=1e-6)

    model.apply(init_weights)
    
    num_epochs_to_run = epochs # Use the hyperparameter
    best_f1_accuracy = 0.0   

    train_losses_hist, train_accs_hist, train_f1s_hist = [], [], []
    val_losses_hist, val_accs_hist, val_f1s_hist = [], [], []

    intermediate_checkpoint_base = os.path.join(checkpoints_folder, f"model_{dataset_id}")
    
    save_intermediate_epochs = []
    
    logging.info(f"Will save intermediate checkpoints at epochs: {save_intermediate_epochs}")

    for epoch_idx in range(num_epochs_to_run):
      
        save_at_this_epoch = epoch_idx + 1 if (epoch_idx + 1) in save_intermediate_epochs else None

        train_loss, train_acc, train_f1 = train(
            train_loader, model, optimizer,
            criterion_robust, device,
            intermediate_checkpoint_base,
            epoch_idx,
            save_intermediate_checkpoints_at_epoch=save_at_this_epoch
        )

        val_loss, val_acc, val_f1 = evaluate(val_loader, model, device, criterion_robust, calculate_metrics=True)

        scheduler.step(val_f1)
        
        
        train_losses_hist.append(train_loss)
        train_accs_hist.append(train_acc)
        train_f1s_hist.append(train_f1)

        val_losses_hist.append(val_loss)
        val_accs_hist.append(val_acc)
        val_f1s_hist.append(val_f1)

        # --- Main logging (every epoch) ---
        log_msg = (f"Epoch {epoch_idx + 1}/{num_epochs_to_run} | "
                   f"Train Loss: {train_loss:.4f} | "
                   f"Train Acc: {train_acc:.4f} | "
                   f"Train F1: {train_f1:.4f} | "
                   f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
        # REMOVE direct print if you fixed duplicated logging: print(log_msg) 
        logging.info(log_msg) # This goes to main training.log and console
        
        # --- 10-Epoch Summary Logging ---
        current_epoch_num = epoch_idx + 1
        if current_epoch_num % 10 == 0 or current_epoch_num == num_epochs_to_run:
            if current_epoch_num == num_epochs_to_run and current_epoch_num % 10 != 0:
                # Handle the last segment if it's not a full 10 epochs
                range_start = ( (current_epoch_num -1) // 10) * 10 + 1
                range_end = current_epoch_num
            else: # For multiples of 10, or if the last epoch is a multiple of 10
                range_end = current_epoch_num
                range_start = range_end - 9

            # logs_folder is defined in Cell 15 (e.g., logs/A/)
            interval_log_filename = f"epoch_{range_start:03d}-{range_end:03d}_summary.log"
            interval_log_filepath = os.path.join(logs_folder, interval_log_filename)

            summary_log_msg = (f"SUMMARY for Epoch {current_epoch_num} (Interval {range_start}-{range_end}) | "
                               f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | "
                               f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")

            try:
                with open(interval_log_filepath, 'w') as f_interval_summary: # Use 'w' to overwrite/create
                    import datetime # Make sure datetime is imported if not globally available
                    timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
                    f_interval_summary.write(f"{timestamp} - {summary_log_msg}\n")
                logging.info(f"Saved 10-epoch interval summary to: {interval_log_filepath}")
            except Exception as e:
                logging.error(f"Failed to write to interval summary log {interval_log_filepath}: {e}")
        
        if val_f1 > best_f1_accuracy:
            best_f1_accuracy = val_f1
            best_model_path = os.path.join(checkpoints_folder, f"model_{dataset_id}_epoch_{epoch_idx}.pth")
            torch.save(model.state_dict(), best_model_path)
            best_save_msg = f"Best models (val_f1: {best_f1_accuracy:.4f}) saved to {best_model_path}"
            # REMOVE direct print if you fixed duplicated logging: print(best_save_msg)
            logging.info(best_save_msg)

    plot_output_dir = os.path.join(logs_folder, "plots") # logs_folder is dataset_id specific
    os.makedirs(plot_output_dir, exist_ok=True)
    
    plot_training_progress(train_losses_hist, train_accs_hist, train_f1s_hist,
                           val_losses_hist, val_accs_hist, val_f1s_hist,
                           plot_output_dir, 
                           plot_title_prefix=f"Dataset {dataset_id} Combined Training") # Updated title
    
else:
    logging.info("No training path provided. Skipping training.")

# Inference function

In [None]:
def inference(data_loader, model, device):
    model.eval()
    correct = 0
    total = 0
    predictions = []
    total_loss = 0
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Iterating eval graphs", unit="batch"):
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            predictions.extend(pred.cpu().numpy())
    return predictions

# Garbage collector

In [None]:
import gc
try:
    del train_dataset
    del train_loader
    del full_dataset
    del val_dataset
    del val_loader
except:
    print("Error: skipping del")
gc.collect()

# Load test data, best model and predict the test data

In [None]:
test_dataset = GraphDataset(args.test_path, transform=add_ones, is_zipped = is_zipped_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
if dataset_id == "A":
    best_model_path = "./checkpoints/model_A_epoch_132.pth"
elif dataset_id == "B":
    best_model_path = "./checkpoints/model_B_epoch_77.pth"
elif dataset_id == "C":
    best_model_path = "./checkpoints/model_C_epoch_57.pth"
elif dataset_id == "D":
    best_model_path = "./checkpoints/model_D_epoch_73.pth"
else:
    print("Error: dataset_id not defined")

In [None]:
model.load_state_dict(torch.load(best_model_path))
predictions = inference(test_loader, model, device)
save_predictions(predictions, args.test_path, dataset_id)