## Interpretation of Hybrid VirProBERT attention values for multiclass classification

### Trainining Dataset: UNiRef90  - Coronaviridae Spike protein sequences aligned using MAFFT
### Interpretation: SARS-CoV-2 Spike protein sequences

**Models**: hybrid_attention_msl256_s64_fnn_2l_d1024_lr1e-4

**Positional Embedding**: Sin-Cos

**Maximum Sequence Length**: -

**Classification**: Multi-class

**\# classes**: 8

In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", ".."))
sys.path.append(os.path.join(os.getcwd(), ".."))
sys.path

['/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/sarscov2-aligned',
 '/opt/conda/lib/python38.zip',
 '/opt/conda/lib/python3.8',
 '/opt/conda/lib/python3.8/lib-dynload',
 '',
 '/home/blessyantony/.local/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages/IPython/extensions',
 '/home/blessyantony/.ipython',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/sarscov2-aligned/../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/sarscov2-aligned/../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/sarscov2-aligned/../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/sarscov2-aligned/..']

In [2]:
from models.nlp.transformer import transformer
from models.nlp.hybrid import transformer_attention
from datasets.protein_sequence_dataset import ProteinSequenceDataset
from src.utils import utils, dataset_utils, nn_utils, constants

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
cmap = sns.color_palette("vlag", as_cmap=True)

from sklearn.metrics import roc_curve, accuracy_score, f1_score, auc, precision_recall_curve
from statistics import mean

# from captum.attr import LayerIntegratedGradients, TokenReferenceBase, LayerGradientXActivation, LayerDeepLift, LayerLRP

In [3]:
label_groupings = {
                    "Chicken": [ "gallus gallus" ],
                    "Human": [ "homo sapiens" ],
                    "Cat": [ "felis catus" ],
                    "Pig": [ "sus scrofa" ],
                    "Gray wolf": [ "canis lupus" ],
                    "Horshoe bat": ["rhinolophus sp."],
                    "Ferret": ["mustela putorius"],
                    "Chinese rufous horseshoe bat": ["rhinolophus sinicus"]
                }


sequence_settings =  {
    "sequence_col": "aligned_seq",
    "batch_size": 16,
    "max_sequence_length": 256,
    "truncate": False,
    "split_sequence": False,
    "feature_type": "token"
}

label_settings = {
    "label_col": "virus_host_name",
    "exclude_labels": [ "nan"],
    "label_groupings":  label_groupings
}

model = {
    "pre_train_settings": {
        "embedding": "linear",
        "n_heads": 8,
        "depth": 6,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024,
        "max_seq_len": 256,
    },
    "loss": "FocalLoss",
    "n_heads": 8,
    "depth": 2,
    "stride": 64,
    "n_classes": 8,
    "input_dim": 512, # input embedding dimension
    "hidden_dim": 1024
}

### Load the datasets

In [4]:
def print_dataset_loader(dataset_loader):
    print()
    sequence, label = next(iter(dataset_loader))
    print(f"Sequence tensor size = {sequence.shape}")
    print(f"Sequence = {sequence}")
    print(f"Label tensor size = {label.shape}")
    print(f"Label = {label}")

In [5]:
input_file_path = os.path.join(os.getcwd(), "..", "..", "..", "..", "input/data/coronaviridae/20240313/uniref/alignment/coronaviridae_s_uniref90_embl_hosts_pruned_metadata_corrected_species_virus_host_vertebrates_w_seq_t0.01_c8_aligned.csv")
uniref90_coronaviridae_aligned_df = pd.read_csv(input_file_path)
wiv04_seq_df = uniref90_coronaviridae_aligned_df[uniref90_coronaviridae_aligned_df["uniref90_id"] == "WIV04"]
uniref90_coronaviridae_aligned_df

Unnamed: 0,uniref90_id,aligned_seq,seq,virus_name,virus_host_name,human_binary_label
0,WIV04,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,WIV04(MN996528.1) Wuhan variant index virus,homo sapiens,homo sapiens
1,UniRef90_A0A7U3RIT3,--------------MFVFLVLVPLVSS--------Q----------...,MFVFLVLVPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
2,UniRef90_A0A7U3HGG2,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
3,UniRef90_A0A7U3EEN6,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
4,UniRef90_A0A7U3HDM5,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
...,...,...,...,...,...,...
677,UniRef90_S5FZ76,---------------------------TLKQ---------------...,TLKQCDASAGYYSSSPIRPSDGVHSVTGFYRPVKTCCIKYTYPSNT...,Infectious bronchitis virus,gallus gallus,NOT homo sapiens
678,UniRef90_U5WLM9,--------------MLLLVTLFGLASG-------------------...,MLLLVTLFGLASGCSLPLTVSCPRGLPFTLQINTTSVTVEWYRVSP...,Sarbecovirus,rhinolophus sinicus,NOT homo sapiens
679,UniRef90_A0A169QA14,--MILHF-IMKVMPILIMVVFILL----------------------...,MILHFIMKVMPILIMVVFILLVYTNTHSSEWLLLFYFLISGVFCLY...,Infectious bronchitis virus,gallus gallus,NOT homo sapiens
680,UniRef90_E7DBM7,----------------------------------------------...,CSRRQFENYNQIEKVHVH,Feline coronavirus,felis catus,NOT homo sapiens


In [6]:
index_label_map, dataset_loader = dataset_utils.load_dataset_with_df(uniref90_coronaviridae_aligned_df, sequence_settings, label_settings, label_col=label_settings["label_col"], classification_type="multi")
print_dataset_loader(dataset_loader)

Grouping labels using config : {'Chicken': ['gallus gallus'], 'Human': ['homo sapiens'], 'Cat': ['felis catus'], 'Pig': ['sus scrofa'], 'Gray wolf': ['canis lupus'], 'Horshoe bat': ['rhinolophus sp.'], 'Ferret': ['mustela putorius'], 'Chinese rufous horseshoe bat': ['rhinolophus sinicus']}
label_idx_map={'Cat': 0, 'Chicken': 1, 'Chinese rufous horseshoe bat': 2, 'Ferret': 3, 'Gray wolf': 4, 'Horshoe bat': 5, 'Human': 6, 'Pig': 7}
idx_label_map={0: 'Cat', 1: 'Chicken', 2: 'Chinese rufous horseshoe bat', 3: 'Ferret', 4: 'Gray wolf', 5: 'Horshoe bat', 6: 'Human', 7: 'Pig'}

Sequence tensor size = torch.Size([16, 2418])
Sequence = tensor([[27., 27., 27.,  ...,  0.,  0.,  0.],
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 27., 27., 27.],
        ...,
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 27., 27., 27.]], dtype=torch.float64)
Label tensor size = torch.Size([16])
Label = tensor([1

### Load the pre-trained and fine-tuned model

In [7]:
pre_train_encoder_settings = model["pre_train_settings"]
pre_train_encoder_settings["vocab_size"] = constants.VOCAB_SIZE
pre_trained_encoder_model = transformer.get_transformer_encoder(pre_train_encoder_settings)

TransformerEncoder(
  (embedding): EmbeddingLayer(
    (token_embedding): Embedding(28, 512, padding_idx=0)
    (positional_embedding): PositionalEncoding()
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (W_O): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForwardLayer(
          (W_1): Linear(in_features=512, out_features=1024, bias=True)
          (W_2): Linear(in_features=1024, out_features=512, bias=True)
        )
        (residual_connections): ModuleList(
          (0): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
          (1): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
        )
 

In [8]:
model["pre_trained_model"] = pre_trained_encoder_model
model["chunk_len"] = pre_train_encoder_settings["max_seq_len"]
prediction_model = transformer_attention.get_model(model)

VirProBERT(
  (pre_trained_model): TransformerEncoder(
    (embedding): EmbeddingLayer(
      (token_embedding): Embedding(28, 512, padding_idx=0)
      (positional_embedding): PositionalEncoding()
    )
    (encoder): Encoder(
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): MultiHeadAttention(
            (W_Q): Linear(in_features=512, out_features=512, bias=True)
            (W_K): Linear(in_features=512, out_features=512, bias=True)
            (W_V): Linear(in_features=512, out_features=512, bias=True)
            (W_O): Linear(in_features=512, out_features=512, bias=True)
          )
          (feed_forward): FeedForwardLayer(
            (W_1): Linear(in_features=512, out_features=1024, bias=True)
            (W_2): Linear(in_features=1024, out_features=512, bias=True)
          )
          (residual_connections): ModuleList(
            (0): ResidualConnectionLayer(
              (norm): NormalizationLayer()
            )
            (1): ResidualCo

In [9]:
model_path = os.path.join(os.getcwd(), "..", "..", "..", "..", "output/raw/coronaviridae_s_prot_uniref90_embl_vertebrates_aligned_t0.01_c8/20240701/host_multi/fine_tuning_hybrid/mlm_tfenc_l6_h8_lr1e-4_uniref90viridae_msl256_hybrid_attention_s64_fnn_2l_d1024_lr1e-4_itr0.pth")
prediction_model.load_state_dict(torch.load(model_path, map_location=nn_utils.get_device()))

<All keys matched successfully>

### t-SNE Analysis

In [10]:
def compute_embeddings(model, dataset_loader):
    model.eval()
    seq_dfs = []
    for _, record in enumerate(dataset_loader):
        seq, label = record
        output = model(seq)
        # embedding = value for each dimension = mean of the dimensional values of all tokens in the input sequence
        seq_encoding = model.embedding
        seq_df = pd.DataFrame(seq_encoding.squeeze().cpu().detach().numpy())
        seq_df["label"] = label.squeeze().cpu().detach().numpy()
        seq_dfs.append(seq_df)
    df = pd.concat(seq_dfs)
    print(df.shape)
    return df

def view_tsne_representation(rep_df, index_label_map):
    columns = rep_df.columns
    print(columns)
    X = rep_df[range(512)]
    tsne_model = TSNE(n_components=2, verbose=1, init="pca", learning_rate="auto").fit(X)
    X_tsne_emb = pd.DataFrame(tsne_model.fit_transform(X))
    print(X_tsne_emb.shape)
    print(X_tsne_emb)
    X_tsne_emb["label"] = rep_df["label"].values
    X_tsne_emb["label"] = X_tsne_emb["label"].map(index_label_map)
    
    sns.scatterplot(data = X_tsne_emb, x=0, y=1, hue="label")
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.show()
    return tsne_model, X_tsne_emb

In [11]:
# emb_df = compute_embeddings(prediction_model, dataset_loader)

### Attention value Analysis

In [12]:
def analyze_attention_of_sequence(model, seq):
    print(f"sequence length = {seq_len}")
    model.eval()
    output = model(seq)
    

In [13]:
wiv04_seq_df

Unnamed: 0,uniref90_id,aligned_seq,seq,virus_name,virus_host_name,human_binary_label
0,WIV04,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,WIV04(MN996528.1) Wuhan variant index virus,homo sapiens,homo sapiens


In [14]:
len(wiv04_seq_df["aligned_seq"][0])

2418

In [60]:
sequence_settings["batch_size"] = 1
sequence_settings["max_sequence_length"] = 256

_, wiv04_seq_df_dataset_loader = dataset_utils.load_dataset_with_df(wiv04_seq_df, sequence_settings, label_settings, label_col=label_settings["label_col"], classification_type="multi")

Grouping labels using config : {'Chicken': ['gallus gallus'], 'Human': ['homo sapiens'], 'Cat': ['felis catus'], 'Pig': ['sus scrofa'], 'Gray wolf': ['canis lupus'], 'Horshoe bat': ['rhinolophus sp.'], 'Ferret': ['mustela putorius'], 'Chinese rufous horseshoe bat': ['rhinolophus sinicus']}
label_idx_map={'Cat': 0, 'Chicken': 1, 'Chinese rufous horseshoe bat': 2, 'Ferret': 3, 'Gray wolf': 4, 'Horshoe bat': 5, 'Human': 6, 'Pig': 7}
idx_label_map={0: 'Cat', 1: 'Chicken', 2: 'Chinese rufous horseshoe bat', 3: 'Ferret', 4: 'Gray wolf', 5: 'Horshoe bat', 6: 'Human', 7: 'Pig'}


In [16]:
#prediction_model.eval()
#output = prediction_model(next(iter(wiv04_seq_df_dataset_loader)))

In [17]:
seq, label = next(iter(wiv04_seq_df_dataset_loader))

In [18]:
seq.shape

torch.Size([1, 256])

In [19]:
seq

tensor([[27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27.,
         13., 14., 22., 14., 11., 22., 11., 11., 15., 11., 22., 17., 17., 27.,
         27., 27., 27., 27., 27., 27., 27.,  6., 27., 27., 27., 27., 27., 27.,
         27., 27., 27., 27., 27., 27., 27., 27., 27., 27.,  5., 22.,  3., 27.,
         27., 27., 27., 27., 27., 27., 27., 11., 19., 19.,  2., 19.,  6., 11.,
         15., 15.,  1., 21., 27., 27., 27., 27., 19.,  3., 17., 14., 19.,  2.,
          8., 22., 21., 21., 15.,  4., 12., 22., 14.,  2., 17., 17., 22., 11.,
          9., 17., 19.,  6.,  4., 11., 14., 11., 15., 14., 27., 27., 27., 27.,
         27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 14., 17.,
          3., 22., 19., 20., 14.,  9.,  1., 10.,  9., 22., 17.,  8., 19., 27.,
         27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27., 27.,
          3.,  8., 19., 12.,  2., 14.,  4.,  3., 15., 22., 11., 15., 14.,  3.,
         27.,  4.,  8., 22., 21., 14.,  1., 17., 19.

In [20]:
sequence_settings

{'sequence_col': 'aligned_seq',
 'batch_size': 1,
 'max_sequence_length': 256,
 'truncate': False,
 'split_sequence': False,
 'feature_type': 'token'}

In [28]:
len(wiv04_seq_df_dataset_loader.dataset.data[wiv04_seq_df_dataset_loader.dataset.sequence_col][0])

2418

In [61]:
wiv04_seq_df_dataset_loader.dataset.__getitem__(0)[0].shape

torch.Size([2418])

In [36]:
next(iter(dataset_loader))[0].shape

torch.Size([16, 2418])

In [44]:
next(iter(wiv04_seq_df_dataset_loader))[0].shape

torch.Size([2, 2418])

In [66]:
sequences, labels = wiv04_seq_df_dataset_loader.dataset.__getitem__(0)
# sequences = [seq.clone().detach() for seq in sequences]

In [67]:
sequences

tensor([27., 27., 27.,  ..., 19., 27., 27.], dtype=torch.float64)

In [69]:
nn.ConstantPad1d((0, 256 - sequences.shape[0]), 0)(sequences)

torch.Size([256])

In [74]:
from torch.nn.utils.rnn import pad_sequence
pad_sequence([sequences], batch_first=True, padding_value=0).shape

torch.Size([1, 2418])