## Interpretation of TF models trained on UniRef90 dataset for multiclass classification

In [1]:
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "..", ".."))
sys.path.append(os.path.join(os.getcwd(), ".."))
sys.path

['C:\\Dev\\git\\zoonosis\\src\\jupyter_notebooks\\interpretation',
 'C:\\Users\\bless\\anaconda3\\python39.zip',
 'C:\\Users\\bless\\anaconda3\\DLLs',
 'C:\\Users\\bless\\anaconda3\\lib',
 'C:\\Users\\bless\\anaconda3',
 '',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages\\locket-0.2.1-py3.9.egg',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages\\win32',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages\\win32\\lib',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages\\Pythonwin',
 'C:\\Users\\bless\\anaconda3\\lib\\site-packages\\IPython\\extensions',
 'C:\\Users\\bless\\.ipython',
 'C:\\Dev\\git\\zoonosis\\src\\jupyter_notebooks\\interpretation\\..\\..',
 'C:\\Dev\\git\\zoonosis\\src\\jupyter_notebooks\\interpretation\\..']

In [2]:
from prediction.models.nlp import transformer
from src.utils import utils, nn_utils

import torch
import torch.nn.functional as F

import pandas as pd
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE



ModuleNotFoundError: No module named 'torch'

In [None]:
input_dir = "/home/grads/blessyantony/dev/git/zoonosis/input/data/uniref90/splits/s79221635"
train_file_names = ["uniref90_final.csv_tr0.8_train.csv"]
test_file_names = ["uniref90_final.csv_tr0.8_test.csv"]

label_groupings = {"Human": [ "Homo sapiens" ],
                  "Desert warthog": [ "Phacochoerus aethiopicus" ],
                  "Lesser bandicoot rat": [ "Bandicota bengalensis" ],
                  "Horse": [ "Equus caballus" ],
                  "Goat": [ "Capra hircus" ],
                  "Red junglefowl": [ "Gallus gallus" ],
                  "Wood mouse": [ "Apodemus sylvaticus" ],
                  "Cattle": [ "Bos taurus" ],
                  "Others": [ "*" ]}
host_classes = ["Homo sapiens",  "Phacochoerus aethiopicus",    "Bandicota bengalensis",     "Equus caballus",   "Capra hircus", 
                "Gallus gallus",   "Apodemus sylvaticus",     "Bos taurus",  "Others"]

amino_acid_idx_map = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5,
                  'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10,
                  'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15,
                  'O': 16, 'S': 17, 'U': 18, 'T': 19, 'W': 20,
                  'Y': 21, 'V': 22, 'B': 23, 'Z': 24, 'X': 25,
                  'J': 26}
idx_amino_acid_map = {v:k for k,v in amino_acid_idx_map.items()}

train_sequence_settings =  {
    "sequence_col": "seq",
    "batch_size": 8,
    "max_sequence_length": 1024,
    "pad_sequence_val": 0,
    "truncate": True
}

test_sequence_settings =  train_sequence_settings.copy()
test_sequence_settings["batch_size"] = 1

label_settings = {
    "label_col": "virus_host_name",
    "exclude_labels": [ "nan"],
    "label_groupings":  label_groupings
}

model = {
    "max_seq_len": 1024,
    "loss": "CrossEntropyLoss",
    "with_convolution": False,
    "n_heads": 8,
    "depth": 6,
    "n_tokens": 27,
    "n_classes": 9,
    "n_epochs": 10,
    "dim": 512,
    "weight_initialization": "normal"
}

### Load the datasets

In [None]:
def load_dataset(input_dir, input_file_names, sequence_settings):
    df = utils.read_dataset(input_dir, input_file_names, cols=[sequence_settings["sequence_col"], label_settings["label_col"]])
    df, index_label_map = utils.transform_labels(df, label_settings, classification_type="multi")
    dataset_loader = nn_utils.get_dataset_loader(df, sequence_settings, label_settings["label_col"])
    return index_label_map, dataset_loader

def load_dataset_with_df(df, sequence_settings):
    df = df[[sequence_settings["sequence_col"], label_settings["label_col"]]]
    df, index_label_map = utils.transform_labels(df, label_settings, classification_type="multi")
    dataset_loader = nn_utils.get_dataset_loader(df, sequence_settings, label_settings["label_col"])
    return index_label_map, dataset_loader

def print_dataset_loader(dataset_loader):
    sequence, label = next(iter(dataset_loader))
    print(sequence.shape)
    print(sequence)
    print(label.shape)
    print(label)

### Training-based interpretation
#### Encoding visualization - all viruses, all hosts

In [None]:
def compute_dataset_representations(nlp_model, dataset_loader):
    nlp_model.eval()
    seq_dfs = []
    for _, record in enumerate(dataset_loader):
        seq, label = record
        output = nlp_model(seq)
        seq_encoding = nlp_model.encoder.encoding
        # embedding = value for each dimension = mean of the dimensional values of all tokens in the input sequence
        seq_encoding = torch.mean(seq_encoding, dim=1, keepdim=True)
        seq_df = pd.DataFrame(seq_encoding.squeeze().cpu().detach().numpy())
        seq_df["label"] = label.squeeze().cpu().detach().numpy()
        seq_dfs.append(seq_df)
    df = pd.concat(seq_dfs)
    print(df.shape)
    return df


def visualize_dataset(rep_df):
    columns = rep_df.columns
    print(columns)
    X = rep_df[range(512)]
    tsne_model = TSNE(n_components=2, verbose=1, init="pca", learning_rate="auto").fit(X)
    X_emb = pd.DataFrame(tsne_model.fit_transform(X))
    print(X_emb.shape)
    print(X_emb)
    X_emb["label"] = rep_df["label"].values
    return tsne_model, X_emb
    
def visualize_prediction(nlp_model, seq, label, rep_df):
    nlp_model.eval()
    output = nlp_model(seq)
    seq_encoding = nlp_model.encoder.encoding
    seq_encoding = torch.mean(seq_encoding, dim=1, keepdim=True)

    seq_df = pd.DataFrame(seq_encoding.squeeze(1).cpu().detach().numpy())
    seq_df["label"] = label.squeeze().cpu().detach().numpy()
    sample_pred = torch.argmax(F.softmax(output, dim=1), dim=1)
    print(f"Label {label} = {index_label_map[label.item()]}")
    sample_pred_mapped = index_label_map[sample_pred.item()]
    print(f"Prediction {sample_pred}= {sample_pred_mapped}")
    seq_df["label"] = "prediction-" + sample_pred_mapped
    
    rep_df_copy = rep_df.copy()
    rep_df_copy["label"] = rep_df["label"].map(index_label_map)
    rep_df_copy = rep_df_copy[rep_df_copy["label"] != "Others"]
    rep_seq_df = pd.concat([rep_df_copy, seq_df])
    print(f"rep_seq_df shape = {rep_seq_df.shape}")
    X = rep_seq_df[range(512)]
    print(f"X shape = {X.shape}")
    
    tsne_model = TSNE(n_components=2, verbose=1, init="pca", learning_rate="auto").fit(X)
    X_emb = pd.DataFrame(tsne_model.fit_transform(X))
    print(f"X_emb shape = {X_emb.shape}")
    X_emb["label"] = rep_seq_df["label"].values
    print(f"X_emb shape = {X_emb.shape}")
    sns.scatterplot(data = X_emb, x=0, y=1, hue="label")
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.show()

### Testing-based interpretation
#### Attention based interpretation

In [None]:
def compute_mean_attn_values(nlp_model):
    attn_values = nlp_model.encoder.layers[5].self_attn.self_attn.squeeze()
    return torch.mean(attn_values, dim=0)


def plot_mean_attention_values(x, seq=None, seq_len=None):
    ticklabels = seq.cpu().detach().numpy().squeeze()[:seq_len]
    ticklabels_mapped = [idx_amino_acid_map[x] for x in ticklabels]

    plt.rcParams['xtick.labelsize'] = 5
    plt.rcParams['ytick.labelsize'] = 5
    plt.figure(figsize=(12,12))
    data = x.cpu().detach().numpy()
    
    sns.heatmap(data=data[:seq_len, :seq_len], xticklabels=ticklabels_mapped, yticklabels=ticklabels_mapped)
    #plt.xticks(rotation=20)
    plt.show()


def plot_mean_of_mean_attention_values(x, seq=None, seq_len=None, seq_max_length=None):
    tokens = seq.cpu().detach().numpy().squeeze()
    
    x = torch.mean(x, dim=0)
    df = pd.DataFrame({"tokens": tokens, "attn_vals": x.cpu().detach().numpy(), "pos": range(seq_max_length)})
    df["tokens"] = df["tokens"].map(idx_amino_acid_map)
    df = df.dropna()
    sorted_df = df.sort_values(by="attn_vals", ascending=False).head(10)
    print("Top 10 tokens + positions with highest attention values for the whole sequence")
    print(sorted_df.head(10))
    plt.rcParams['xtick.labelsize'] = 8
    plt.rcParams['ytick.labelsize'] = 8
    plt.figure(figsize=(12,6))
    sns.scatterplot(data=df, x="pos", y="attn_vals", hue="tokens")
    plt.show()
    
def analyze_attention_of_prediction(nlp_model, sample_seq, sample_label, seq_max_length):
    # sample_seq = sample_seq.unsqueeze(0)
    seq_len= torch.count_nonzero(sample_seq)
    print(sample_seq.shape)
    print(f"seq_len = {seq_len}")
    
    nlp_model.eval()
    output = nlp_model(sample_seq)
    sample_pred = torch.argmax(F.softmax(nlp_model(sample_seq), dim=1), dim=1)
    print(f"Label = {index_label_map[sample_label.item()]}")
    print(f"Prediction = {index_label_map[sample_pred.item()]}")
    mean_attn_values = compute_mean_attn_values(nlp_model)

    plot_mean_attention_values(mean_attn_values, seq=sample_seq, seq_len=seq_len)
    plot_mean_of_mean_attention_values(mean_attn_values, seq=sample_seq, seq_len=seq_len, seq_max_length=seq_max_length)
    
def analyze_attention_of_df(nlp_model, dataset_loader, seq_max_length):
    attn_dfs = []
    max_seq_len_actual = 0
    for _, record in enumerate(dataset_loader):
        seq, label = record
        seq_len = torch.count_nonzero(seq).item()
        if seq_len > max_seq_len_actual:
            max_seq_len_actual = seq_len
        nlp_model(seq)
        mean_attn_values = compute_mean_attn_values(nlp_model)
        mean_of_mean = torch.mean(mean_attn_values, dim=0, keepdim=True)
        attn_dfs.append(mean_of_mean.cpu().detach().numpy())
    print("max_seq_len_actual = ", max_seq_len_actual)
    attn_df = np.concatenate(attn_dfs, axis=0)
    plt.figure(figsize=(12,12))
    sns.heatmap(data=attn_df[:,:max_seq_len_actual])

 ### Analysis Pipeline

In [None]:
def analyse_model(model, train_dataset_loader, test_dataset_loader, seq, label, seq_max_length, viz_train=False, viz_test=False):
    if viz_train:
        train_rep_df = compute_dataset_representations(model, train_dataset_loader)
        visualize_prediction(model, seq, label, train_rep_df)
    if viz_test:
        test_rep_df = compute_dataset_representations(model, test_dataset_loader)
        visualize_dataset(test_rep_df)
    
    
    analyze_attention_of_prediction(model, seq, label, seq_max_length)
    
    analyze_attention_of_df(model, test_dataset_loader, seq_max_length)

#### UniRef90 Datasets
19k
all viruses, all hosts, all proteins, without duplicates and single hosts

In [None]:
index_label_map, train_dataset_loader = load_dataset(input_dir, train_file_names, train_sequence_settings)
print_dataset_loader(train_dataset_loader)

In [None]:
index_label_map, test_dataset_loader = load_dataset(input_dir, test_file_names, test_sequence_settings)
print_dataset_loader(test_dataset_loader)
# Random seq, label from test_dataset_loader
test_seq, test_label = next(iter(test_dataset_loader))

#### UniProtKB Coronavirus only

In [None]:
uniref90_coronaviruses_df = pd.read_csv("/home/grads/blessyantony/dev/git/zoonosis/input/data/coronaviridae/coronaviridae_top_7_hosts.csv")
uniref90_coronaviruses_df.head()

In [None]:
uniref90_coronaviruses_df["virus_host"].value_counts()

In [None]:
uniref90_coronaviruses_humans_df = uniref90_coronaviruses_df[uniref90_coronaviruses_df["virus_host"] == "Homo sapiens (Human) [TaxID: 9606]"]
print(uniref90_coronaviruses_humans_df.shape)
uniref90_coronaviruses_humans_df.head()

In [None]:
uniref90_coronaviruses_humans_df["virus_host"] = "Homo sapiens"
uniref90_coronaviruses_humans_df.rename(columns={"virus_host": "virus_host_name"}, inplace=True)
uniref90_coronaviruses_humans_df.head()

In [None]:
sns.histplot(uniref90_coronaviruses_humans_df["seq_len"])
print(f"min seq len = {min(uniref90_coronaviruses_humans_df['seq_len'])}")
print(f"max seq len = {max(uniref90_coronaviruses_humans_df['seq_len'])}")
plt.show()

In [None]:
_, coronavirus_dataset_loader = load_dataset_with_df(uniref90_coronaviruses_humans_df, test_sequence_settings)
print_dataset_loader(coronavirus_dataset_loader)
# Random seq, label from coronavirus_dataset_loader
coronavirus_seq, coronavirus_label = next(iter(coronavirus_dataset_loader))

### Load the trained model

#### Model: TF - PosEmb_SINCOS - MSL_1024 - d_512
#### Manual Seed = 0

In [None]:
model_path = "/home/grads/blessyantony/dev/git/zoonosis/output/raw/uniref90/20230531/host_multi-seed0/transformer-crossentropy_itr4.pth"

nlp_model = transformer.get_transformer_model(model)
nlp_model.load_state_dict(torch.load(model_path))
nlp_model = nlp_model.to(nn_utils.get_device())


In [None]:
analyse_model(nlp_model, train_dataset_loader, test_dataset_loader, test_seq, test_label, seq_max_length=1024, viz_train=True, viz_test=False)

In [None]:
analyse_model(nlp_model, train_dataset_loader, coronavirus_dataset_loader, coronavirus_seq, coronavirus_label, seq_max_length=1024, viz_train=False, viz_test=False)

#### Model: TF - PosEmb_SINCOS - MSL_1024 - d_512
#### Manual Seed = 170638

In [None]:
model_path = "/home/grads/blessyantony/dev/git/zoonosis/output/raw/uniref90/20230531/host_multi-seed170638/transformer-crossentropy_itr4.pth"

nlp_model = transformer.get_transformer_model(model)
nlp_model.load_state_dict(torch.load(model_path))
nlp_model = nlp_model.to(nn_utils.get_device())


In [None]:
analyse_model(nlp_model, train_dataset_loader, test_dataset_loader, test_seq, test_label, seq_max_length=1024, viz_train=True, viz_test=False)

In [None]:
analyse_model(nlp_model, train_dataset_loader, coronavirus_dataset_loader, coronavirus_seq, coronavirus_label, seq_max_length=1024, viz_train=False, viz_test=False)

#### Model: TF - PosEmb_SINCOS - MSL_1024 - d_512
#### Manual Seed = 745540

In [None]:
model_path = "/home/grads/blessyantony/dev/git/zoonosis/output/raw/uniref90/20230531/host_multi-seed745540/transformer-crossentropy_itr4.pth"

nlp_model = transformer.get_transformer_model(model)
nlp_model.load_state_dict(torch.load(model_path))
nlp_model = nlp_model.to(nn_utils.get_device())


In [None]:
analyse_model(nlp_model, train_dataset_loader, test_dataset_loader, test_seq, test_label, seq_max_length=1024, viz_train=True, viz_test=False)

In [None]:
analyse_model(nlp_model, train_dataset_loader, coronavirus_dataset_loader, coronavirus_seq, coronavirus_label, seq_max_length=1024, viz_train=False, viz_test=False)

In [None]:
from Bio import motifs
from Bio.Seq import Seq
from Bio.Alphabet import generic_protein

In [None]:
instances = []
for s in uniref90_coronaviruses_humans_df["seq"]:
    if len(s) >= 900:
        instances.append(Seq(s[:900]))

In [None]:
m = motifs.create(instances, alphabet=generic_protein)