# Automated Infusion Evaluation System

This notebook automates the process of running multiple trials for different model configurations (baseline, infused-fixed, infused-trainable) to evaluate and compare their performance systematically. It collects validation metrics for each run, aggregates them, and saves the results.


### ||RUN ON RESTART||

In [None]:
# Load dependencies

from utils import build_multilabel_dataset, multilabel_split, prep_infused_sweetnet, seed_everything

import os
import pickle

from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training


In [None]:
# Load embeddings

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            glm_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

## Experimental setup
Change parameters here to define each Experiment.


In [None]:
# Load part of dataset to train the model on

glycans, labels, label_names = build_multilabel_dataset(glycan_dataset = 'df_disease', 
                                                        glycan_class = 'disease_association', 
                                                        min_class_size = 6)

In [None]:
# --- Global Parameters & Experiment Configuration ---
EXPERIMENT_NAME =           "infusion_trial_1" # Name of the experiment for saving results
NUM_RUNS =                  10  # Number of trials per configuration (e.g., 5 or 10)
EPOCHS =                    100 # Number of training epochs per run
BATCH_SIZE =                128 # 32 or 128 seems to work well
TRAIN_SIZE =                0.7 # Fraction of data to use for training (0.7 = 70% train, 15% val, 15% test)
LEARNING_RATE =             0.005 # Learning rate for the optimizer
DROP_LAST =                 False # Whether to drop the last batch if it's smaller than the batch size
AUGMENT_PROB =              0.0  # Adjust if you want augmentation for training
GENERALIZATION_PROB =       0.2  # Adjust if you want generalization for training
BASE_RANDOM_STATE =         42 # Initial seed for reproducibility of the entire experiment sequence
RESULTS_SUMMARY_CSV_PATH =  f"summary_{EXPERIMENT_NAME}.csv"
ALL_EPOCH_DATA_PKL_PATH =   f"epoch_data_{EXPERIMENT_NAME}.pkl"
CLASSES =                   len(labels[0]) # Number of classes in the dataset


# --- Define Experiment Configurations ---
# This list defines all sets of parameters you want to test.
# This can be used to do basic Hyperparameter tuning as well.
experiment_configs = [
    {
        "name": "baseline_trainable",
        "initialization_method": "random", 
        "trainable_embeddings": True
    },
    {
        "name": "baseline_fixed",
        "initialization_method": "random", 
        "trainable_embeddings": False
    },
    {
        "name": "infused_trainable",
        "initialization_method": "external", 
        "trainable_embeddings": True
    },
    {
        "name": "infused_fixed",
        "initialization_method": "external", 
        "trainable_embeddings": False
    }   
    # add more configurations as needed     
    ]


# Set the random seed for reproducibility
seed_everything(BASE_RANDOM_STATE)

# Initialize lists/dicts to store results globally for the notebook session
all_run_summary_results = {} # For storing best metrics from each run
all_run_epoch_histories = {} # For storing full epoch-wise histories

print(f"Experiment *{EXPERIMENT_NAME}* set up with {len(experiment_configs)} configurations.")

## Run Experiment Loop

In [None]:
print(f"Initializing {EXPERIMENT_NAME} experiment with {len(experiment_configs)} configurations and {NUM_RUNS} runs per configuration.")


for i in range(NUM_RUNS):

    # Print the current run number
    print("----------------")
    print(f"Run {i+1}/{NUM_RUNS}")
    print()

    # Increment the random state for each run
    random_state = BASE_RANDOM_STATE + i

    # Set the random seed for reproducibility for this run
    seed_everything(random_state)    

    # --- split data for this run outside of core loop for efficiency ---
    # Split the dataset into training, validation, and test sets
    train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(glycans, labels, train_size=TRAIN_SIZE, 
                                                                random_state=random_state)

    # Load into dataloders for training and validation
    dataloaders = split_data_to_train(
        glycan_list_train = train_glycans, glycan_list_val = val_glycans, labels_train = train_labels, labels_val = val_labels,
        batch_size = BATCH_SIZE,
        drop_last = DROP_LAST,
        augment_prob = AUGMENT_PROB, 
        generalization_prob = GENERALIZATION_PROB
    )
    
    # --- Loop through each configuration & train model ---
    for config in experiment_configs:

        # Extract the configuration parameters
        config_name = config["name"]
        initialization_method = config["initialization_method"]
        trainable_embeddings = config["trainable_embeddings"]

        # Print the current configuration being used
        print("----------------")
        print(f"Running configuration: {config_name}")
        print()

        # --- Model Training ---
        # Initialize the model with the specified parameters
        model =  prep_infused_sweetnet(
                    initialization_method = initialization_method,
                    num_classes = CLASSES,
                    embeddings_dict = glm_embeddings, 
                    trainable_embeddings = trainable_embeddings
                    ) 
        
        print()
        
        # Run the training setup function to prepare the model for training
        optimizer, scheduler, criterion = model_training.training_setup(model, LEARNING_RATE, num_classes = CLASSES)

        # Run the training process
        model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer, scheduler,
                        num_epochs = EPOCHS, mode = 'multilabel', return_metrics = True)
        
        print()
        
        # Collect the epoch-wise metrics from this configuration
        config_run_identifier = f"{config_name}_{i+1}"
        all_run_epoch_histories[config_run_identifier] = current_run_metrics

        # --- Save the best metrics from this run to the summary results dictionary ---
        # generate keys for the summary results
        loss_key = f"{config_name}_loss"
        lrap_key = f"{config_name}_lrap"
        ndcg_key = f"{config_name}_ndcg"
        metric_keys = [loss_key, lrap_key, ndcg_key]
        
        # add keys to the summary results dictionary if they don't exist
        for key in metric_keys:
            if key not in all_run_summary_results:
                all_run_summary_results[key] = []

        # Find the best metrics from the current run
        best_loss = min(current_run_metrics['val']['loss'])
        best_lrap = max(current_run_metrics['val']['lrap'])
        best_ndcg = max(current_run_metrics['val']['ndcg'])
        best_metrics = [best_loss, best_lrap, best_ndcg]

        for key, metric in zip(metric_keys, best_metrics):
            # Append the best metric to the summary results dictionary
            all_run_summary_results[key].append(metric)
        
    # --- Export metrics at end of each run, in case of early termination ---

    # Save the epoch-wise metrics for this run to a pickle file
    with open(ALL_EPOCH_DATA_PKL_PATH, 'wb') as f:
        pickle.dump(all_run_epoch_histories, f)
    print(f"Saved training histories to {ALL_EPOCH_DATA_PKL_PATH}")

    # Save the summary results to a CSV file
    with open(RESULTS_SUMMARY_CSV_PATH, 'w') as f:
        # Write the header
        f.write(','.join(all_run_summary_results.keys()) + '\n')
        
        # Write the data
        any_key = next(iter(all_run_summary_results.keys()))
        num_runs_done = len(all_run_summary_results[any_key])
        for i in range(num_runs_done):
            row = [str(all_run_summary_results[key][i]) for key in all_run_summary_results.keys()]
            f.write(','.join(row) + '\n')
    print(f"Saved summary results to {RESULTS_SUMMARY_CSV_PATH}")

        

In [None]:
# Developing function to test model here before migrating to utils.py

import torch
import torch.utils.data
import torch.nn 
from typing import Dict # Import Dict for type hinting
from glycowork.ml.model_training import sigmoid
from sklearn.metrics import label_ranking_average_precision_score, ndcg_score, 

def test_model(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               criterion: torch.nn.Module) -> dict[str, float]:
    """
    Evaluates a multi-label model on a test set.

    Parameters
    ----------
    model: torch.nn.Module
        The trained model to evaluate.
    dataloader: torch.utils.data.DataLoader
        DataLoader containing the test split.
    criterion: torch.nn.Module
        The loss function to calculate average loss during evaluation.

    Returns
    -------
    dict
        A dictionary containing calculated evaluation metrics
        for the multi-label task.
    """
    
    pass

In [None]:
# Load trial data

pickle_file_path = 'evaluation_run_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading data from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            user_data_string_from_input = pickle.load(file_handle)

        print("Data loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

In [None]:
print(user_data_string_from_input)