In [None]:
import os
os.chdir('../..')
from tqdm import tqdm
import pickle
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE

from utils.config import*
from src.task_2.helpers import*
from src.task_2.tokenizer import*

import warnings
warnings.filterwarnings("ignore")


# Load

In [2]:
# Data + config
config = load_config()
seed = config['seed']
data = load_file('data/glycan_embedding/df_glycan.pkl')

In [3]:
# Models
# TODO uniformize naming conventions between the two models
sweetnet_config = config['models']['sweetnet']
roberta_config = config['models']['roberta']
sweetnet = load_model(sweetnet_config['training']['save_dir'] + '/Sweetnet_Family.pt', 'SweetNet', config=sweetnet_config)
roberta = load_model(roberta_config['training']['output_dir'], 'RoBERTa', config=roberta_config)
# Roberta tokenizer
wrapper = HuggingFaceTokenizerWrapper()
wrapper.load(roberta_config['tokenizer']['path'])
tokenizer = wrapper.get_tokenizer()

# Embeddings

In [7]:
encs = tokenizer(data['glycan'].tolist(), padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')

In [30]:
data_test = {'input_ids' : encs['input_ids'][0].unsqueeze(0).to(roberta.device),\
              'attention_mask' : encs['attention_mask'][0].unsqueeze(0).to(roberta.device)}

In [64]:
with torch.no_grad():
    output = roberta(**data_test)

In [74]:
# retrieve the last hidden state
embedding = output.hidden_states[-1].cpu().numpy()
attention_mask = data_test['attention_mask'].unsqueeze(-1).numpy()

In [85]:
(embedding*attention_mask).sum(axis=1).squeeze().shape

(256,)

In [89]:
sum_mask = (attention_mask.sum(axis=1)).squeeze()

In [90]:
sum_mask 

array(12)

In [None]:
# Build sentence embedding from mask word embedding


## Compute embeddings

In [None]:

def get_embeddings(data: pd.DataFrame, model, tokenizer=None, save_path=None) -> pd.DataFrame: 

    from glycowork.ml.models import SweetNet
    from glycowork.ml.inference import glycans_to_emb
    from transformers import RobertaForMaskedLM

    assert('glycan' in data.columns)

    if isinstance(model, SweetNet):
        embeddings = glycans_to_emb(data['glycan'].values, model)
        if save_path:
            dt = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
            with open(os.path.join(save_path, f'embeddings_{dt}.pkl'), 'wb') as f:
                pickle.dump(embeddings, f)
        return embeddings
    
    elif isinstance(model, RobertaForMaskedLM):

        assert(tokenizer is not None)

        errors_g, embeddings = [], []
        encodings = tokenizer(data['glycan'].tolist(),\
                              truncation=True,\
                              padding='max_length',\
                              max_length=tokenizer.model_max_length,\
                             return_tensors='pt')
        encodings = {k: v.to(DEVICE) for k, v in encodings.items()}

        for g, encoding in tqdm(zip(data['glycan'].values, encodings['input_ids'].tolist())):
            try : 
                with torch.no_grad():
                    embed = model(**encoding)
                last_hidden_state = embed[-1][-1]
                # Average token embedding to build sequence embedding
                glycan_embed = last_hidden_state.squeeze(0).mean(dim=0).numpy()
                embeddings.append(glycan_embed)
            except Exception as e:
                print(f"Error with {g} : {e}")
                errors_g.append(g)
        
        if save_path:
            dt = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
            with open(os.path.join(save_path, f'embeddings_{dt}.pkl'), 'wb') as f:
                pickle.dump(embeddings, f)
        
        return embeddings, errors_g
    else:
        raise NotImplementedError()


In [5]:
os.makedirs('models/embeddings/SweetNet', exist_ok=True)
embeddings_sweetnet = get_embeddings(data, sweetnet, save_path='models/embeddings/SweetNet')



In [6]:
embeddings_sweetnet

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,118,119,120,121,122,123,124,125,126,127
0,-0.239049,-1.205554,2.505064,-2.204096,-1.664683,-1.788511,-1.561113,-0.461428,-0.637427,-0.611985,...,-1.380926,1.291588,3.231824,-1.089608,-0.278298,-1.629229,0.745970,-1.162513,-0.956747,-0.602077
1,-0.528192,-1.131496,1.589969,-1.757410,-0.918702,-0.714965,-0.662135,-1.089974,-0.566921,-0.317467,...,-0.387416,1.024946,2.244558,-0.733578,-1.347370,-1.255779,-0.328468,2.809101,-0.764077,2.894966
2,2.133796,-0.015784,-0.860791,-0.323467,-0.193998,-0.250146,-0.160487,-0.072700,-0.512950,0.287123,...,-0.352306,1.462230,0.445271,-0.385861,-0.960686,-0.577423,-0.462516,-1.020753,-0.526767,2.492939
3,0.024131,-0.798205,1.159252,-0.872954,0.882341,-0.346515,-1.095960,0.972724,-1.133452,-1.039080,...,0.089598,-0.750147,1.849371,-1.229478,-1.216052,-0.970517,-0.432945,1.118007,-0.068015,-0.280603
4,-0.293180,-1.272951,-0.727258,-0.122571,-0.034667,1.133042,-0.651254,2.912364,-0.884565,0.425860,...,0.362923,0.245477,0.673627,-1.123325,-0.799653,-1.161224,0.962933,-0.421276,-1.176906,0.747134
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
50584,0.182961,2.211320,1.805225,1.480693,1.445918,0.361411,-0.356336,0.911873,-0.327498,1.320318,...,1.688902,1.070147,1.047303,1.543294,0.785552,1.175742,1.371868,1.143094,0.849715,0.742523
50585,0.357590,0.946273,1.268209,0.283762,3.105134,2.102657,-0.898391,1.466275,-0.227159,1.035146,...,0.801446,0.807440,1.796256,0.166234,-0.323527,0.401960,1.001705,0.834673,0.467638,0.175889
50586,0.230199,1.042811,1.529540,0.703689,2.225735,1.142767,-0.695822,1.781696,0.190462,0.516006,...,0.951339,0.521067,1.875985,0.725284,0.020722,0.748264,0.863091,0.745802,0.791014,0.215019
50587,-0.028394,1.833864,1.221431,0.805910,1.913851,0.948037,-0.490566,1.633394,-0.326149,-0.071438,...,1.729324,1.294404,1.559716,1.259737,0.920063,1.105738,0.657258,1.619102,0.777057,0.887049


In [None]:
os.makedirs('models/embeddings/RoBERTa', exist_ok=True)

## Plot embeddings

In [None]:
def plot_embeddings(embed:np.ndarray, data:pd.DataFrame, hue:str, limit:int = 5, errors=None, seed=42):

    assert(hue in data.columns)
    assert(embed.shape[0] == data.shape[0])
    if errors:
        data = data[~data['glycan'].isin(errors)].reset_index(drop=True)
    
    tsne_embeds = TSNE(n_components=2, random_state=seed).fit_transform(embed)
    df_tsne = pd.DataFrame(tsne_embeds, columns=['x', 'y'])  
    df_tsne['glycan'] = data['glycan'].tolist()

    # Select the most relevant categories to see the clusters
    df_tsne['hue'] = data[hue].tolist()
    df_tsne = df_tsne.explode('hue').drop_duplicates(subset=['glycan', 'hue']).reset_index(drop=True)
    top_hues = df_tsne['hue'].value_counts().nlargest(limit).index.tolist()
    df_tsne = df_tsne[df_tsne['hue'].isin(top_hues)].reset_index(drop=True)

    sns.set_theme(rc = {'figure.figsize':(10, 10)}, font_scale=2)
    fig = sns.scatterplot(data=df_tsne, x='x', y='y', hue=hue, palette='colorblind', s=40, rasterized=True)
    fig.set_title('TSNE of Glycan Embeddings')

    return tsne_embeds

In [None]:
plot_embeddings(embeddings_roberta, data, hue='Kingdom', limit=5, errors=errors)