# Infusion Evaluation System

Used to Evaluate the performance of Infused, vs noninfused models over several training runs to compare different metrics. 

### ||RUN ON RESTART||

In [None]:
# Load dependencies

from utils import build_multilabel_dataset, multilabel_split, prep_infused_sweetnet

import os
import pickle

from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training


In [None]:
# Load embeddings

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            glm_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

## Evaluation Loop
Change parameters here for each trial run.


In [None]:
# Load part of dataset to train the model on

glycans, labels, label_names = build_multilabel_dataset(glycan_dataset='df_species', 
                                                        glycan_class='Kingdom', 
                                                        min_class_size=2)

In [None]:
# initialize all_training_histories which is used to save training data
all_training_histories = {} 
# Only run this cell when you run an entirely new run.

In [None]:
# Run settings

# file to save the run data to
saved_run_data = "evaluation_run_dump"

trial_seed = 1
#increment each trial by 1

config_description = 'Kindomd'
# baseline, infused_train, or infused

learning_rate = 0.005

In [None]:
# Split the dataset into training, validation, and test sets
ratio = 0.593
ratiod = 1 - ((1 - ratio)/2)
train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(glycans, labels, train_size = ratio, 
                                                             random_state=trial_seed, no_test = False)
"""
# Load into dataloders for training and validation
dataloaders = split_data_to_train(
    glycan_list_train = train_glycans, glycan_list_val = val_glycans, labels_train = train_labels, labels_val = val_labels,
    batch_size = 128,  # 32 or 128 seem to work well on this system
    drop_last = False,
    augment_prob = 0.0,  # Adjust if you want augmentation for training
    generalization_prob = 0.2  # Adjust if you want generalization for training
)
"""

In [None]:
ratio2 = 1- ((1 - ratio) / (1 + ratio))
print(f"Ratio for test set: {ratio2}")

In [None]:

train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(train_glycans, train_labels, train_size=ratio2, 
                                                             random_state=trial_seed, no_test = True)

In [None]:
# model training 

classes = len(labels[0]) # number of classes in the dataset

model =  prep_infused_sweetnet(
            initialization_method = 'random', # random or external
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = True, # True or False
            ) 

optimizer_ft, scheduler, criterion = model_training.training_setup(model, learning_rate, num_classes = classes)

model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 100, mode = 'multilabel', return_metrics = True)

run_identifier = f"{config_description}_{trial_seed}"
all_training_histories[run_identifier] = current_run_metrics

saved_run_data_path = (f"{saved_run_data}.pkl")

# Save the entire collection at the end (or periodically)
with open(saved_run_data_path, 'wb') as f:
    pickle.dump(all_training_histories, f)
print(f"Saved training histories to {saved_run_data_path}")

In [None]:
print(all_training_histories)

In [None]:
# Load trial data

pickle_file_path = 'evaluation_run_dump.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading data from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            user_data_string_from_input = pickle.load(file_handle)

        print("Data loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

In [None]:
print(user_data_string_from_input)

In [None]:

# print some keys to check in the Sanity Cheker
print(list(loaded_embeddings.keys())[:10])

In [None]:
# --- Embedding Sanity Checker ---

# 1. Choose a token to check 
token_to_check = '!GlcNAc' 


from glycowork.glycan_data.loader import  lib
import numpy as np
import torch

glycowork_lib = lib

if token_to_check not in glycowork_lib:
    print(f"Error: Token '{token_to_check}' not found in glycowork_lib. Choose another token.")
elif glm_embeddings is None or token_to_check not in glm_embeddings:
    print(f"Error: Token '{token_to_check}' not found in glm_embeddings dictionary or dictionary not loaded.")
else:
    print(f"--- Checking embedding for token: '{token_to_check}' ---")
    
    # 3. Get the index and the vector from the dictionary
    token_index = glycowork_lib[token_to_check]
    vector_from_dict = glm_embeddings[token_to_check]
    print(f"Index for '{token_to_check}': {token_index}")
    print(f"Vector from glm_embeddings dict (first 5 elements): {vector_from_dict[:5]}")

    # 4. Prepare a model instance using the 'external' method
    print("\nPreparing a temporary model instance with external embeddings...")
    try:
        # Use parameters relevant for checking the embedding layer
        temp_model = prep_infused_sweetnet(
            num_classes=len(labels[0]), # Needs a valid class number
            initialization_method='external',
            embeddings_dict=glm_embeddings,
            trainable_embeddings=False, # Trainable doesn't matter for checking initial state
            hidden_dim=vector_from_dict.shape[0], # Ensure hidden_dim matches embedding dim
            libr=glycowork_lib
        )
        

        # 5. Get the vector from the model's embedding layer
        with torch.no_grad(): # No need for gradients here
            model_embedding_layer = temp_model.item_embedding
            # Ensure index is valid for the layer
            if token_index < model_embedding_layer.weight.shape[0]:
                vector_from_model = model_embedding_layer.weight[token_index].cpu().numpy()
                print(f"Vector from model's layer (index {token_index}, first 5 elements): {vector_from_model[:5]}")

                # 6. Compare the vectors
                if np.allclose(vector_from_dict, vector_from_model, atol=1e-6): # Use allclose for float comparison
                    print(f"\nSUCCESS: Vectors for '{token_to_check}' match between dictionary and model layer.")
                else:
                    print(f"\nFAILURE: Vectors for '{token_to_check}' DO NOT match.")
                    # Optional: print more elements or the difference
                    # print(f"Difference (sum of absolute diff): {np.sum(np.abs(vector_from_dict - vector_from_model))}")
            else:
                print(f"Error: Index {token_index} is out of bounds for the model's embedding layer (size {model_embedding_layer.weight.shape[0]})")

    except Exception as e:
        print(f"\nAn error occurred during model preparation or vector comparison: {e}")

In [None]:
classes = len(labels[0]) # number of classes in the dataset

model =  prep_infused_sweetnet(
            initialization_method = 'external', # random or external
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = False, # True or False
            ) 

In [None]:
print(model.item_embedding.weight.data[3])
print(model.item_embedding.weight.data[10])
print(model.item_embedding.weight.data[42])

In [None]:
optimizer_ft, scheduler, criterion = model_training.training_setup(model, learning_rate, num_classes = classes)

model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 10, mode = 'multilabel', return_metrics = True)

In [None]:
print(model_ft.item_embedding.weight.data[3])