## Interpretation of Hybrid VirProBERT attention values for multiclass classification

### Trainining Dataset: UNiRef90  - Coronaviridae Spike protein sequences aligned using MAFFT
### Interpretation: SARS-CoV-2 Spike protein sequences

**Positional Embedding**: Sin-Cos

**Maximum Sequence Length**: -

**Classification**: Multi-class

**\# classes**: 8

In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", "..", "..", ".."))
sys.path.append(os.path.join(os.getcwd(), "..", ".."))
sys.path.append(os.path.join(os.getcwd(), ".."))
sys.path

['/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/coronaviridae-aligned',
 '/opt/conda/lib/python38.zip',
 '/opt/conda/lib/python3.8',
 '/opt/conda/lib/python3.8/lib-dynload',
 '',
 '/home/blessyantony/.local/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages',
 '/opt/conda/lib/python3.8/site-packages/IPython/extensions',
 '/home/blessyantony/.ipython',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/coronaviridae-aligned/../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/coronaviridae-aligned/../../../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/coronaviridae-aligned/../..',
 '/home/blessyantony/dev/git/zoonosis/src/jupyter_notebooks/interpretability/coronaviridae-aligned/..']

In [2]:
from models.nlp.transformer import transformer
from models.nlp.hybrid import transformer_attention
from datasets.protein_sequence_dataset import ProteinSequenceDataset
from src.utils import utils, dataset_utils, nn_utils, constants

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
cmap = sns.color_palette("vlag", as_cmap=True)

from sklearn.metrics import roc_curve, accuracy_score, f1_score, auc, precision_recall_curve
from statistics import mean

# from captum.attr import LayerIntegratedGradients, TokenReferenceBase, LayerGradientXActivation, LayerDeepLift, LayerLRP

In [3]:
label_groupings = {
                    "Chicken": [ "gallus gallus" ],
                    "Human": [ "homo sapiens" ],
                    "Cat": [ "felis catus" ],
                    "Pig": [ "sus scrofa" ],
                    "Gray wolf": [ "canis lupus" ],
                    "Horshoe bat": ["rhinolophus sp."],
                    "Ferret": ["mustela putorius"],
                    "Chinese rufous horseshoe bat": ["rhinolophus sinicus"]
                }


sequence_settings =  {
    "sequence_col": "aligned_seq",
    "batch_size": 16,
    "max_sequence_length": 128,
    "truncate": False,
    "split_sequence": False,
    "feature_type": "token",
}

label_settings = {
    "label_col": "virus_host_name",
    "exclude_labels": [ "nan"],
    "label_groupings":  label_groupings
}

model = {
    "pre_train_settings": {
        "n_heads": 8,
        "depth": 6,
        "input_dim": 512, # input embedding dimension
        "hidden_dim": 1024,
        "max_seq_len": 129,
    },
    "loss": "FocalLoss",
    "n_heads": 8,
    "depth": 2,
    "stride": 64,
    "n_classes": 8,
    "input_dim": 512, # input embedding dimension
    "hidden_dim": 1024,
    "cls_token": True
}

### Load the datasets

In [4]:
def print_dataset_loader(dataset_loader):
    print()
    sequence, label = next(iter(dataset_loader))
    print(f"Sequence tensor size = {sequence.shape}")
    print(f"Sequence = {sequence}")
    print(f"Label tensor size = {label.shape}")
    print(f"Label = {label}")

In [5]:
input_file_path = os.path.join(os.getcwd(), "..", "..", "..", "..", "input/data/coronaviridae/20240313/uniref/alignment/coronaviridae_s_uniref90_embl_hosts_pruned_metadata_corrected_species_virus_host_vertebrates_w_seq_t0.01_c8_aligned.csv")
uniref90_coronaviridae_aligned_df = pd.read_csv(input_file_path)
wiv04_seq_df = uniref90_coronaviridae_aligned_df[uniref90_coronaviridae_aligned_df["uniref90_id"] == "WIV04"]
uniref90_coronaviridae_aligned_df

Unnamed: 0,uniref90_id,aligned_seq,seq,virus_name,virus_host_name,human_binary_label
0,WIV04,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,WIV04(MN996528.1) Wuhan variant index virus,homo sapiens,homo sapiens
1,UniRef90_A0A7U3RIT3,--------------MFVFLVLVPLVSS--------Q----------...,MFVFLVLVPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
2,UniRef90_A0A7U3HGG2,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
3,UniRef90_A0A7U3EEN6,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
4,UniRef90_A0A7U3HDM5,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,Severe acute respiratory syndrome coronavirus 2,homo sapiens,homo sapiens
...,...,...,...,...,...,...
677,UniRef90_S5FZ76,---------------------------TLKQ---------------...,TLKQCDASAGYYSSSPIRPSDGVHSVTGFYRPVKTCCIKYTYPSNT...,Infectious bronchitis virus,gallus gallus,NOT homo sapiens
678,UniRef90_U5WLM9,--------------MLLLVTLFGLASG-------------------...,MLLLVTLFGLASGCSLPLTVSCPRGLPFTLQINTTSVTVEWYRVSP...,Sarbecovirus,rhinolophus sinicus,NOT homo sapiens
679,UniRef90_A0A169QA14,--MILHF-IMKVMPILIMVVFILL----------------------...,MILHFIMKVMPILIMVVFILLVYTNTHSSEWLLLFYFLISGVFCLY...,Infectious bronchitis virus,gallus gallus,NOT homo sapiens
680,UniRef90_E7DBM7,----------------------------------------------...,CSRRQFENYNQIEKVHVH,Feline coronavirus,felis catus,NOT homo sapiens


In [6]:
index_label_map, dataset_loader = dataset_utils.load_dataset_with_df(uniref90_coronaviridae_aligned_df, sequence_settings, label_settings, label_col=label_settings["label_col"], classification_type="multi")
print_dataset_loader(dataset_loader)

Grouping labels using config : {'Chicken': ['gallus gallus'], 'Human': ['homo sapiens'], 'Cat': ['felis catus'], 'Pig': ['sus scrofa'], 'Gray wolf': ['canis lupus'], 'Horshoe bat': ['rhinolophus sp.'], 'Ferret': ['mustela putorius'], 'Chinese rufous horseshoe bat': ['rhinolophus sinicus']}
label_idx_map={'Cat': 0, 'Chicken': 1, 'Chinese rufous horseshoe bat': 2, 'Ferret': 3, 'Gray wolf': 4, 'Horshoe bat': 5, 'Human': 6, 'Pig': 7}
idx_label_map={0: 'Cat', 1: 'Chicken', 2: 'Chinese rufous horseshoe bat', 3: 'Ferret', 4: 'Gray wolf', 5: 'Horshoe bat', 6: 'Human', 7: 'Pig'}

Sequence tensor size = torch.Size([16, 2418])
Sequence = tensor([[27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 19., 27., 27.],
        ...,
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 27., 27., 27.],
        [27., 27., 27.,  ..., 19., 27., 27.]], dtype=torch.float64)
Label tensor size = torch.Size([16])
Label = tensor([1

### Load the pre-trained and fine-tuned model

In [7]:
pre_train_encoder_settings = model["pre_train_settings"]
pre_train_encoder_settings["vocab_size"] = 30 #constants.VOCAB_SIZE
pre_trained_encoder_model = transformer.get_transformer_encoder(pre_train_encoder_settings)

TransformerEncoder(
  (embedding): EmbeddingLayer(
    (token_embedding): Embedding(30, 512, padding_idx=0)
    (positional_embedding): PositionalEncoding()
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0): EncoderLayer(
        (self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True)
          (W_K): Linear(in_features=512, out_features=512, bias=True)
          (W_V): Linear(in_features=512, out_features=512, bias=True)
          (W_O): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForwardLayer(
          (W_1): Linear(in_features=512, out_features=1024, bias=True)
          (W_2): Linear(in_features=1024, out_features=512, bias=True)
        )
        (residual_connections): ModuleList(
          (0): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
          (1): ResidualConnectionLayer(
            (norm): NormalizationLayer()
          )
        )
 

In [8]:
model["pre_trained_model"] = pre_trained_encoder_model
model["segment_len"] = sequence_settings["max_sequence_length"]
prediction_model = transformer_attention.get_model(model)

TransformerAttention(
  (pre_trained_model): TransformerEncoder(
    (embedding): EmbeddingLayer(
      (token_embedding): Embedding(30, 512, padding_idx=0)
      (positional_embedding): PositionalEncoding()
    )
    (encoder): Encoder(
      (layers): ModuleList(
        (0): EncoderLayer(
          (self_attn): MultiHeadAttention(
            (W_Q): Linear(in_features=512, out_features=512, bias=True)
            (W_K): Linear(in_features=512, out_features=512, bias=True)
            (W_V): Linear(in_features=512, out_features=512, bias=True)
            (W_O): Linear(in_features=512, out_features=512, bias=True)
          )
          (feed_forward): FeedForwardLayer(
            (W_1): Linear(in_features=512, out_features=1024, bias=True)
            (W_2): Linear(in_features=1024, out_features=512, bias=True)
          )
          (residual_connections): ModuleList(
            (0): ResidualConnectionLayer(
              (norm): NormalizationLayer()
            )
            (1): 

In [9]:
model_path = os.path.join(os.getcwd(), "..", "..", "..", "..", "output/raw/coronaviridae_s_prot_uniref90_embl_vertebrates_aligned_t0.01_c8/20240711/host_multi/fine_tuning_hybrid_cls_vs30/mlm_tfenc_l6_h8_lr1e-4_uniref90viridae_msl128b1024vs30_hybrid_attention_s64_fnn_2l_d1024_lr1e-4_itr4.pth")
prediction_model.load_state_dict(torch.load(model_path, map_location=nn_utils.get_device()))

<All keys matched successfully>

### t-SNE Analysis

In [10]:
def compute_embeddings(model, dataset_loader):
    model.eval()
    seq_dfs = []
    for _, record in enumerate(dataset_loader):
        seq, label = record
        output = model(seq)
        # embedding = value for each dimension = mean of the dimensional values of all tokens in the input sequence
        seq_encoding = model.embedding
        seq_df = pd.DataFrame(seq_encoding.squeeze().cpu().detach().numpy())
        seq_df["label"] = label.squeeze().cpu().detach().numpy()
        print(seq_df.shape)
        seq_dfs.append(seq_df)
    df = pd.concat(seq_dfs)
    print(df.shape)
    return df

def view_tsne_representation(rep_df, index_label_map):
    columns = rep_df.columns
    print(columns)
    X = rep_df[range(512)]
    tsne_model = TSNE(n_components=2, verbose=1, init="pca", learning_rate="auto").fit(X)
    X_tsne_emb = pd.DataFrame(tsne_model.fit_transform(X))
    print(X_tsne_emb.shape)
    print(X_tsne_emb)
    X_tsne_emb["label"] = rep_df["label"].values
    X_tsne_emb["label"] = X_tsne_emb["label"].map(index_label_map)
    
    sns.scatterplot(data = X_tsne_emb, x=0, y=1, hue="label")
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.show()
    return tsne_model, X_tsne_emb

In [11]:
#emb_df = compute_embeddings(prediction_model, dataset_loader)

### Attention value Analysis

In [12]:
def analyze_attention_of_sequence(model, seq):
    print(f"sequence length = {seq_len}")
    model.eval()
    output = model(seq)
    

In [13]:
wiv04_seq_df

Unnamed: 0,uniref90_id,aligned_seq,seq,virus_name,virus_host_name,human_binary_label
0,WIV04,--------------MFVFLVLLPLVSS--------Q----------...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,WIV04(MN996528.1) Wuhan variant index virus,homo sapiens,homo sapiens


In [14]:
len(wiv04_seq_df["aligned_seq"][0])

2418

In [15]:
sequence_settings["batch_size"] = 1
sequence_settings["max_sequence_length"] = 128

_, wiv04_seq_df_dataset_loader = dataset_utils.load_dataset_with_df(wiv04_seq_df, sequence_settings, label_settings, label_col=label_settings["label_col"], classification_type="multi")

Grouping labels using config : {'Chicken': ['gallus gallus'], 'Human': ['homo sapiens'], 'Cat': ['felis catus'], 'Pig': ['sus scrofa'], 'Gray wolf': ['canis lupus'], 'Horshoe bat': ['rhinolophus sp.'], 'Ferret': ['mustela putorius'], 'Chinese rufous horseshoe bat': ['rhinolophus sinicus']}
label_idx_map={'Cat': 0, 'Chicken': 1, 'Chinese rufous horseshoe bat': 2, 'Ferret': 3, 'Gray wolf': 4, 'Horshoe bat': 5, 'Human': 6, 'Pig': 7}
idx_label_map={0: 'Cat', 1: 'Chicken', 2: 'Chinese rufous horseshoe bat', 3: 'Ferret', 4: 'Gray wolf', 5: 'Horshoe bat', 6: 'Human', 7: 'Pig'}


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_block(indexer, value, name)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  iloc._setitem_with_indexer(indexer, value, self.name)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[label_col] = df[label_col].transform(lambda x: label_idx_map[x] if x in label_idx_map else 0)


In [16]:
prediction_model.eval()
input, label = next(iter(wiv04_seq_df_dataset_loader))

In [17]:
output = prediction_model(input)

IndexError: index out of range in self

In [None]:
output

In [None]:
label

In [None]:
F.softmax(output, dim=1)

In [None]:
prediction_model.pre_trained_model.encoder.layers[-1].self_attn.self_attn.shape

In [None]:
prediction_model.self_attn.self_attn.squeeze().shape

In [None]:
wiv04_position_mapping = {}
pos = 0
for i, token in enumerate(wiv04_seq_df["aligned_seq"][0]):
    if token == "-":
        pass
    else:
        pos += 1
        wiv04_position_mapping[i] = int(pos)
    if (i%64 == 0) and (i not in wiv04_position_mapping):
        wiv04_position_mapping[i] = pos        

In [None]:
len(wiv04_position_mapping)

In [None]:
wiv04_position_mapping

In [None]:
pos_mapping = {}
j = 0
for i in range(0, len(wiv04_seq_df["aligned_seq"][0]), 64):
    try:
        pos_mapping[j] = f"{j}: {wiv04_position_mapping[i]}-{wiv04_position_mapping[i+128]}"
    except KeyError:
        break
    j += 1

In [None]:
pos_mapping

In [None]:
inter_seg_attn = prediction_model.self_attn.self_attn.squeeze()

In [None]:
plt.clf()
plt.rcParams["xtick.labelsize"] = 40
plt.rcParams["ytick.labelsize"] = 40
plt.rcParams.update({'font.size': 40})
fig, axs = plt.subplots(4, 2, figsize=(80, 100), sharex=False, sharey=True)

c = 0
for i in range(4):
    for j in range(2):
        df = pd.DataFrame(inter_seg_attn[c].squeeze().detach().cpu().numpy())
        df.rename(columns=pos_mapping, inplace=True)
        df.rename(index=pos_mapping, inplace=True)
        sns.heatmap(df, ax=axs[i, j], linewidth=.1)
        axs[i, j].set_title(f"Head {c}")
        c += 1

plt.tight_layout(pad=.1)
plt.show()

In [None]:
intra_seg_index = 20
intra_seg_attn = prediction_model.pre_trained_model.encoder.layers[-1].self_attn.self_attn[intra_seg_index].squeeze()

In [None]:
intra_seg_attn.shape

In [None]:
intra_seg_pos_map = {}
intra_seg_start = intra_seg_index * 64
intra_seg_end = intra_seg_index * 64 + 128

c = 0
for i in range(intra_seg_start, intra_seg_end + 1):
    if i in wiv04_position_mapping:
        intra_seg_pos_map[c] = wiv04_position_mapping[i]
    else:
        intra_seg_pos_map[c] = "-"
    c += 1

intra_seg_pos_map

In [None]:
plt.clf()
plt.rcParams["xtick.labelsize"] = 40
plt.rcParams["ytick.labelsize"] = 40
plt.rcParams.update({'font.size': 40})
fig, axs = plt.subplots(4, 2, figsize=(80, 100), sharex=False, sharey=False)

c = 0
for i in range(4):
    for j in range(2):
        df = pd.DataFrame(intra_seg_attn[c].squeeze().detach().cpu().numpy())
        df.rename(columns=intra_seg_pos_map, inplace=True)
        df.rename(index=intra_seg_pos_map, inplace=True)
        sns.heatmap(df, ax=axs[i, j], linewidth=.1)
        axs[i, j].set_title(f"Head {c}")
        c += 1

plt.tight_layout(pad=.1)
plt.show()

In [None]:
plt.clf()
plt.rcParams["xtick.labelsize"] = 20
plt.rcParams["ytick.labelsize"] = 20
plt.rcParams.update({'font.size': 20})
fig, axs = plt.subplots(1, 1, figsize=(30, 30), sharex=False, sharey=False)

df = pd.DataFrame(intra_seg_attn.mean(dim=0).detach().cpu().numpy())
df.rename(columns=intra_seg_pos_map, inplace=True)
df.rename(index=intra_seg_pos_map, inplace=True)
sns.heatmap(df, ax=axs, linewidth=.1)

plt.tight_layout(pad=.1)
plt.show()

In [None]:
non_dash_indices = [k for k, v in intra_seg_pos_map.items() if v != "-" ]

In [None]:
plt.clf()
plt.rcParams["xtick.labelsize"] = 30
plt.rcParams["ytick.labelsize"] = 30
plt.rcParams.update({'font.size': 30})
fig, axs = plt.subplots(1, 1, figsize=(40, 40), sharex=False, sharey=False)

df = pd.DataFrame(intra_seg_attn.mean(dim=0).detach().cpu().numpy())
df = df.iloc[non_dash_indices]
df.rename(columns=intra_seg_pos_map, inplace=True)
df.rename(index=intra_seg_pos_map, inplace=True)
sns.heatmap(df, ax=axs, linewidth=.1)

plt.tight_layout(pad=.1)
plt.show()