# Infusion Evaluation System

Used to Evaluate the performance of Infused, vs noninfused models over several training runs to compare different metrics. 

### ||RUN ON RESTART||

In [None]:
# Load dependencies

from utils import build_multilabel_dataset, multilabel_split, prep_infused_sweetnet

import os
import pickle

from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training


In [None]:
# Load embeddings

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            glm_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

## Evaluation Loop
Change parameters here for each trial run.


In [30]:
# Load part of dataset to train the model on

glycans, labels, label_names = build_multilabel_dataset(glycan_dataset='df_tissue', 
                                                        glycan_class='tissue_sample', 
                                                        min_class_size=6)

Found 262 unique individual classes/labels.
Number of unique glycans left after filtering rare classes (size >= 6): 5248/6560
Number of unique labels left after filtering: 96


In [None]:
# initialize all_training_histories which is used to save training data
all_training_histories = {} 
# Only run this cell when you run an entirely new run.

In [31]:
# Run settings

# file to save the run data to
saved_run_data = "evaluation_run_dump"

trial_seed = 1
#increment each trial by 1

config_description = 'Kindomd'
# baseline, infused_train, or infused

learning_rate = 0.005

In [32]:
# Split the dataset into training, validation, and test sets
train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(glycans, labels, train_size=0.7, 
                                                             random_state=trial_seed)

# Load into dataloders for training and validation
dataloaders = split_data_to_train(
    glycan_list_train = train_glycans, glycan_list_val = val_glycans, labels_train = train_labels, labels_val = val_labels,
    batch_size = 128,  # 32 or 128 seem to work well on this system
    drop_last = False,
    augment_prob = 0.0,  # Adjust if you want augmentation for training
    generalization_prob = 0.2  # Adjust if you want generalization for training
)

Split complete!
Train set size: 3673
Validation set size: 787
Test set size: 788


In [33]:
# model training 

classes = len(labels[0]) # number of classes in the dataset

model =  prep_infused_sweetnet(
            initialization_method = 'random', # random or external
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = True, # True or False
            ) 

optimizer_ft, scheduler, criterion = model_training.training_setup(model, learning_rate, num_classes = classes)

model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 10, mode = 'multilabel', return_metrics = True)

run_identifier = f"{config_description}_{trial_seed}"
all_training_histories[run_identifier] = current_run_metrics

saved_run_data_path = (f"{saved_run_data}.pkl")

# Save the entire collection at the end (or periodically)
with open(saved_run_data_path, 'wb') as f:
    pickle.dump(all_training_histories, f)
print(f"Saved training histories to {saved_run_data_path}")

SweetNet model instantiated with lib_size=2565, num_classes=96, hidden_dim=320.
Handling 'random' initialization method (training from scratch).
SweetNet item_embedding layer set to trainable: True.
Epoch 0/9
----------
train Loss: 4.4966 LRAP: 0.0000 NDCG: 0.1701
val Loss: 4.3607 LRAP: 0.0000 NDCG: 0.1764
Validation loss decreased (0.000000 --> 4.360734).  Saving model ...

Epoch 1/9
----------
train Loss: 3.9080 LRAP: 0.0008 NDCG: 0.2086
val Loss: 4.3394 LRAP: 0.0025 NDCG: 0.2004
Validation loss decreased (4.360734 --> 4.339397).  Saving model ...

Epoch 2/9
----------
train Loss: 3.7470 LRAP: 0.0071 NDCG: 0.2227
val Loss: 3.6905 LRAP: 0.0064 NDCG: 0.2268
Validation loss decreased (4.339397 --> 3.690483).  Saving model ...

Epoch 3/9
----------
train Loss: 3.6604 LRAP: 0.0106 NDCG: 0.2297
val Loss: 3.5760 LRAP: 0.0076 NDCG: 0.2448
Validation loss decreased (3.690483 --> 3.576006).  Saving model ...

Epoch 4/9
----------
train Loss: 3.5641 LRAP: 0.0090 NDCG: 0.2335
val Loss: 3.4915 LR

In [None]:
print(all_training_histories)

In [None]:
# Load trial data

pickle_file_path = 'evaluation_run_dump.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading data from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            user_data_string_from_input = pickle.load(file_handle)

        print("Data loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

In [None]:
print(user_data_string_from_input)

In [None]:

# print some keys to check in the Sanity Cheker
print(list(loaded_embeddings.keys())[:10])

In [None]:
# --- Embedding Sanity Checker ---

# 1. Choose a token to check 
token_to_check = '!GlcNAc' 


from glycowork.glycan_data.loader import  lib
import numpy as np
import torch

glycowork_lib = lib

if token_to_check not in glycowork_lib:
    print(f"Error: Token '{token_to_check}' not found in glycowork_lib. Choose another token.")
elif glm_embeddings is None or token_to_check not in glm_embeddings:
    print(f"Error: Token '{token_to_check}' not found in glm_embeddings dictionary or dictionary not loaded.")
else:
    print(f"--- Checking embedding for token: '{token_to_check}' ---")
    
    # 3. Get the index and the vector from the dictionary
    token_index = glycowork_lib[token_to_check]
    vector_from_dict = glm_embeddings[token_to_check]
    print(f"Index for '{token_to_check}': {token_index}")
    print(f"Vector from glm_embeddings dict (first 5 elements): {vector_from_dict[:5]}")

    # 4. Prepare a model instance using the 'external' method
    print("\nPreparing a temporary model instance with external embeddings...")
    try:
        # Use parameters relevant for checking the embedding layer
        temp_model = prep_infused_sweetnet(
            num_classes=len(labels[0]), # Needs a valid class number
            initialization_method='external',
            embeddings_dict=glm_embeddings,
            trainable_embeddings=False, # Trainable doesn't matter for checking initial state
            hidden_dim=vector_from_dict.shape[0], # Ensure hidden_dim matches embedding dim
            libr=glycowork_lib
        )
        

        # 5. Get the vector from the model's embedding layer
        with torch.no_grad(): # No need for gradients here
            model_embedding_layer = temp_model.item_embedding
            # Ensure index is valid for the layer
            if token_index < model_embedding_layer.weight.shape[0]:
                vector_from_model = model_embedding_layer.weight[token_index].cpu().numpy()
                print(f"Vector from model's layer (index {token_index}, first 5 elements): {vector_from_model[:5]}")

                # 6. Compare the vectors
                if np.allclose(vector_from_dict, vector_from_model, atol=1e-6): # Use allclose for float comparison
                    print(f"\nSUCCESS: Vectors for '{token_to_check}' match between dictionary and model layer.")
                else:
                    print(f"\nFAILURE: Vectors for '{token_to_check}' DO NOT match.")
                    # Optional: print more elements or the difference
                    # print(f"Difference (sum of absolute diff): {np.sum(np.abs(vector_from_dict - vector_from_model))}")
            else:
                print(f"Error: Index {token_index} is out of bounds for the model's embedding layer (size {model_embedding_layer.weight.shape[0]})")

    except Exception as e:
        print(f"\nAn error occurred during model preparation or vector comparison: {e}")

In [43]:
classes = len(labels[0]) # number of classes in the dataset

model =  prep_infused_sweetnet(
            initialization_method = 'random', # random or external
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = True, # True or False
            ) 

SweetNet model instantiated with lib_size=2565, num_classes=96, hidden_dim=320.
Handling 'random' initialization method (training from scratch).
SweetNet item_embedding layer set to trainable: True.


In [44]:
print(model.item_embedding.weight.data[3])
print(model.item_embedding.weight.data[10])
print(model.item_embedding.weight.data[42])

tensor([-1.4525,  0.3075,  1.4349,  0.3794, -1.7839,  0.3885,  1.2002, -2.1629,
         0.5717, -0.6538,  1.5382,  0.6430, -0.1368,  0.7357, -0.9986,  1.1028,
         1.1380,  1.6883, -0.4352, -0.2213,  0.3439, -1.3767,  0.3486,  0.4144,
         1.0753,  1.1903,  0.6938,  2.1160, -1.0378,  1.6642,  0.7488, -0.6732,
         0.8369, -0.4839, -0.0633,  0.1429, -0.1653, -0.5991,  0.4693, -1.3164,
        -1.2417,  0.5104, -0.0872, -0.3140, -1.1179,  0.6367,  1.5762, -0.5266,
        -0.0107, -0.1291, -1.1057, -1.3789, -1.7769, -0.0630, -1.4919,  0.3969,
         0.5373, -0.9781, -0.0049,  0.6047,  0.3725, -0.1161, -0.7335,  1.4174,
         1.1244,  0.2782,  0.3702, -0.5335,  0.7336, -0.7596,  2.1775,  1.5075,
         0.0878,  0.1304,  1.0939, -0.9240, -0.3225,  2.2266, -0.0524, -0.3247,
         0.9508,  1.1304,  0.8215, -0.5020, -0.2016,  0.5250,  0.4758,  0.3991,
        -0.7320,  0.6299,  1.4124,  1.0406, -1.6504,  0.3421, -1.1384, -0.5006,
        -2.8388, -0.1219,  0.4013, -1.39

In [None]:
optimizer_ft, scheduler, criterion = model_training.training_setup(model, learning_rate, num_classes = classes)

model_ft, current_run_metrics = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 10, mode = 'multilabel', return_metrics = True)

Epoch 0/9
----------
train Loss: 4.4550 LRAP: 0.0000 NDCG: 0.1664
val Loss: 4.2388 LRAP: 0.0000 NDCG: 0.1875
Validation loss decreased (0.000000 --> 4.238810).  Saving model ...

Epoch 1/9
----------
train Loss: 3.9701 LRAP: 0.0005 NDCG: 0.1916
val Loss: 3.9027 LRAP: 0.0089 NDCG: 0.2076
Validation loss decreased (4.238810 --> 3.902740).  Saving model ...

Epoch 2/9
----------
train Loss: 3.8381 LRAP: 0.0033 NDCG: 0.2049
val Loss: 3.7819 LRAP: 0.0089 NDCG: 0.2345
Validation loss decreased (3.902740 --> 3.781872).  Saving model ...

Epoch 3/9
----------
train Loss: 3.7427 LRAP: 0.0054 NDCG: 0.2100
val Loss: 3.6963 LRAP: 0.0089 NDCG: 0.2336
Validation loss decreased (3.781872 --> 3.696251).  Saving model ...

Epoch 4/9
----------
train Loss: 3.6751 LRAP: 0.0063 NDCG: 0.2203
val Loss: 3.8380 LRAP: 0.0292 NDCG: 0.2333
EarlyStopping counter: 1 out of 50

Epoch 5/9
----------
train Loss: 3.6109 LRAP: 0.0084 NDCG: 0.2132
val Loss: 3.9702 LRAP: 0.0038 NDCG: 0.2085
EarlyStopping counter: 2 out o

In [42]:
print(model_ft.item_embedding.weight.data[3])

tensor([-4.2740e-01, -2.0471e+00, -1.4139e+00, -1.0333e+00, -1.1648e+00,
        -1.2794e+00,  2.2036e+00, -1.0167e+00, -2.5458e+00,  5.7299e-01,
         7.6437e-01,  3.6878e-01, -1.6637e+00,  1.8681e+00, -2.1661e-01,
         3.9253e-01, -4.1259e-02,  7.3389e-01,  3.6420e-01, -5.1863e-01,
        -3.9298e-01, -2.4760e-01,  8.1618e-01, -7.1385e-01, -3.0701e-02,
         1.7010e+00, -1.3852e-01,  6.1011e-01,  1.7432e+00,  6.6272e-01,
         2.5717e-01, -4.0061e-01,  1.3637e+00, -1.7826e+00, -7.0668e-01,
         1.4412e+00, -5.5786e-01, -3.9710e-01, -7.9893e-01,  1.2903e+00,
        -5.6667e-02,  7.5806e-01, -2.0977e+00,  1.4146e+00,  1.6569e-01,
        -3.8940e-02, -3.0771e-01,  4.4825e-01,  4.0977e-01,  7.6066e-01,
        -1.0494e+00,  1.8453e+00, -7.0803e-01, -4.2688e-01,  9.6372e-01,
         1.5386e+00,  7.5607e-01, -1.3538e-01, -1.2501e+00, -5.6516e-01,
         9.0650e-01,  1.5950e+00, -3.3288e-01,  8.4733e-01,  7.3296e-01,
         2.6609e-01,  1.1626e-01, -3.9144e-01,  1.3