# Infusion Evaluation System

Used to Evaluate the performance of Infused, vs noninfused models over several training runs to compare different metrics. 

### ||RUN ON RESTART||

In [1]:
# Load dependencies

from utils import build_multilabel_dataset, multilabel_split, prep_infused_sweetnet

import os
import pickle

from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training


In [3]:
# Load embeddings

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            glm_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

Loading embeddings from: glm_embeddings_1.pkl
Embeddings loaded successfully!


## Evaluation Loop
Change parameters here for each trial run.


In [4]:
# Load part of dataset to train the model on

glycans, labels, label_names = build_multilabel_dataset(glycan_dataset='df_disease', 
                                                        glycan_class='disease_association', 
                                                        min_class_size=6)

Found 60 unique individual classes/labels.
Number of unique glycans left after filtering rare classes (size >= 6): 1458/1648
Number of unique labels left after filtering: 18


In [5]:
# Split the dataset into training, validation, and test sets
train_glycans, val_glycans, test_glycans, \
    train_labels, val_labels, test_labels = multilabel_split(glycans, labels, train_size=0.7, 
                                                             random_state=42)

# Load into dataloders for training and validation
glycan_loaders = split_data_to_train(
    glycan_list_train = train_glycans, glycan_list_val = val_glycans, labels_train = train_labels, labels_val = val_labels,
    batch_size = 128,  # 32 or 128 seem to work well on this system
    drop_last = False,
    augment_prob = 0.0,  # Adjust if you want augmentation for training
    generalization_prob = 0.2  # Adjust if you want generalization for training
)

Split complete!
Train set size: 1020
Validation set size: 219
Test set size: 219


In [None]:
# model training 

classes = len(labels[0]) # number of classes in the dataset
dataloaders = glycan_loaders


model =  prep_infused_sweetnet(
            initialization_method = 'external',
            num_classes = classes,
            embeddings_dict = glm_embeddings, 
            trainable_embeddings = True
            ) 

optimizer_ft, scheduler, criterion = model_training.training_setup(model, 0.0005, num_classes = classes)

model_ft = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 100, mode = 'multilabel',)

SweetNet model instantiated with lib_size=2565, num_classes=18, hidden_dim=320.
Handling 'external' initialization method.
SweetNet item_embedding layer set to trainable: True.
Epoch 0/99
----------
train Loss: 3.7704 LRAP: 0.0020 NDCG: 0.1450
val Loss: 3.8337 LRAP: 0.0000 NDCG: -0.1150
Validation loss decreased (0.000000 --> 3.833693).  Saving model ...

Epoch 1/99
----------
train Loss: 3.6634 LRAP: 0.0059 NDCG: 0.2769
val Loss: 3.7367 LRAP: 0.0228 NDCG: 0.5497
Validation loss decreased (3.833693 --> 3.736732).  Saving model ...

Epoch 2/99
----------
train Loss: 3.5717 LRAP: 0.0069 NDCG: 0.3490
val Loss: 3.6126 LRAP: 0.0228 NDCG: 0.5497
Validation loss decreased (3.736732 --> 3.612585).  Saving model ...

Epoch 3/99
----------
train Loss: 3.4672 LRAP: 0.0157 NDCG: 0.4563
val Loss: 3.6224 LRAP: 0.0228 NDCG: 0.5402
EarlyStopping counter: 1 out of 50

Epoch 4/99
----------
train Loss: 3.2888 LRAP: 0.0176 NDCG: 0.5166
val Loss: 3.3311 LRAP: 0.0228 NDCG: 0.5546
Validation loss decreased 