In [91]:
#@title Imports
import torch
import random
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score
from sklearn.svm import SVR
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold
from sklearn.utils import resample
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from scipy.stats import spearmanr
from sklearn.metrics import ndcg_score
from sklearn.svm import SVR
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler, normalize
from sklearn.model_selection import KFold
from sklearn.utils import resample
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.ticker as ticker
from scipy.stats import spearmanr
from scipy.stats import pearsonr
from sklearn.metrics import ndcg_score
from sklearn.kernel_ridge import KernelRidge
from sklearn.svm import SVR, SVC
from sklearn.linear_model import Ridge
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling
from datasets import Dataset
from evaluate import load
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [92]:
#@title Load sequences from csv
df = pd.read_csv('./LazBF_sequences.csv')
LazBF_sequences = df['sequences'].tolist()
LazBF_labels = df['labels'].tolist()

df = pd.read_csv('./LazBF_sample.csv')
LazBF_sample = df['sequences'].tolist()
LazBF_sample_labels = df['labels'].tolist()

df = pd.read_csv('./LazDEF_sequences.csv')
LazDEF_sequences = df['sequences'].tolist()
LazDEF_labels = df['labels'].tolist()

df = pd.read_csv('./LazDEF_sample.csv')
LazDEF_sample = df['sequences'].tolist()
LazDEF_sample_labels = df['labels'].tolist()

In [93]:
LazBF_model = AutoModelForSequenceClassification.from_pretrained('./LazBF_ft/checkpoint-9766').to(device).eval()
LazDEF_model = AutoModelForSequenceClassification.from_pretrained('./LazDEF_ft/checkpoint-9766').to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

In [None]:
#@title Trainers
training_args = TrainingArguments(
    output_dir="esm_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    num_train_epochs=2,
    weight_decay=0.01,
    push_to_hub=False,
    fp16=True,
    load_best_model_at_end=True,
    gradient_accumulation_steps=2,
)

from datasets import load_metric
metric = load_metric('accuracy')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

LazBF_trainer = Trainer(
    model=LazBF_model,
    args=training_args,
    compute_metrics=compute_metrics
)

LazDEF_trainer = Trainer(
    model=LazDEF_model,
    args=training_args,
    compute_metrics=compute_metrics
)

In [None]:
lbf = Dataset.from_dict(tokenizer(LazBF_sample))
lbf = lbf.add_column("labels", LazBF_sample_labels)

ldef = Dataset.from_dict(tokenizer(LazDEF_sample))
ldef = ldef.add_column("labels", LazDEF_sample_labels)

In [None]:
LazBF_trainer.evaluate(lbf)

In [9]:
LazBF_trainer.evaluate(ldef)

{'eval_loss': 3.8547000885009766,
 'eval_accuracy': 0.5089,
 'eval_runtime': 9.5973,
 'eval_samples_per_second': 5209.791,
 'eval_steps_per_second': 40.741}

In [10]:
LazDEF_trainer.evaluate(lbf)

{'eval_loss': 1.1414541006088257,
 'eval_accuracy': 0.69678,
 'eval_runtime': 8.9048,
 'eval_samples_per_second': 5614.929,
 'eval_steps_per_second': 43.909}

In [11]:
LazDEF_trainer.evaluate(ldef)

{'eval_loss': 0.025831105187535286,
 'eval_accuracy': 0.99184,
 'eval_runtime': 9.495,
 'eval_samples_per_second': 5265.948,
 'eval_steps_per_second': 41.18}

---

In [12]:
#@title Interpretability

from torch import tensor
import matplotlib.colors as mcolors
from transformers.pipelines import TextClassificationPipeline
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, IntegratedGradients
from matplotlib.font_manager import FontProperties

import matplotlib.pyplot as plt
sm = torch.nn.Softmax(dim=1)


class ExplainableTransformerPipeline():
    def __init__(self, name, device):
        self.__name = name
        self.__device = device

    def forward_func(self, inputs: tensor, position = 0):
        pred = model.forward(inputs)
        return pred[position]

    def visualize(self, inputs: list, attributes: list, prediction):
        attr_sum = attributes.sum(-1)
        attr = attr_sum #/ torch.norm(attr_sum)
        attr = [float(at) for at in attr_sum[0]]
        y = np.array(attr)
        #y*=-1
        #a = pd.Series(attr.numpy()[0], index = tokenizer.convert_ids_to_tokens(inputs.detach().numpy()[0]))
        peptide = tokenizer(sample).input_ids #[1:-1] #list(sample)
        words = [tokenizer.decode(i) for i in peptide]
        #return y
        #print(y)
        #cmap = plt.cm.get_cmap('cividis') # Use the 'viridis' colormap
        #norm = plt.Normalize(min(y), max(y))


        letters = np.array(words[1:-1])
        colors = np.array([mcolors.to_hex(cmap(norm(datapoint))) for datapoint in y[1:-1]])
        positions = np.arange(len(letters))

        print("Model's prediction:", str(float(prediction[0][0])))

        fig, ax = plt.subplots()

        for i, letter in enumerate(letters):
          rect = plt.Rectangle((positions[i], 0), 1, 1, color=colors[i])
          ax.add_patch(rect)
          plt.text(positions[i]+0.5, 0.5, letter, ha='center', va='center', fontsize=100)

        ax.set_xlim([0, len(letters)])
        ax.set_ylim([0, 1])
        ax.set_xticks([])
        ax.set_yticks([])
        fig.set_size_inches(200, 10)
        plt.show()

    def explain(self, text: str, label):
        inputs = torch.tensor(tokenizer.encode(text, add_special_tokens=True), device=self.__device).unsqueeze(0)
        prediction = model.forward(inputs)[0]
        baseline = torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id] * (inputs.shape[1] - 2) + [tokenizer.eos_token_id], device = self.__device).unsqueeze(0)
        lig = LayerIntegratedGradients(self.forward_func, model.esm.embeddings) # The 'layer' from which we want to get the IG's is the mebedding slyaer.
        attributes, delta = lig.attribute(inputs=inputs,
                                  baselines=baseline,
                                  target = label,
                                  return_convergence_delta = True)
        attr_sum = attributes.sum(-1)
        return attr_sum.cpu().numpy()[0][1:-1]
        #return self.visualize(inputs, attributes, prediction) #float(sm(prediction)[0][1]))

    def generate_inputs(self, text: str) -> tensor:
        return torch.tensor(tokenizer.encode(text, add_special_tokens=True), device=self.__device).unsqueeze(0)

    def generate_baseline(self, sequence_len: int) -> tensor:
        return torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id] * (sequence_len - 2) + [tokenizer.eos_token_id], device = self.__device).unsqueeze(0)

In [5]:
#@title Amino acid dictionary
amino_acids = {
    'R': 0,
    'H': 1,
    'K': 2,
    'D': 3,
    'E': 4,
    'S': 5,
    'T': 6,
    'N': 7,
    'Q': 8,
    'C': 9,
    'G': 10,
    'P': 11,
    'A': 12,
    'V': 13,
    'I': 14,
    'L': 15,
    'M': 16,
    'F': 17,
    'Y': 18,
    'W': 19,
}

In [14]:
# Empty lists for storing avg contributions
avg_contribBF = np.zeros((20,))
avg_positionBF = np.zeros((11,))
avg_position_contribBF = np.zeros((20, 11))

avg_position_contribBF_counter = np.zeros((20, 11))

# Define model and aa_counter
model = LazBF_model
exp_model = ExplainableTransformerPipeline('distilbert', device)
aa_counts = np.zeros((20,))

for peptide in tqdm(LazBF_sample):
  contributions = exp_model.explain(peptide, 1)
  for i, letter in enumerate(peptide):
    # Add to position avg
    avg_positionBF[i] += contributions[i]
    # Add to pos x AA average
    avg_position_contribBF[amino_acids[letter]][i] += contributions[i]
    avg_position_contribBF_counter[amino_acids[letter]][i] += 1
    # Add to aa average
    avg_contribBF[amino_acids[letter]] += contributions[i]
    # Count amino acid types
    aa_counts[amino_acids[letter]] += 1

avg_position_contribBF = avg_position_contribBF / avg_position_contribBF_counter
avg_contribBF = avg_contribBF / aa_counts
avg_positionBF = avg_positionBF / 50000

100%|██████████| 50000/50000 [1:52:10<00:00,  7.43it/s]
  avg_position_contribBF = avg_position_contribBF / avg_position_contribBF_counter


In [15]:
# np.save('./drive/MyDrive/avg_pos_contribBF', avg_positionBF)
# np.save('./drive/MyDrive/avg_aa_contribBF', avg_contribBF)
# np.save('./drive/MyDrive/avg_posxaa_contribBF', avg_position_contribBF)

In [16]:
# Empty lists for storing avg contributions
avg_contribDEF = np.zeros((20,))
avg_positionDEF = np.zeros((11,))
avg_position_contribDEF = np.zeros((20, 11))
avg_position_contribDEF_count = np.zeros((20, 11))

# Define model and aa_counter
model = LazDEF_model
exp_model = ExplainableTransformerPipeline('distilbert', device)
aa_counts = np.zeros((20,))

for peptide in tqdm(LazBF_sample):
  contributions = exp_model.explain(peptide, 1)
  for i, letter in enumerate(peptide):
    # Add to position avg
    avg_positionDEF[i] += contributions[i]
    # Add to pos x AA average
    avg_position_contribDEF[amino_acids[letter]][i] += contributions[i]
    avg_position_contribDEF_count[amino_acids[letter]][i] += 1
    # Add to aa average
    avg_contribDEF[amino_acids[letter]] += contributions[i]
    # Count amino acid types
    aa_counts[amino_acids[letter]] += 1

avg_position_contribDEF = avg_position_contribDEF / avg_position_contribDEF_count
avg_contribDEF = avg_contribDEF / aa_counts
avg_positionDEF = avg_positionDEF / 50000

100%|██████████| 50000/50000 [1:50:43<00:00,  7.53it/s]
  avg_position_contribDEF = avg_position_contribDEF / avg_position_contribDEF_count


In [17]:
# np.save('./drive/MyDrive/avg_pos_contribDEF', avg_positionDEF)
# np.save('./drive/MyDrive/avg_aa_contribDEF', avg_contribDEF)
# np.save('./drive/MyDrive/avg_posxaa_contribDEF', avg_position_contribDEF)

In [2]:
avg_positionBF = np.load('./drive/MyDrive/avg_pos_contribBF.npy')
avg_contribBF = np.load('./drive/MyDrive/avg_aa_contribBF.npy')
avg_position_contribBF = np.load('./drive/MyDrive/avg_posxaa_contribBF.npy')

avg_positionDEF = np.load('./drive/MyDrive/avg_pos_contribDEF.npy')
avg_contribDEF = np.load('./drive/MyDrive/avg_aa_contribDEF.npy')
avg_position_contribDEF = np.load('./drive/MyDrive/avg_posxaa_contribDEF.npy')

In [9]:
from scipy.stats import spearmanr
spearmanr(np.nan_to_num(avg_contribBF, nan=0.0), np.nan_to_num(avg_contribDEF, nan=0.0))

SignificanceResult(statistic=0.8105263157894737, pvalue=1.4687607586992212e-05)

In [10]:
from scipy.stats import spearmanr
spearmanr(np.nan_to_num(avg_positionDEF, nan=0.0), np.nan_to_num(avg_positionBF, nan=0.0))

SignificanceResult(statistic=0.8000000000000002, pvalue=0.0031104283103858483)

In [3]:
from scipy.stats import spearmanr
spearmanr(np.nan_to_num(avg_position_contribBF, nan=0.0).flatten(), np.nan_to_num(avg_position_contribDEF, nan=0.0).flatten())

SignificanceResult(statistic=0.5867492909461007, pvalue=9.61098530273787e-22)

In [118]:
#@title Helper functions
csfont = {'fontname':'Times New Roman'}

def full_last_layer(model, sequence):
  input = tokenizer(sequence, return_tensors='pt').to(device)
  output = model.forward(input.input_ids, output_attentions=True)
  matrs = []
  for i, head in enumerate(output.attentions[-1][0]):
      matr = head.cpu().detach().numpy()
      matrs.append(matr)
      plt.imshow(matr, interpolation='nearest')
      x_ticks = np.arange(0, matr.shape[1])
      x_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]
      plt.xticks(x_ticks, x_tick_labels)
      y_ticks = np.arange(0, matr.shape[1])
      y_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]
      sizes = [9] + 13*[14] + [9]
      plt.yticks(y_ticks, y_tick_labels)
      for j, label in enumerate(plt.xticks()[1]):
        label.set_fontsize(sizes[j])
      for j, label in enumerate(plt.yticks()[1]):
        label.set_fontsize(sizes[j])
      plt.title(f"Layer 12, Head {i+1}", fontsize=17)
      plt.colorbar()
      #plt.text(-4, -0.9, 'c', fontsize=20) #, transform=ax.transAxes)
      plt.savefig(f'./VIGGRTCDGTRYY_head_{i+1}_alt.png', dpi=400, bbox_inches='tight', pad_inches=0)
      plt.show()
def full_attention(model, sequence):
  input = tokenizer(sequence, return_tensors='pt').to(device)
  output = model.forward(input.input_ids, output_attentions=True)
  matrs = []
  for layer, att in enumerate(output.attentions):
    for i, head in enumerate(att[0]):
      print(f'Layer {layer}, head {i}')
      matr = head.cpu().detach().numpy()
      matrs.append(matr)
      plt.imshow(matr, interpolation='nearest')
      x_ticks = np.arange(0, matr.shape[1])
      x_tick_labels = list(" "+sequence+" ")
      plt.xticks(x_ticks, x_tick_labels)
      y_ticks = np.arange(0, matr.shape[1])
      y_tick_labels = list(" "+sequence+" ")
      plt.yticks(y_ticks, y_tick_labels)
      plt.show()

def per_layer_attention(model, sequence):
  input = tokenizer(sequence, return_tensors='pt').to(device)
  output = model.forward(input.input_ids, output_attentions=True)
  for i, att in enumerate(output.attentions):
    matrs = []
    for head in att[0]:
      matr = head.cpu().detach().numpy()
      matrs.append(matr)
    matrs = np.array(matrs)
    plt.imshow(np.mean(matrs, axis=0), interpolation='nearest', vmin=0, vmax=0.3)
    x_ticks = np.arange(0, matr.shape[1])
    x_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]

    plt.xticks(x_ticks, x_tick_labels)
    sizes = [10] + 11*[15] + [10]
    for j, label in enumerate(plt.xticks()[1]):
      label.set_fontsize(sizes[j])

    y_ticks = np.arange(0, matr.shape[1])
    y_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]
    plt.yticks(y_ticks, y_tick_labels)
    sizes = [10] + 11*[15] + [10]
    for j, label in enumerate(plt.yticks()[1]):
      label.set_fontsize(sizes[j])

    plt.title(f'Average Attention for Layer {i+1}', fontsize=17)
    #plt.text(-4, -0.9, 'a', fontsize=20) #, transform=ax.transAxes)
    plt.colorbar()

    plt.savefig(f'./FVCHPSRWVGA_layer_{i+1}_alt.png', dpi=400, bbox_inches='tight', pad_inches=0)
    plt.show()

def oned_attention(model, sequence):
  input = tokenizer(sequence, return_tensors='pt').to(device)
  output = model.forward(input.input_ids, output_attentions=True)
  for att in output.attentions:
    matrs = []
    for head in att[0]:
      matrs.append(head.cpu().detach().numpy())
    matrs = np.array(matrs)
    matrs = np.mean(matrs, axis=0)
    matrs = np.mean(matrs, axis=0)
    print(matrs[:, np.newaxis].T)
    plt.imshow(matrs[:, np.newaxis].T, interpolation='nearest')
    x_ticks = np.arange(0, len(sequence)+2)
    x_tick_labels = list(" "+sequence+" ")
    plt.xticks(x_ticks, x_tick_labels)
    plt.show()

def layer11(model, sequence):
  input = tokenizer(sequence, return_tensors='pt').to(device)
  output = model.forward(input.input_ids, output_attentions=True)
  matrs = []
  for i, head in enumerate(output.attentions[-2][0]):
      matr = head.cpu().detach().numpy()
      matrs.append(matr)
      plt.imshow(matr, interpolation='nearest', vmin=0, vmax=0.8)
      x_ticks = np.arange(0, matr.shape[1])
      x_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]
      plt.xticks(x_ticks, x_tick_labels)
      y_ticks = np.arange(0, matr.shape[1])
      y_tick_labels = ["[BOS]"] + list(sequence) + ["[EOS]"]
      plt.yticks(y_ticks, y_tick_labels)

      sizes = [10] + 11*[15] + [10]
      for j, label in enumerate(plt.xticks()[1]):
        label.set_fontsize(sizes[j])
      for j, label in enumerate(plt.yticks()[1]):
        label.set_fontsize(sizes[j])

      plt.title(f"Layer 11, Head {i+1}", fontsize=17)
      #plt.colorbar()
      plt.text(-4, -0.9, 'b', fontsize=20) #, transform=ax.transAxes)
      plt.savefig(f'./FVCHPSRWVGA_head_{i+1}_alt.png', dpi=400, bbox_inches='tight', pad_inches=0)
      plt.show()

In [None]:
per_layer_attention(LazBF_model, "FVCHPSRWVGA") #'VIGGRTCDGTRYY')