# Infusion Evaluation System

Used to Evaluate the performance of Infused, vs noninfused models over several training runs to compare different metrics. 

### ||RUN ON RESTART||

In [None]:
# Load dependencies

from utils import build_multilabel_dataset, multilabel_split, prep_infused_sweetnet

import os
import pickle

from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training


In [None]:
# Load embeddings

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            glm_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

## Evaluation Loop
Change parameters here for each trial run.


In [None]:
# Load part of dataset to train the model on

glycans, labels, label_names = build_multilabel_dataset(glycan_dataset='df_disease', 
                                                        glycan_class='disease_association', 
                                                        min_class_size=6)

In [None]:
# initialize all_training_histories which is used to save training data
all_training_histories = {} 
# Only run this cell when you run an entirely new run.

In [None]:
# Run settings

# file to save the run data to
saved_run_data = "evaluation_run_rand_static"

trial_seed = 5
#increment each trial by 1

config_description = 'rand_static'
# baseline, infused_train, or infused

In [None]:
# Split the dataset into training, validation, and test sets
train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(glycans, labels, train_size=0.7, 
                                                             random_state=trial_seed)

# Load into dataloders for training and validation
dataloaders = split_data_to_train(
    glycan_list_train = train_glycans, glycan_list_val = val_glycans, labels_train = train_labels, labels_val = val_labels,
    batch_size = 128,  # 32 or 128 seem to work well on this system
    drop_last = False,
    augment_prob = 0.0,  # Adjust if you want augmentation for training
    generalization_prob = 0.2  # Adjust if you want generalization for training
)

In [None]:
# model training 

classes = len(labels[0]) # number of classes in the dataset

model =  prep_infused_sweetnet(
            initialization_method = 'random', # random or external
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = True, # True or False
            ) 

optimizer_ft, scheduler, criterion = model_training.training_setup(model, 0.0005, num_classes = classes)

model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 100, mode = 'multilabel', return_metrics = True)

run_identifier = f"{config_description}_{trial_seed}"
all_training_histories[run_identifier] = current_run_metrics

saved_run_data_path = (f"{saved_run_data}.pkl")

# Save the entire collection at the end (or periodically)
with open(saved_run_data_path, 'wb') as f:
    pickle.dump(all_training_histories, f)
print(f"Saved training histories to {saved_run_data_path}")

In [None]:
print(all_training_histories)

In [None]:
# Load trial data

pickle_file_path = 'evaluation_run_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading data from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            user_data_string_from_input = pickle.load(file_handle)

        print("Data loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

In [None]:
print(user_data_string_from_input)