## TSNE Analysis of VirProBERT Few Shot Classifier
### Dataset: EMBL mapping, Vertebrates, Non-IDV

**Models**: VirProBERT, Few-Shot Classifier

**Maximum Sequence Length**: segment len 256, stride 64

In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", ".."))
sys.path

['/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning',
 '/opt/conda/lib/python38.zip',
 '/opt/conda/lib/python3.8',
 '/opt/conda/lib/python3.8/lib-dynload',
 '',
 '/home/blessyantony/.local/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages/IPython/extensions',
 '/home/blessyantony/.ipython',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../..']

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import pandas as pd
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
from torch.utils.data import DataLoader
from datasets.collations.padding import Padding

from src.utils import constants, nn_utils, utils
from src.models.virprobert import VirProBERT
from src.models.baseline.nlp.transformer.transformer import TransformerEncoder
from datasets.protein_sequence_dataset import ProteinSequenceDataset

In [3]:
virprobert_model_file_path = os.path.join(os.getcwd(), "..", "..", "..", "output/raw/uniref90_embl_vertebrates_non_idv_t0.01_c5_seq_len_in_99.9percentile/20240826/host_multi/fine_tuning_hybrid_cls/mlm_tfenc_l6_h8_lr1e-4_uniref90viridae_msl256s64allemb_vs30cls_batchnorm_hybrid_attention_msl256s64ae_fnn_2l_d1024_lr1e-4_itr4.pth")
virprobert_few_shot_classifier_model_file_path = os.path.join(os.getcwd(), "..", "..", "..", "output/raw/uniref90_embl_vertebrates_non_idv/20240928/host_multi/few_shot_learning/fsl_tr_w3s5q10_te_w3s5q-1_e100b32_split70-10-20_hybrid-attention_sl256st64vs30cls_fnn_2l_d1024_lr1e-4_itr4.pth")

input_dir = os.path.join(os.getcwd(), "..", "..", "..", "input/data/uniref90/20240131")
input_file_names = ["uniref90_viridae_embl_hosts_pruned_metadata_species_vertebrates_w_seq_non_idv_lt_1_gte_0.05_prcnt_prevalence_seq_len_in_99prcntile.csv"]

In [4]:
virprobert_settings = {
    "mlm_encoder_settings": {
        "n_heads": 8,
        "depth": 6,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024,
        "vocab_size": constants.VOCAB_SIZE
    },
    "host_prediction_settings": {
        "n_mlp_layers": 2,
        "n_classes": 5,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024,
        "cls_token": True,
        "n_heads": 8,
        "stride": 64,
        "data_parallel": False
    }
}

sequence_settings= {
    "id_col": "uniref90_id",
    "sequence_col": "seq",
    "max_sequence_length": 256,
    "truncate": False,
    "split_sequence": False,
    "feature_type": "token",
    "batch_size": 4,
}

label_settings= {
    "label_col": "virus_host_name"
}
max_seq_len = sequence_settings["max_sequence_length"]
virprobert_settings["mlm_encoder_settings"]["max_seq_len"] = max_seq_len
virprobert_settings["host_prediction_settings"]["segment_len"] = max_seq_len

In [5]:
mlm_encoder_model = TransformerEncoder.get_transformer_encoder(virprobert_settings["mlm_encoder_settings"], cls_token=True)
virprobert_settings["host_prediction_settings"]["pre_trained_model"] = mlm_encoder_model

virprobert_model = VirProBERT.get_model(model_params=virprobert_settings["host_prediction_settings"])

TransformerEncoder(
  (embedding): EmbeddingLayer(
    (token_embedding): Embedding(30, 512)
    (positional_embedding): PositionalEncoding()
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (W_O): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForwardLayer(
          (W_1): Linear(in_features=512, out_features=1024, bias=True)
          (W_2): Linear(in_features=1024, out_features=512, bias=True)
        )
        (residual_connections): ModuleList(
          (0): ResidualConnectionLayer(
            (layer_norm): LayerNormalization()
          )
          (1): ResidualConnectionLayer(
            (layer_norm): LayerNormalization()
          )
        )
    

In [6]:
virprobert_model.load_state_dict(torch.load(virprobert_model_file_path, map_location=nn_utils.get_device()))
virprobert_model.eval()

VirProBERT(
  (linear_ip): Linear(in_features=512, out_features=1024, bias=True)
  (linear_hidden): Linear(in_features=1024, out_features=1024, bias=True)
  (linear_hidden_n): ModuleList(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (linear_op): Linear(in_features=1024, out_features=5, bias=True)
  (batch_norm_ip): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch_norm_hidden): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batch_norm_hidden_n): ModuleList(
    (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pre_trained_model): TransformerEncoder(
    (embedding): EmbeddingLayer(
      (token_embedding): Embedding(30, 512)
      (positional_embedding): PositionalEncoding()
    )
    (en

In [7]:
def filter_dataset(df, filter_threshold, per_class_samples=0):
    label_col = label_settings["label_col"]
    print(f"Before filter df size = {df.shape}")
    label_count_map = dict(df[label_col].value_counts())
    print(f"Before filter # of unique values = {len(label_count_map)}")
    label_count_map = dict(filter(lambda x: x[1] >= filter_threshold, label_count_map.items()))
    print(f"After filter on min of samples # of unique values = {len(label_count_map)}")
    
    selected_labels = list(label_count_map.keys())
    df = df[df[label_col].isin(selected_labels)]
    print(f"After filter df size = {df.shape}")
    
    df = pd.concat([df[df[label_col] == k][:min(v + 1, per_class_samples)] for k, v in label_count_map.items()])
    print(f"After filter on per_class_samples df size = {df.shape}")
    
    return df


# read datasets using config properties
def read_dataset(input_dir, input_file_names, cols):
    datasets = []
    for input_file_name in input_file_names:
        input_file_path = os.path.join(input_dir, input_file_name)
        df = pd.read_csv(input_file_path, usecols=cols)
        print(f"input file: {input_file_path}, size = {df.shape}")
        datasets.append(df)

    df = pd.concat(datasets)
    print(f"Size of input dataset = {df.shape}")
    return df


def get_token_dataset_loader(df, sequence_settings, label_col):
    seq_col = sequence_settings["sequence_col"]
    batch_size = sequence_settings["batch_size"]
    max_seq_len = sequence_settings["max_sequence_length"]
    truncate = sequence_settings["truncate"]
    split_sequence = sequence_settings["split_sequence"]

    dataset = ProteinSequenceDataset(df, seq_col, label_col, truncate, max_seq_len)
    collate_func = Padding(max_seq_len)       

    return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_func)

def load_dataset(input_dir, input_file_names, sequence_settings, filter_threshold, per_class_samples):
    df = read_dataset(input_dir, input_file_names, cols=[sequence_settings["sequence_col"], label_settings["label_col"]])
    df = filter_dataset(df, filter_threshold, per_class_samples)
    label_idx_map, idx_label_map = utils.get_label_vocabulary(list(df[label_settings["label_col"]].unique()))
    df.replace({label_settings["label_col"]:label_idx_map}, inplace=True)
    dataset_loader = get_token_dataset_loader(df, sequence_settings, label_settings["label_col"])
    print(df.head())
    return df, dataset_loader, idx_label_map

def compute_dataset_representations(model, dataset_loader):
    model.eval()
    seq_dfs = []
    for _, record in enumerate(dataset_loader):
        seq, label = record
        seq_encoding = model.get_embedding(seq)
        seq_df = pd.DataFrame(seq_encoding.squeeze().cpu().detach().numpy())
        seq_df["label"] = label.squeeze().cpu().detach().numpy()
        seq_dfs.append(seq_df)
    df = pd.concat(seq_dfs)
    print(df.shape)
    return df

def print_dataset_loader(dataset_loader):
    sequence, label = next(iter(dataset_loader))
    print(sequence.shape)
    print(sequence)
    print(label.shape)
    print(label)

In [8]:
df, dataset_loader, idx_label_map = load_dataset(input_dir, input_file_names, sequence_settings, filter_threshold=15, per_class_samples=5)
print_dataset_loader(dataset_loader)
len(idx_label_map)

input file: /home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../input/data/uniref90/20240131/uniref90_viridae_embl_hosts_pruned_metadata_species_vertebrates_w_seq_non_idv_lt_1_gte_0.05_prcnt_prevalence_seq_len_in_99prcntile.csv, size = (11116, 2)
Size of input dataset = (11116, 2)
Before filter df size = (11116, 2)
Before filter # of unique values = 143
After filter on min of samples # of unique values = 143
After filter df size = (11116, 2)
After filter on per_class_samples df size = (715, 2)
      virus_host_name                                                seq
2026               33  MKMFVLVGFVLFVVASATTTVNINVTTNGNHNVTSSNSNVLLQNRT...
2027               33  MKTQLYILILYFLGVSSSQETTALLDPDRFCLQTDFSRILVFPKFR...
2028               33  MKTLVSICFFITLFILTNSDPSCYDGLVENSRKNLDRPNSLAAYDL...
2029               33  MGIHALNYIASNFETDDLVPTLFGACGVFAFLIIIGTVLFVCSGRM...
2030               33  MNSTLLVISNPENQFTIDFILSGYINNTHYSIIVKDIKEESDGRFD...
torch.Size([4, 2372])
tensor([

143

In [None]:
virprobert_rep_df = compute_dataset_representations(virprobert_model, dataset_loader)

In [None]:
df[label_settings["label_col"]].value_counts()