# GLM-Infused SweetNet Experiments

Experimenting with a modified version of SweetNet that allows it to take pre-trained embeddings as input. To get there I need a way to take the embeddings I’ve gotten from roman and transform them into nice inputs for the model, and a way to set the initial features using these inputs. 


## Importing and exploring the GLM Embedding data

In [None]:
# quick thing to load a pickle file

import pickle
import os # To check if file exists

pickle_file_path = 'glm_embeddings_1.pkl'

# --- Load the Pickle File ---
if os.path.exists(pickle_file_path):
    print(f"Loading embeddings from: {pickle_file_path}")
    try:
        # Open the file in binary read mode ('rb')
        with open(pickle_file_path, 'rb') as file_handle:
            # Load the object(s) from the pickle file
            loaded_embeddings = pickle.load(file_handle)

        print("Embeddings loaded successfully!")        

    except Exception as e:
        print(f"An error occurred while loading the pickle file: {e}")
else:
    print(f"Error: File not found at '{pickle_file_path}'. Please check the filename and path.")

Loading embeddings from: glm_embeddings_1.pkl
Embeddings loaded successfully!
Type of loaded object: <class 'dict'>
Number of items (if dictionary): 2565
Example keys (first 5): ['!GlcNAc', '-10', '-12', '-2', '-4']


In [6]:
# lets do some quick exploration

# --- Explore the loaded data ---
print(f"Type of loaded object: {type(loaded_embeddings)}")

# Common formats for embeddings: dictionary or numpy array
if isinstance(loaded_embeddings, dict):
    print(f"Number of items (if dictionary): {len(loaded_embeddings)}")
    # print some keys to see what they look like
    print(f"Example keys (first 5): {list(loaded_embeddings.keys())[:5]}")
elif hasattr(loaded_embeddings, 'shape'):
    print(f"Shape (if array/tensor): {loaded_embeddings.shape}")
    if hasattr(loaded_embeddings, 'dtype'):
         print(f"Data type (if array/tensor): {loaded_embeddings.dtype}")

Type of loaded object: <class 'dict'>
Number of items (if dictionary): 2565
Example keys (first 5): ['!GlcNAc', '-10', '-12', '-2', '-4']


In [7]:
print(list(loaded_embeddings.keys())[5:30]) # Print more keys

['-6', '-8', '0dHex', '1,4-Anhydro-Gal-ol', '1,5-Anhydro-D-AltNAc-ol', '1,5-Anhydro-D-FucN-ol', '1,5-Anhydro-D-Rha4NAc-ol', '1,5-Anhydro-Gal-ol', '1,5-Anhydro-GalNAc-ol', '1,5-Anhydro-Glc-ol', '1,5-Anhydro-Glc-onic', '1,5-Anhydro-GlcN2S-ol', '1,5-Anhydro-GlcN2S6S-ol', '1,5-Anhydro-GlcNAc-ol', '1,5-Anhydro-GlcNAc-onic', '1,5-Anhydro-Man-ol', '1,5-Anhydro-ManNAc-ol', '1,5-Anhydro-Xyl-ol', '1,5-Anhydro-Xyl2F-ol', '1-1', '1-2', '1-3', '1-4', '1-5', '1-6']


In [9]:
example_key = '!GlcNAc' 
if example_key in loaded_embeddings:
    embedding_vector = loaded_embeddings[example_key]
    print(f"Type of value for '{example_key}': {type(embedding_vector)}")
    if hasattr(embedding_vector, 'shape'):
        print(f"Shape of value: {embedding_vector.shape}") # This gives dimensionality!
        print(f"Dtype of value: {embedding_vector.dtype}")
    print(embedding_vector) # Print the vector itself if it's not too long
else:
    print(f"Key '{example_key}' not found.")

Type of value for '!GlcNAc': <class 'numpy.ndarray'>
Shape of value: (320,)
Dtype of value: float32
[ 9.33886290e-01 -7.57189512e-01 -5.22765040e-01  4.93726492e-01
  3.03156078e-01 -1.72754931e+00  2.03015614e+00 -1.13539708e+00
 -8.32044244e-01 -6.09763384e-01 -5.63947335e-02 -2.68140852e-01
 -6.37493312e-01  1.45667583e-01 -7.75620103e-01 -1.39048725e-01
  1.06042847e-01 -3.74972522e-01  7.91566074e-01 -1.03034627e+00
 -1.12639211e-01 -3.78986076e-03  5.92547238e-01  2.81559825e-01
 -5.21002829e-01  9.35327411e-01  2.56601274e-01 -3.91364455e-01
  2.72188634e-02  5.00928342e-01 -5.55309415e-01  1.28289807e+00
 -6.45282388e-01  5.19899249e-01  6.10100806e-01  1.84122849e+00
  3.11432898e-01 -7.64928609e-02 -1.05589128e+00  6.50653005e-01
  9.70111132e-01  7.40227938e-01  8.39829683e-01 -3.04328918e-01
 -1.06630003e+00  4.53770608e-01  4.27673876e-01 -6.02427721e-01
  4.39536482e-01 -1.16493046e+00 -2.04154789e-01  1.13036299e+00
  2.51586974e-01  1.04393315e+00  2.60879964e-01  4.638

In [10]:
# let's look at the keys a bit more closely

import collections

key_types = collections.defaultdict(int)
for key in loaded_embeddings.keys():
    if '-' in key and not any(char.isalpha() for char in key):
        key_types['linkage_or_modification'] += 1
    elif key[0].isalpha():
        key_types['monosaccharide'] += 1
    else:
        key_types['other'] += 1

print(key_types)

defaultdict(<class 'int'>, {'other': 122, 'linkage_or_modification': 36, 'monosaccharide': 2407})


In [None]:
# Let's explore those Other keys 

other_keys = []
for key in loaded_embeddings.keys():
    if '-' in key and not any(char.isalpha() for char in key):
        pass # linkage_or_modification
    elif key[0].isalpha():
        pass # monosaccharide
    else:
        other_keys.append(key)

print(f"Number of 'other' keys: {len(other_keys)}")
print(f"Examples of 'other' keys: {other_keys[:20]}") # Print the first 20

Number of 'other' keys: 122
Examples of 'other' keys: ['!GlcNAc', '0dHex', '1,4-Anhydro-Gal-ol', '1,5-Anhydro-D-AltNAc-ol', '1,5-Anhydro-D-FucN-ol', '1,5-Anhydro-D-Rha4NAc-ol', '1,5-Anhydro-Gal-ol', '1,5-Anhydro-GalNAc-ol', '1,5-Anhydro-Glc-ol', '1,5-Anhydro-Glc-onic', '1,5-Anhydro-GlcN2S-ol', '1,5-Anhydro-GlcN2S6S-ol', '1,5-Anhydro-GlcNAc-ol', '1,5-Anhydro-GlcNAc-onic', '1,5-Anhydro-Man-ol', '1,5-Anhydro-ManNAc-ol', '1,5-Anhydro-Xyl-ol', '1,5-Anhydro-Xyl2F-ol', '1b-4', '1dAlt-ol']


In [13]:
# Let's look at 50 more keys

print(f"More Examples of 'other' keys: {other_keys[20:70]}")

More Examples of 'other' keys: ['1dEry-ol', '2,3-Anhydro-All', '2,3-Anhydro-Man', '2,3-Anhydro-Rib', '2,5-Anhydro-D-Alt-ol', '2,5-Anhydro-D-Alt3S-ol', '2,5-Anhydro-D-Tal', '2,5-Anhydro-Glc', '2,5-Anhydro-L-Man-ol', '2,5-Anhydro-Man', '2,5-Anhydro-Man-ol', '2,5-Anhydro-Man1S-ol', '2,5-Anhydro-Man3S-ol', '2,5-Anhydro-Man6S', '2,5-Anhydro-Tal-ol', '2,5-Anhydro-Tal6P', '2,6-Anhydro-Glc5NAc-ol', '2,6-Anhydro-L-Gul-ol', '2,6-Anhydro-L-Gul-onic', '2,6-Anhydro-Man-ol', '2,6-Anhydro-Tal5NAc-ol', '2,7-Anhydro-Kdo', '2,7-Anhydro-Kdof', '2dAraHexA', '3,6-Anhydro-Fruf', '3,6-Anhydro-Gal', '3,6-Anhydro-Gal2S', '3,6-Anhydro-Glc', '3,6-Anhydro-L-Gal', '3,6-Anhydro-L-Gal2Me', '3-Anhydro-Gal', '3-Anhydro-Gal2S', '3dFuc', '3dGal', '3dLyxHep-ulosaric', '4,7-Anhydro-Kdo', '4,7-Anhydro-KdoOPEtN', '4,8-Anhydro-Kdo', '4d8dNeu5Ac', '4dAraHex', '4dEry-ol', '4dFuc', '4dGal', '4dNeu5Ac', '4dThrHexNAcA4en', '4eLeg5Ac7Ac', '5dAraf', '5dAraf3Me', '5dLyxf3CFo', '5dLyxf3CMe']


In [20]:
# Let's explore those monosaccharide keys
monosaccharide = []
for key in loaded_embeddings.keys():
    if '-' in key and not any(char.isalpha() for char in key):
        pass # linkage_or_modification
    elif key[0].isalpha():
        monosaccharide.append(key)
    else:
        pass # other

print(f"Number of 'monosaccharide' keys: {len(monosaccharide)}")
print(f"Examples of 'monosaccharide' keys: {monosaccharide[:50]}") # Print the first 50

Number of 'monosaccharide' keys: 2407
Examples of 'monosaccharide' keys: ['Abe', 'Abe1PP', 'Abe2Ac', 'AbeOAc', 'Acarbose', 'AcefA', 'Aci5Ac7Ac', 'AcoNAc', 'All', 'All-ol', 'All1S2S3S4S', 'All2Ac3Ac', 'All2S3S4S', 'All3Ac', 'All6Ac', 'AllN', 'AllN1P', 'AllNAc', 'AllNAc6Me', 'AllOMe', 'Alt', 'AltA', 'AltA2N', 'AltA2S', 'AltAN', 'AltNAc', 'AltNAcA', 'AltNAcA1Prop', 'Altf', 'AltfOAc', 'Amikacin', 'Api', 'ApiOAc', 'ApiOMe-ol', 'Apif', 'Ara', 'Ara-ol', 'Ara1Cer2Ac', 'Ara1Me', 'Ara1N4P', 'Ara1P4N', 'Ara1PP', 'Ara1PP2NAc', 'Ara1PP4N', 'Ara1PP4NFo', 'Ara2Ac', 'Ara2Ac3Ac', 'Ara2Ac3Ac4Ac', 'Ara2Ac4Ac', 'Ara2Ac5P-ol']


In [21]:
# To be throughough, let's look at 50 Linkage or Modification keys as well
linkage_or_modification = []
for key in loaded_embeddings.keys():
    if '-' in key and not any(char.isalpha() for char in key):
        linkage_or_modification.append(key)
    elif key[0].isalpha():
        pass # monosaccharide
    else:
        pass # other

print(f"Number of 'linkage_or_modification' keys: {len(linkage_or_modification)}")
print(f"Examples of 'linkage_or_modification' keys: {linkage_or_modification[:50]}") # Print the first 50

Number of 'linkage_or_modification' keys: 36
Examples of 'linkage_or_modification' keys: ['-10', '-12', '-2', '-4', '-6', '-8', '1-1', '1-2', '1-3', '1-4', '1-5', '1-6', '1-?', '2-3', '2-4', '2-5', '2-6', '3-1', '3-5', '4-1', '4-5', '5-1', '5-2', '5-3', '5-4', '5-5', '5-6', '6-1', '6-3', '6-4', '?1-2', '?1-3', '?1-4', '?1-6', '?1-?', '?2-?']


### Load the glycowork library

I'll load the glycowork library and compare it to the keys in the embedding file

In [22]:
from glycowork.glycan_data import loader

glycowork_vocabulary = loader.lib

print(f"Number of items in glycowork vocabulary: {len(glycowork_vocabulary)}")
print(f"Example keys from glycowork vocabulary (first 20): {list(glycowork_vocabulary.keys())[:20]}")

Number of items in glycowork vocabulary: 2565
Example keys from glycowork vocabulary (first 20): ['!GlcNAc', '-10', '-12', '-2', '-4', '-6', '-8', '0dHex', '1,4-Anhydro-Gal-ol', '1,5-Anhydro-D-AltNAc-ol', '1,5-Anhydro-D-FucN-ol', '1,5-Anhydro-D-Rha4NAc-ol', '1,5-Anhydro-Gal-ol', '1,5-Anhydro-GalNAc-ol', '1,5-Anhydro-Glc-ol', '1,5-Anhydro-Glc-onic', '1,5-Anhydro-GlcN2S-ol', '1,5-Anhydro-GlcN2S6S-ol', '1,5-Anhydro-GlcNAc-ol', '1,5-Anhydro-GlcNAc-onic']


Nice, they seem to correspond one to one!

That saves me a lot of work down the line (Thanks Roman)

In [26]:
# let's look at one of the keys in the glycowork vocabulary to see what they return
example_glycowork_key = '-10'
if example_glycowork_key in glycowork_vocabulary:
    glycowork_value = glycowork_vocabulary[example_glycowork_key]
    print(f"Type of value for '{example_glycowork_key}': {type(glycowork_value)}")
    print(glycowork_value)   

Type of value for '-10': <class 'int'>
1


### Filter and Transform embeddings

## SweetNet copy from models.py for experimentation

In [1]:
# SweetNet class

from typing import Dict, Optional, Tuple, Union, Literal

import numpy as np
try:
    import torch
    import torch.nn.functional as F
    from torch_geometric.nn import GraphConv
    from torch_geometric.nn import global_mean_pool as gap
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda:0"
except ImportError:
  raise ImportError("<torch or torch_geometric missing; did you do 'pip install glycowork[ml]'?>")
from glycowork.glycan_data.loader import lib, download_model 

class SweetNet(torch.nn.Module):
    def __init__(self, lib_size: int, # number of unique tokens for graph nodes
                 num_classes: int = 1, # number of output classes (>1 for multilabel)
                 hidden_dim: int = 128 # dimension of hidden layers
                ) -> None:
        "given glycan graphs as input, predicts properties via a graph neural network"
        print("Using SweetNet from notebook cell!") # Check to see if I am running this in the notebook
        super(SweetNet, self).__init__()
        # Convolution operations on the graph
        self.conv1 = GraphConv(hidden_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.conv3 = GraphConv(hidden_dim, hidden_dim)

        # Node embedding
        self.item_embedding = torch.nn.Embedding(num_embeddings=lib_size+1, embedding_dim=hidden_dim)
        # Fully connected part
        self.lin1 = torch.nn.Linear(hidden_dim, 1024)
        self.lin2 = torch.nn.Linear(1024, 128)
        self.lin3 = torch.nn.Linear(128, num_classes)
        self.bn1 = torch.nn.BatchNorm1d(1024)
        self.bn2 = torch.nn.BatchNorm1d(128)
        self.act1 = torch.nn.LeakyReLU()
        self.act2 = torch.nn.LeakyReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, batch: torch.Tensor,
                inference: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # Getting node features
        x = self.item_embedding(x)
        x = x.squeeze(1)

        # Graph convolution operations
        x = F.leaky_relu(self.conv1(x, edge_index))
        x = F.leaky_relu(self.conv2(x, edge_index))
        x = F.leaky_relu(self.conv3(x, edge_index))
        x = gap(x, batch)

        # Fully connected part
        x = self.act1(self.bn1(self.lin1(x)))
        x_out = self.bn2(self.lin2(x))
        x = F.dropout(self.act2(x_out), p = 0.5, training = self.training)

        x = self.lin3(x).squeeze(1)

        if inference:
          return x, x_out
        else:
          return x



In [3]:
# Init_weights function

def init_weights(model: torch.nn.Module, # neural network for analyzing glycans
                mode: str = 'sparse', # initialization algorithm: 'sparse', 'kaiming', 'xavier'
                sparsity: float = 0.1 # proportion of sparsity after initialization
               ) -> None:
    "initializes linear layers of PyTorch model with a weight initialization"
    print("Using init_weights from notebook cell!") # Check to see if I am running this in the notebook
    if isinstance(model, torch.nn.Linear):
        if mode == 'sparse':
            torch.nn.init.sparse_(model.weight, sparsity = sparsity)
        elif mode == 'kaiming':
            torch.nn.init.kaiming_uniform_(model.weight)
        elif mode == 'xavier':
            torch.nn.init.xavier_uniform_(model.weight)
        else:
            print("This initialization option is not supported.")

In [4]:
# prep_model function

def prep_model(model_type: Literal["SweetNet", "LectinOracle", "LectinOracle_flex", "NSequonPred"], # type of model to create
              num_classes: int, # number of unique classes for classification
              libr: Optional[Dict[str, int]] = None, # dictionary of form glycoletter:index
              trained: bool = False, # whether to use pretrained model
              hidden_dim: int = 128 # hidden dimension for the model (SweetNet/LectinOracle only)
             ) -> torch.nn.Module: # initialized PyTorch model
    "wrapper to instantiate model, initialize it, and put it on the GPU"
    print("Using prep_model from notebook cell!") # Check to see if I am running this in the notebook
    if libr is None:
      libr = lib
    if model_type == 'SweetNet':
      model = SweetNet(len(libr), num_classes = num_classes, hidden_dim = hidden_dim)
      model = model.apply(lambda module: init_weights(module, mode = 'sparse'))
      if trained:
        if hidden_dim != 128:
          raise ValueError("Hidden dimension must be 128 for pretrained model")
        model_path = download_model("glycowork_sweetnet_species.pt")
        model.load_state_dict(torch.load(model_path, map_location = device, weights_only = True))
      model = model.to(device)
    elif model_type == 'LectinOracle':
      model = LectinOracle(len(libr), num_classes = num_classes, input_size_prot = int(10*hidden_dim))
      model = model.apply(lambda module: init_weights(module, mode = 'xavier'))
      if trained:
        model_path = download_model("glycowork_lectinoracle.pt")
        model.load_state_dict(torch.load(model_path, map_location = device, weights_only = True))
      model = model.to(device)
    elif model_type == 'LectinOracle_flex':
      model = LectinOracle_flex(len(libr), num_classes = num_classes)
      model = model.apply(lambda module: init_weights(module, mode = 'xavier'))
      if trained:
        model_path = download_model("glycowork_lectinoracle_flex.pt")
        model.load_state_dict(torch.load(model_path, map_location = device, weights_only = True))
      model = model.to(device)
    elif model_type == 'NSequonPred':
      model = NSequonPred()
      model = model.apply(lambda module: init_weights(module, mode = 'xavier'))
      if trained:
        model_path = download_model("NSequonPred_batch32.pt")
        model.load_state_dict(torch.load(model_path, map_location = device, weights_only = True))
      model = model.to(device)
    else:
      print("Invalid Model Type")
    return model
    

## Testing using same framework as iteration 0 (basic kingdom sweetnet)

In [5]:

# testing the modified SweetNet model on the GlycoWork dataset 
from glycowork.glycan_data.loader import df_species
from glycowork.ml.train_test_split import hierarchy_filter
from glycowork.ml.processing import split_data_to_train
from glycowork.ml import model_training

# silence the avalanche of "undefined" warnings
import warnings
from sklearn.exceptions import UndefinedMetricWarning
warnings.filterwarnings("ignore", category=UndefinedMetricWarning)

train_x, val_x, train_y, val_y, id_val, class_list, class_converter = hierarchy_filter(df_species,
                                                                                       rank = 'Kingdom')

dataloaders = split_data_to_train(train_x, val_x, train_y, val_y)




In [None]:
# Lets split out the training function so I don't have to load the data each time

model = prep_model('SweetNet', len(class_list))
optimizer_ft, scheduler, criterion = model_training.training_setup(model, 0.0005, num_classes = len(class_list))
model_ft = model_training.train_model(model, dataloaders, criterion, optimizer_ft, scheduler,
                   num_epochs = 10, return_metrics = True,)

Using prep_model from notebook cell!
Using SweetNet from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Using init_weights from notebook cell!
Epoch 0/9
----------
train Loss: 1.9425 Accuracy: 0.6660 MCC: 0.4531
val Loss: 1.4610 Accuracy: 0.7855 MCC: 0