# Imports and Settings

In [6]:
import sys
import os
import argparse
import torch
import torch.nn.functional as F
import pandas as pd
import matplotlib.pyplot as plt
import logging
from tqdm import tqdm

from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv as PyG_GINConv, global_add_pool

#Set up path to the src directory for import of pre-made teacher modules
src_path = os.path.join(os.path.abspath('.'), 'hackaton', 'src')
sys.path.insert(0, src_path)

try:
    from loadData import GraphDataset
    from utils import set_seed
    # from conv import GINConv as OriginalRepoGINConv
    # from models import GNN as OriginalRepoGNN
    print("Successfully imported GraphDataset and set_seed.")
except ImportError as e:
    print(f"ERROR importing module: {e}")
    print("Please check that the .py files exist and have no syntax errors.")
    # print("Current sys.path:", sys.path)


# Call set_seed early
set_seed(42) # You can change 

Successfully imported GraphDataset and set_seed.


# Useful functions

In [2]:
def add_zeros(data):
    data.x = torch.zeros(data.num_nodes, dtype=torch.long)
    return data

In [None]:
class ImprovedGIN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_gnn_layers=3, dropout_rate=0.5):
        super(ImprovedGIN, self).__init__()
        # Embedding for nodes if data.x comes from add_zeros (all zeros)
        # This means we learn a single feature vector for all nodes initially.
        self.embedding = torch.nn.Embedding(num_embeddings=1, embedding_dim=input_dim)
        self.dropout_rate = dropout_rate
        self.num_gnn_layers = num_gnn_layers

        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        # Initial GIN layer's MLP: maps from input_dim (from embedding) to hidden_dim
        # PyG_GINConv takes an nn.Sequential (MLP) as its first argument.
        # The MLP's input dim should be input_dim, output dim can be hidden_dim.
        mlp_initial = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim) # GIN paper suggests MLP output dim == GINConv output dim
        )
        self.convs.append(PyG_GINConv(mlp_initial, train_eps=True)) # train_eps=True is common for GIN
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))

        # Subsequent GIN layers: map from hidden_dim to hidden_dim
        for _ in range(1, self.num_gnn_layers):
            mlp = torch.nn.Sequential(
                torch.nn.Linear(hidden_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(PyG_GINConv(mlp, train_eps=True))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))

        self.global_pool = global_add_pool # Using global_add_pool

        # Classifier: A 2-layer MLP
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = torch.nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.embedding(x) # Shape: (num_nodes, input_dim)
        
        residual_x = None # To store the output for residual connection

        for i in range(self.num_gnn_layers):
            x_conv_input = x
            # For GIN with train_eps=True, the original node features are added internally by the layer
            # based on (1+eps)*x_orig + aggregated_neighbors.
            # So the residual connection should ideally be applied after the GINConv operation if mimicking GIN paper's layer.
            # Or, a simpler skip connection: input to layer i is output of layer i-1.
            
            x = self.convs[i](x_conv_input, edge_index)
            x = self.batch_norms[i](x)
            
            # Add residual connection
            if residual_x is not None and i > 0 : # Add output of previous layer (after activation and dropout)
                 x = x + residual_x
            
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)
            residual_x = x # Store for next layer's residual connection

        x_pooled = self.global_pool(x, batch)

        x_pooled = F.dropout(x_pooled, p=self.dropout_rate, training=self.training)
        x_fc1 = self.fc1(x_pooled)
        x_relu_fc1 = F.relu(x_fc1)
        x_dropout_fc1 = F.dropout(x_relu_fc1, p=self.dropout_rate, training=self.training)
        out = self.fc2(x_dropout_fc1)
        
        return out

print("ImprovedGIN model class defined.")

In [None]:
class ImprovedGIN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_gnn_layers=3, dropout_rate=0.5):
        super(ImprovedGIN, self).__init__()
        # Embedding for nodes if data.x comes from add_zeros (all zeros)
        # This means we learn a single feature vector for all nodes initially.
        self.embedding = torch.nn.Embedding(num_embeddings=1, embedding_dim=input_dim)
        self.dropout_rate = dropout_rate
        self.num_gnn_layers = num_gnn_layers

        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        # Initial GIN layer's MLP: maps from input_dim (from embedding) to hidden_dim
        # PyG_GINConv takes an nn.Sequential (MLP) as its first argument.
        # The MLP's input dim should be input_dim, output dim can be hidden_dim.
        mlp_initial = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim) # GIN paper suggests MLP output dim == GINConv output dim
        )
        self.convs.append(PyG_GINConv(mlp_initial, train_eps=True)) # train_eps=True is common for GIN
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))

        # Subsequent GIN layers: map from hidden_dim to hidden_dim
        for _ in range(1, self.num_gnn_layers):
            mlp = torch.nn.Sequential(
                torch.nn.Linear(hidden_dim, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, hidden_dim)
            )
            self.convs.append(PyG_GINConv(mlp, train_eps=True))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_dim))

        self.global_pool = global_add_pool # Using global_add_pool

        # Classifier: A 2-layer MLP
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim // 2)
        self.fc2 = torch.nn.Linear(hidden_dim // 2, output_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        x = self.embedding(x) # Shape: (num_nodes, input_dim)
        
        residual_x = None # To store the output for residual connection

        for i in range(self.num_gnn_layers):
            x_conv_input = x
            # For GIN with train_eps=True, the original node features are added internally by the layer
            # based on (1+eps)*x_orig + aggregated_neighbors.
            # So the residual connection should ideally be applied after the GINConv operation if mimicking GIN paper's layer.
            # Or, a simpler skip connection: input to layer i is output of layer i-1.
            
            x = self.convs[i](x_conv_input, edge_index)
            x = self.batch_norms[i](x)
            
            # Add residual connection
            if residual_x is not None and i > 0 : # Add output of previous layer (after activation and dropout)
                 x = x + residual_x
            
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout_rate, training=self.training)
            residual_x = x # Store for next layer's residual connection

        x_pooled = self.global_pool(x, batch)

        x_pooled = F.dropout(x_pooled, p=self.dropout_rate, training=self.training)
        x_fc1 = self.fc1(x_pooled)
        x_relu_fc1 = F.relu(x_fc1)
        x_dropout_fc1 = F.dropout(x_relu_fc1, p=self.dropout_rate, training=self.training)
        out = self.fc2(x_dropout_fc1)
        
        return out

print("ImprovedGIN model class defined.")

# Core functions

In [None]:
def train(data_loader, model, optimizer, criterion, device, 
          save_checkpoints_periodically, periodic_checkpoint_dir_base, 
          dataset_name_for_checkpoint, current_epoch):
    model.train()
    total_loss = 0
    correct_preds = 0
    total_samples = 0

    for data in tqdm(data_loader, desc=f"Epoch {current_epoch+1} Training", unit="batch", leave=False):
        data = data.to(device)
        if data.y is None: # Skip if no labels (should not happen in training)
            continue
        if data.y.numel() == 0: # Skip if batch is empty or labels are empty
            continue

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, data.y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * data.num_graphs # Loss is mean over batch, scale by batch size
        pred = output.argmax(dim=1)
        correct_preds += (pred == data.y).sum().item()
        total_samples += data.y.size(0)

    avg_loss = total_loss / total_samples if total_samples > 0 else 0
    accuracy = correct_preds / total_samples if total_samples > 0 else 0

    if save_checkpoints_periodically:
        # e.g. checkpoints/A/model_A_epoch_10.pth
        final_periodic_checkpoint_dir = os.path.join(periodic_checkpoint_dir_base, dataset_name_for_checkpoint)
        os.makedirs(final_periodic_checkpoint_dir, exist_ok=True)
        checkpoint_file = os.path.join(final_periodic_checkpoint_dir, f"model_{dataset_name_for_checkpoint}_epoch_{current_epoch + 1}.pth")
        torch.save(model.state_dict(), checkpoint_file)
        logging.info(f"Periodic checkpoint saved: {checkpoint_file}")

    return avg_loss, accuracy

print("train function defined.")

In [None]:
def evaluate(data_loader, model, device, criterion_for_loss=None, calculate_accuracy_and_loss=False):
    model.eval()
    correct_preds = 0
    total_samples = 0
    predictions_list = []
    total_loss = 0
    
    with torch.no_grad():
        for data in tqdm(data_loader, desc="Evaluating", unit="batch", leave=False):
            data = data.to(device)
            if data.y is None and calculate_accuracy_and_loss: # Cannot calculate accuracy without labels
                continue
            if data.y is not None and data.y.numel() == 0 and calculate_accuracy_and_loss:
                 continue

            output = model(data)
            pred = output.argmax(dim=1)
            
            if calculate_accuracy_and_loss and data.y is not None:
                correct_preds += (pred == data.y).sum().item()
                total_samples += data.y.size(0)
                if criterion_for_loss:
                    # Ensure criterion handles cases where output might be for fewer samples than data.y if some were filtered
                    valid_indices = (data.y != -1) # Example: if -1 indicates no label for a sample
                    if output.size(0) == data.y[valid_indices].size(0): # Check if output matches filtered labels
                         total_loss += criterion_for_loss(output, data.y[valid_indices]).item() * output.size(0)
                    elif output.size(0) == data.y.size(0):
                         total_loss += criterion_for_loss(output, data.y).item() * data.num_graphs


            else: # Only collecting predictions for test set usually
                predictions_list.extend(pred.cpu().numpy())

    if calculate_accuracy_and_loss:
        accuracy = correct_preds / total_samples if total_samples > 0 else 0.0
        avg_loss = total_loss / total_samples if total_samples > 0 and criterion_for_loss else 0.0
        return avg_loss, accuracy
    return predictions_list

print("evaluate function defined.")

In [None]:
def save_predictions(predictions_to_save, test_path_arg, root_submission_folder="submission", id_col_name="id", pred_col_name="pred"):
    current_dir = os.getcwd() # Use current_dir for clarity
    submission_folder_main_path = os.path.join(current_dir, root_submission_folder)
    os.makedirs(submission_folder_main_path, exist_ok=True)

    dataset_folder_name = "datasets"
    if test_path_arg:
        try:
            # Assumes test_path_arg is like './datasets/A/test.json.gz'
            dataset_folder_name = os.path.basename(os.path.dirname(test_path_arg))
            if not dataset_folder_name or dataset_folder_name == ".": # Handle edge cases
                # if path is just 'test.json.gz' or './test.json.gz'
                path_parts = os.path.normpath(test_path_arg).split(os.sep)
                if len(path_parts) > 2 and path_parts[-2] != "datasets": # e.g. datasets/A/file
                    dataset_folder_name = path_parts[-2]
                elif len(path_parts) > 1 and path_parts[-2] == "datasets": # e.g. datasets/file (no subfolder)
                     dataset_folder_name = "root_dataset"
                else: # single file
                    dataset_folder_name = "direct_file_dataset"


        except Exception as e:
            print(f"Error parsing dataset name from test_path: {e}, using default.")
            
    output_csv_filename = f"testset_{dataset_folder_name}.csv"
    output_csv_full_path = os.path.join(submission_folder_main_path, output_csv_filename)
    
    # Assuming predictions are for sequentially indexed graphs
    graph_ids = list(range(len(predictions_to_save)))
    output_df = pd.DataFrame({
        id_col_name: graph_ids,
        pred_col_name: predictions_to_save
    })
    
    output_df.to_csv(output_csv_full_path, index=False)
    logging.info(f"Predictions saved to {output_csv_full_path}")
    print(f"Predictions saved to {output_csv_full_path}")

print("save_predictions function defined.")

In [None]:
def plot_training_progress(losses_data, accuracies_data, plot_title_prefix, output_dir_for_plots, dataset_name_for_plot):
    epochs_count = range(1, len(losses_data) + 1)
    fig, axs = plt.subplots(1, 2, figsize=(15, 5)) # Use subplots for clarity

    # Plot Loss
    axs[0].plot(epochs_count, losses_data, 'b-o', label=f"{plot_title_prefix} Loss")
    axs[0].set_title(f'{plot_title_prefix} Loss ({dataset_name_for_plot})')
    axs[0].set_xlabel('Epochs')
    axs[0].set_ylabel('Loss')
    axs[0].legend()
    axs[0].grid(True)

    # Plot Accuracy
    axs[1].plot(epochs_count, accuracies_data, 'g-o', label=f"{plot_title_prefix} Accuracy")
    axs[1].set_title(f'{plot_title_prefix} Accuracy ({dataset_name_for_plot})')
    axs[1].set_xlabel('Epochs')
    axs[1].set_ylabel('Accuracy')
    axs[1].legend()
    axs[1].grid(True)

    os.makedirs(output_dir_for_plots, exist_ok=True)
    plot_filename = f"{plot_title_prefix.lower().replace(' ', '_')}_{dataset_name_for_plot}_progress.png"
    fig.savefig(os.path.join(output_dir_for_plots, plot_filename))
    plt.show() # Show plot in notebook
    plt.close(fig) # Close the figure to free memory
    logging.info(f"Plot saved: {os.path.join(output_dir_for_plots, plot_filename)}")

print("plot_training_progress function defined.")

In [None]:
class ExperimentConfig:
    def __init__(self):
        # --- Crucial: Dataset Paths ---
        # Ensure these paths are correct for your setup.
        # Example: To run for dataset 'A'
        self.dataset_letter = 'A' # CHANGE THIS FOR DIFFERENT DATASETS (A, B, C, D)
        
        # Base path to where 'A', 'B', etc. folders are located
        self.base_dataset_path = "./data"
        self.train_path = os.path.join(self.base_dataset_path, self.dataset_letter, "train.json.gz")
        self.test_path = os.path.join(self.base_dataset_path, self.dataset_letter, "test.json.gz")
        # Set self.train_path = None to skip training and only test (requires a best_model checkpoint)
        # self.train_path = None 

        # --- Model Hyperparameters (for ImprovedGIN) ---
        self.input_dim = 300         # For the embedding layer with add_zeros
        self.hidden_dim = 128        # Hidden dimension for GIN layers
        self.num_gnn_layers = 3      # Number of GIN layers in ImprovedGIN
        self.dropout_rate = 0.3      # Dropout rate used in ImprovedGIN

        # --- Training Hyperparameters ---
        self.epochs = 20             # Number of epochs to train (increased from 10 for better model)
        self.batch_size = 32
        self.lr = 0.001              # Learning rate

        # --- Loss Function ---
        # Options: "CE" for CrossEntropyLoss, "NoisyCE" for NoisyCrossEntropyLoss
        self.loss_type = "CE"
        self.num_classes = 6         # Required for NoisyCrossEntropyLoss if used, and model output
        self.noise_prob = 0.1        # p_noisy for NoisyCrossEntropyLoss (if self.loss_type is "NoisyCE")

        # --- Checkpoints & Logging & Submission ---
        self.num_periodic_checkpoints = 3 # How many checkpoints to save during training epochs
                                         # (e.g., 3 means save at 1/3, 2/3, and end of epochs)
                                         # The 'best' model based on validation is always saved.
        self.base_checkpoint_dir = "/kaggle/working/checkpoints"
        self.base_log_dir = "/kaggle/working/logs"
        self.base_submission_dir = "/kaggle/working/submission"

        # --- Device ---
        self.force_cpu = False # Set to True to force CPU even if CUDA is available

# Instantiate the configuration
config = ExperimentConfig()

# --- Validate paths ---
if config.train_path and not os.path.exists(config.train_path):
    print(f"WARNING: Train path does not exist: {config.train_path}")
    # config.train_path = None # Optionally disable training if path missing
if not os.path.exists(config.test_path):
    raise FileNotFoundError(f"CRITICAL: Test path does not exist: {config.test_path}")

print(f"Configuration for Dataset '{config.dataset_letter}':")
for key, value in vars(config).items():
    print(f"  {key}: {value}")

In [None]:
# ---- 0. Setup Device ----
device = torch.device('cuda' if torch.cuda.is_available() and not config.force_cpu else 'cpu')
print(f"========= Running on Device: {device} =========")

# ---- 1. Determine Dataset Name for file/folder naming ----
# This is now handled by config.dataset_letter directly.
dataset_name_fs = config.dataset_letter # Filesystem-safe dataset name
print(f"========= Processing Dataset: {dataset_name_fs} =========")

# ---- 2. Setup Output Directories ----
# Checkpoints: base_checkpoint_dir / dataset_name_fs / model_files... (periodic)
#            also base_checkpoint_dir / model_DATASET_best.pth (best val)
# Logs: base_log_dir / dataset_name_fs / logfile.log & plots...
# Submission: base_submission_dir / testset_DATASET.csv

dataset_periodic_checkpoint_storage_dir = os.path.join(config.base_checkpoint_dir, dataset_name_fs)
os.makedirs(dataset_periodic_checkpoint_storage_dir, exist_ok=True)
best_model_overall_checkpoint_path = os.path.join(config.base_checkpoint_dir, f"model_{dataset_name_fs}_best.pth")

dataset_specific_log_dir = os.path.join(config.base_log_dir, dataset_name_fs)
os.makedirs(dataset_specific_log_dir, exist_ok=True)
main_log_file_path = os.path.join(dataset_specific_log_dir, "experiment.log")

# ---- 3. Setup Logging ----
# Remove existing handlers to prevent duplicate logs if re-running cell
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(main_log_file_path, mode='w'), # Overwrite log file each run for this dataset
        logging.StreamHandler() # Also print to console
    ]
)
logging.info(f"Logging setup for dataset {dataset_name_fs}. Log file: {main_log_file_path}")
logging.info(f"Using device: {device}")
for key, value in vars(config).items(): # Log the config
    logging.info(f"CONFIG - {key}: {value}")


# ---- 4. Initialize Model, Optimizer, Criterion ----
model = ImprovedGIN(
    input_dim=config.input_dim,
    hidden_dim=config.hidden_dim,
    output_dim=config.num_classes,
    num_gnn_layers=config.num_gnn_layers,
    dropout_rate=config.dropout_rate
).to(device)
logging.info(f"Model: ImprovedGIN initialized: layers={config.num_gnn_layers}, hidden_dim={config.hidden_dim}, dropout={config.dropout_rate}")

optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
logging.info(f"Optimizer: Adam with LR={config.lr}")

if config.loss_type.upper() == "NOISYCE":
    criterion = NoisyCrossEntropyLoss(p_noisy=config.noise_prob, num_classes=config.num_classes)
    logging.info(f"Using NoisyCrossEntropyLoss with p_noisy={config.noise_prob}")
else:
    criterion = torch.nn.CrossEntropyLoss()
    logging.info("Using standard CrossEntropyLoss")

# ---- 5. Data Loading ----
if not os.path.exists(config.test_path):
    logging.error(f"Test path {config.test_path} not found. Exiting.")
    raise FileNotFoundError(f"Test path {config.test_path} not found.")
# TODO: modify the path of cnofig.test_path or remove modifications to use_processed and loadData
test_dataset = GraphDataset(config.test_path, transform=add_zeros, use_processed=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
logging.info(f"Loaded test data from {config.test_path} ({len(test_dataset)} graphs)")

# ---- 6. Training Phase ----
if config.train_path and os.path.exists(config.train_path):
    logging.info(f"Loading training data from {config.train_path}")
    full_train_dataset_obj = GraphDataset(config.train_path, transform=add_zeros)
    logging.info(f"Full training dataset loaded ({len(full_train_dataset_obj)} graphs)")
    
    val_split_ratio = 0.2 # 20% for validation
    num_total_train_graphs = len(full_train_dataset_obj)
    num_val_graphs = int(val_split_ratio * num_total_train_graphs)
    num_train_graphs = num_total_train_graphs - num_val_graphs

    split_generator = torch.Generator().manual_seed(42) # For reproducible splits
    train_dataset_split_obj, val_dataset_split_obj = torch.utils.data.random_split(
        full_train_dataset_obj, [num_train_graphs, num_val_graphs], generator=split_generator
    )

    train_loader = DataLoader(train_dataset_split_obj, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset_split_obj, batch_size=config.batch_size, shuffle=False)
    logging.info(f"Training data split: {len(train_dataset_split_obj)} train, {len(val_dataset_split_obj)} val graphs.")

    best_val_accuracy = -1.0 # Initialize to a low value
    train_losses_hist, train_accuracies_hist = [], []
    val_losses_hist, val_accuracies_hist = [], []

    # Determine epochs for periodic checkpoint saving
    epochs_for_periodic_checkpoints = set()
    if config.num_periodic_checkpoints > 0 and config.epochs > 0:
        periodic_interval = config.epochs // config.num_periodic_checkpoints
        if periodic_interval > 0:
            for i in range(1, config.num_periodic_checkpoints + 1):
                epochs_for_periodic_checkpoints.add(i * periodic_interval)
    epochs_for_periodic_checkpoints.add(config.epochs) # Always save at the final epoch

    logging.info(f"Starting training for {config.epochs} epochs...")
    for epoch in range(config.epochs):
        save_this_epoch_periodically = (epoch + 1) in epochs_for_periodic_checkpoints
        
        avg_epoch_train_loss, epoch_train_acc = train(
            train_loader, model, optimizer, criterion, device,
            save_checkpoints_periodically=save_this_epoch_periodically,
            periodic_checkpoint_dir_base=config.base_checkpoint_dir,
            dataset_name_for_checkpoint=dataset_name_fs,
            current_epoch=epoch
        )
        
        avg_epoch_val_loss, epoch_val_acc = evaluate(
            val_loader, model, device, 
            criterion_for_loss=criterion, calculate_accuracy_and_loss=True
        )

        logging.info(f"Epoch {epoch+1}/{config.epochs} -> Train Loss: {avg_epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f} | Val Loss: {avg_epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

        train_losses_hist.append(avg_epoch_train_loss)
        train_accuracies_hist.append(epoch_train_acc)
        val_losses_hist.append(avg_epoch_val_loss)
        val_accuracies_hist.append(epoch_val_acc)

        if epoch_val_acc > best_val_accuracy:
            best_val_accuracy = epoch_val_acc
            torch.save(model.state_dict(), best_model_overall_checkpoint_path)
            logging.info(f"*** New best validation accuracy: {best_val_accuracy:.4f}. Best model saved to {best_model_overall_checkpoint_path} ***")
    
    logging.info("Training finished.")
    # Plotting training and validation progress
    plots_output_directory = os.path.join(dataset_specific_log_dir, "plots")
    plot_training_progress(train_losses_hist, train_accuracies_hist, "Training", plots_output_directory, dataset_name_fs)
    plot_training_progress(val_losses_hist, val_accuracies_hist, "Validation", plots_output_directory, dataset_name_fs)
    logging.info(f"Training & validation plots saved to {plots_output_directory}")

    # Clean up large objects from memory
    del full_train_dataset_obj, train_dataset_split_obj, val_dataset_split_obj, train_loader, val_loader
    import gc
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

elif not config.train_path :
    logging.info("config.train_path is None. Skipping training. Model will be loaded for testing.")
else: # train_path was specified but not found
     logging.warning(f"Training path {config.train_path} was specified but not found. Skipping training. Attempting to load best model for testing.")


# ---- 7. Testing Phase ----
if not os.path.exists(best_model_overall_checkpoint_path):
    logging.error(f"Best model checkpoint {best_model_overall_checkpoint_path} not found. Cannot proceed with testing. Ensure training was run or a valid checkpoint exists.")
    raise FileNotFoundError(f"Best model checkpoint {best_model_overall_checkpoint_path} not found.")

logging.info(f"Loading best model for testing from: {best_model_overall_checkpoint_path}")
try:
    # Ensure model is on the correct device before loading state_dict if it was saved from a different device
    model.to(device) 
    model.load_state_dict(torch.load(best_model_overall_checkpoint_path, map_location=device))
    logging.info("Best model loaded successfully for testing.")
except Exception as e:
    logging.error(f"Failed to load best model from {best_model_overall_checkpoint_path}: {e}")
    raise

final_test_predictions = evaluate(test_loader, model, device, calculate_accuracy_and_loss=False)

save_predictions(final_test_predictions, config.test_path, root_submission_folder=config.base_submission_dir)
logging.info(f"========= Experiment for dataset {dataset_name_fs} finished. =========")

# Clean up console handler for potential re-runs of the cell
logging.getLogger().removeHandler(console_handler) # Assumes console_handler was defined if used