In [None]:
import os
os.chdir('../..')

from sklearn.manifold import TSNE

from utils.config import*
from src.task_2.loading_helpers import*
from src.task_2.tokenizer import*
from src.task_2.inference import get_embeddings

import warnings
warnings.filterwarnings("ignore")


  from .autonotebook import tqdm as notebook_tqdm


# Load

In [2]:
# Data + config
config = load_config()
seed = config['seed']
data = load_file('data/glycan_embedding/df_glycan.pkl')

In [3]:
# Models
# TODO uniformize naming conventions between the two models
sweetnet_config = config['models']['sweetnet']
roberta_config = config['models']['roberta']
sweetnet = load_model(sweetnet_config['training']['save_dir'] + '/Sweetnet_Family.pt', 'SweetNet', config=sweetnet_config)
roberta = load_model(roberta_config['training']['output_dir'], 'RoBERTa', config=roberta_config)
# Roberta tokenizer
wrapper = HuggingFaceTokenizerWrapper()
wrapper.load(roberta_config['tokenizer']['path'])
tokenizer = wrapper.get_tokenizer()

# Embeddings

## Compute embeddings

In [164]:
os.makedirs('models/embeddings/RoBERTa', exist_ok=True)
embeddings_roberta = get_embeddings(data, roberta, tokenizer=tokenizer, save_path='models/embeddings/RoBERTa')

100%|██████████| 791/791 [1:25:41<00:00,  6.50s/it]


In [5]:
os.makedirs('models/embeddings/SweetNet', exist_ok=True)
embeddings_sweetnet = get_embeddings(data, sweetnet, save_path='models/embeddings/SweetNet')



## Plot embeddings

In [None]:
def plot_embeddings(embed:np.ndarray, data:pd.DataFrame, hue:str, limit:int = 5, errors=None, seed=42):

    assert(hue in data.columns)
    assert(embed.shape[0] == data.shape[0])
    if errors:
        data = data[~data['glycan'].isin(errors)].reset_index(drop=True)
    
    tsne_embeds = TSNE(n_components=2, random_state=seed).fit_transform(embed)
    df_tsne = pd.DataFrame(tsne_embeds, columns=['x', 'y'])  
    df_tsne['glycan'] = data['glycan'].tolist()

    # Select the most relevant categories to see the clusters
    df_tsne['hue'] = data[hue].tolist()
    df_tsne = df_tsne.explode('hue').drop_duplicates(subset=['glycan', 'hue']).reset_index(drop=True)
    top_hues = df_tsne['hue'].value_counts().nlargest(limit).index.tolist()
    df_tsne = df_tsne[df_tsne['hue'].isin(top_hues)].reset_index(drop=True)

    sns.set_theme(rc = {'figure.figsize':(10, 10)}, font_scale=2)
    fig = sns.scatterplot(data=df_tsne, x='x', y='y', hue=hue, palette='colorblind', s=40, rasterized=True)
    fig.set_title('TSNE of Glycan Embeddings')

    return tsne_embeds

In [None]:
plot_embeddings(embeddings_roberta, data, hue='Kingdom', limit=5, errors=errors)