## TSNE Analysis of VirProBERT Few Shot Classifier
### Dataset: EMBL mapping, Vertebrates, Non-IDV

**Models**: VirProBERT, Few-Shot Classifier

**Maximum Sequence Length**: 2048

In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", ".."))
sys.path

['/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning',
 '/opt/conda/lib/python38.zip',
 '/opt/conda/lib/python3.8',
 '/opt/conda/lib/python3.8/lib-dynload',
 '',
 '/home/blessyantony/.local/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages/IPython/extensions',
 '/home/blessyantony/.ipython',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../..']

In [2]:
from src.utils import utils, nn_utils, dataset_utils

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
import pandas as pd
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

from src.utils import utils, nn_utils
from src.models.nlp.transformer import transformer
from src.transfer_learning.fine_tuning import host_prediction

In [3]:
virprobert_model_file_path = os.path.join(os.getcwd(), "..", "..", "..", "output/raw/uniref90_embl_vertebrates_non_idv_t0.01_c5/20240611/host_multi/fine_tuning/host_prediction_fnn_2l_d1024_lr1e-4_itr0.pth")
virprobert_few_shot_classifier_model_file_path = os.path.join(os.getcwd(), "..", "..", "..", "output/raw/uniref90_embl_vertebrates_non_idv/20240625/host_multi/few_shot_learning/fsl_tr_w3s5q10_te_w3s5q-1_e25b32_msl2048_split70-10-20_virprobert_itr4.pth")

input_dir = os.path.join(os.getcwd(), "..", "..", "..", "input/data/uniref90/20240131")
input_file_names = ["uniref90_viridae_embl_hosts_pruned_metadata_species_vertebrates_w_seq_non_idv_lt_1percent_prevalence.csv"]

In [4]:
virprobert_settings = {
    "mlm_encoder_settings": {
        "embedding": "linear",
        "n_heads": 8,
        "depth": 6,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024,
        "vocab_size": 28
    },
    "host_prediction_settings": {
        "depth": 2,
        "n_classes": 5,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024
    }
}

sequence_settings= {
    "id_col": "uniref90_id",
    "sequence_col": "seq",
    "max_sequence_length": 2048, # 6630 # 1024 # 1115
    "truncate": True,
    "split_sequence": False,
    "pad_token_val": 0,
    "feature_type": "token",
    "batch_size": 8,
}

label_settings= {
    "label_col": "virus_host_name"
}
max_seq_len = sequence_settings["max_sequence_length"]
virprobert_settings["mlm_encoder_settings"]["max_seq_len"] = max_seq_len
virprobert_settings["host_prediction_settings"]["max_seq_len"] = max_seq_len

In [5]:
mlm_encoder_model = transformer.get_transformer_encoder(virprobert_settings["mlm_encoder_settings"])
virprobert_settings["host_prediction_settings"]["pre_trained_model"] = mlm_encoder_model
pre_trained_model = host_prediction.get_host_prediction_model(virprobert_settings["host_prediction_settings"])

TransformerEncoder(
  (embedding): EmbeddingLayer(
    (token_embedding): Embedding(28, 512, padding_idx=0)
    (positional_embedding): PositionalEncoding()
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (W_O): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForwardLayer(
          (W_1): Linear(in_features=512, out_features=1024, bias=True)
          (W_2): Linear(in_features=1024, out_features=512, bias=True)
        )
        (residual_connections): ModuleList(
          (0): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
          (1): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
        )
 

In [6]:
pre_trained_model.load_state_dict(torch.load(virprobert_model_file_path, map_location=nn_utils.get_device()))

<All keys matched successfully>

In [7]:
def filter_dataset(df, filter_threshold, per_class_samples=0):
    label_col = label_settings["label_col"]
    print(f"Before filter df size = {df.shape}")
    label_count_map = dict(df[label_col].value_counts())
    print(f"Before filter # of unique values = {len(label_count_map)}")
    label_count_map = dict(filter(lambda x: x[1] >= filter_threshold, label_count_map.items()))
    print(f"After filter on min of samples # of unique values = {len(label_count_map)}")
    
    selected_labels = list(label_count_map.keys())
    df = df[df[label_col].isin(selected_labels)]
    print(f"After filter df size = {df.shape}")
    
    df = pd.concat([df[df[label_col] == k][:min(v + 1, per_class_samples)] for k, v in label_count_map.items()])
    print(f"After filter on per_class_samples df size = {df.shape}")
    
    return df

def load_dataset(input_dir, input_file_names, sequence_settings, filter_threshold, per_class_samples):
    df = dataset_utils.read_dataset(input_dir, input_file_names, cols=[sequence_settings["sequence_col"], label_settings["label_col"]])
    df = filter_dataset(df, filter_threshold, per_class_samples)
    label_idx_map, idx_label_map = utils.get_label_vocabulary(list(df[label_settings["label_col"]].unique()))
    df.replace({label_settings["label_col"]:label_idx_map}, inplace=True)
    dataset_loader = dataset_utils.get_dataset_loader(df, sequence_settings, label_settings["label_col"])
    print(df.head())
    return df, dataset_loader, idx_label_map

def compute_dataset_representations(model, dataset_loader):
    model.eval()
    seq_dfs = []
    for _, record in enumerate(dataset_loader):
        seq, label = record
        seq_encoding = model.get_embedding(seq)
        seq_df = pd.DataFrame(seq_encoding.squeeze().cpu().detach().numpy())
        seq_df["label"] = label.squeeze().cpu().detach().numpy()
        seq_dfs.append(seq_df)
    df = pd.concat(seq_dfs)
    print(df.shape)
    return df

def print_dataset_loader(dataset_loader):
    sequence, label = next(iter(dataset_loader))
    print(sequence.shape)
    print(sequence)
    print(label.shape)
    print(label)

In [8]:
df, dataset_loader, idx_label_map = load_dataset(input_dir, input_file_names, sequence_settings, filter_threshold=15, per_class_samples=5)
print_dataset_loader(dataset_loader)
len(idx_label_map)

input file: /home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/few_shot_learning/../../../input/data/uniref90/20240131/uniref90_viridae_embl_hosts_pruned_metadata_species_vertebrates_w_seq_non_idv_lt_1percent_prevalence.csv, size = (16074, 2)
Size of input dataset = (16074, 2)
Before filter df size = (16074, 2)
Before filter # of unique values = 1299
After filter on min of samples # of unique values = 192
After filter df size = (12112, 2)
After filter on per_class_samples df size = (960, 2)
      virus_host_name                                                seq
3201               42  MKMFVLVGFVLFVVASATTTVNINVTTNGNHNVTSSNSNVLLQNRT...
3202               42  MKTQLYILILYFLGVSSSQETTALLDPDRFCLQTDFSRILVFPKFR...
3203               42  MKTLVSICFFITLFILTNSDPSCYDGLVENSRKNLDRPNSLAAYDL...
3204               42  MGIHALNYIASNFETDDLVPTLFGACGVFAFLIIIGTVLFVCSGRM...
3205               42  MNSTLLVISNPENQFTIDFILSGYINNTHYSIIVKDIKEESDGRFD...
torch.Size([8, 2048])
tensor([[13.,  7.,  3.,  ...,  0.,  0

192

In [None]:
virprobert_rep_df = compute_dataset_representations(pre_trained_model, dataset_loader)

In [None]:
df[label_settings["label_col"]].value_counts()