In [None]:
%pip install captum
%pip install pytorch-lightning

# **Imports**

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients

import pytorch_lightning as pl
import torch
from torch import nn
import torch.nn.functional as F

import pickle
import numpy as np

In [None]:
# @title finetuning model mean
class ESMFinetune(pl.LightningModule):
    def __init__(self):
        super().__init__()
        model, alphabet = pretrained.load_model_and_alphabet("esm1_t12_85M_UR50S")
        self.model = model
        self.clf_head = nn.Linear(768, 1)

        # The ESM 12 model does not have a layer norm before MLM. Therefore the 768 feature output has spikes.
        # We found no difference in performance by adding this. 
        with open("../Training/ESM12_Layer12_Norm.pkl", "rb") as f:
            final_scaling = pickle.load(f)
        self.scaling_mean = torch.tensor(final_scaling["mean"], device="cuda", requires_grad=False)
        self.scaling_std = torch.tensor(final_scaling["std"], device="cuda", requires_grad=False)
        self.final_ln = nn.LayerNorm(768)
        self.lr = 2e-5
    def forward(self, toks, lens, non_mask):
        # in lightning, forward defines the prediction/inference actions
        x = self.model(toks, repr_layers=[12])
        x = x["representations"][12]
        x = (x- self.scaling_mean) / self.scaling_std
        x = self.final_ln(x)
        x_mean = (x * non_mask[:,:,None]).sum(1) / lens[:,None]
        x = self.clf_head(x_mean)
        return x.squeeze() 

    def configure_optimizers(self):
        grouped_parameters = [
            {"params": [p for n, p in self.model.named_parameters()], 'lr': 3e-6},
            {"params": [p for n, p in self.clf_head.named_parameters()] + [p for n, p in self.final_ln.named_parameters()], 'lr': 2e-5},
        ]
        optimizer = torch.optim.AdamW(grouped_parameters, lr=self.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        #self.unfreeze()
        x, l, n, y, _ = batch
        y_pred =  self.forward(x, l, n)
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
        self.log('train_loss_batch', loss)
        return {'loss': loss}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_loss', avg_loss, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        #self.freeze()
        x, l, n, y, _ = batch
        y_pred =  self.forward(x, l, n)
        correct = ((y_pred>0) == y).sum()
        count = y.size(0)
        loss = F.binary_cross_entropy_with_logits(y_pred, y)
        self.log('val_loss_batch', loss)
        return {'loss': loss, 'correct':correct, "count":count}
  
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss, prog_bar=True)
        avg_acc = torch.tensor([x['correct'] for x in outputs]).sum() / torch.tensor([x['count'] for x in outputs]).sum()
        self.log('val_acc', avg_acc, prog_bar=True)

# **Attribution Computation**

In [None]:
import pickle
# reference https://captum.ai/tutorials/Bert_SQUAD_Interpret
def predict(toks, lengths, np_mask):
    return clf(toks, lengths, np_mask)

def custom_forward(toks, lengths, np_mask):
    preds = predict(toks, lengths, np_mask)
    return torch.sigmoid(preds)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

for split_i in range(5):
    print(split_i)
    path = f"./models/{split_i}PSISplit.ckpt"
    clf = ESMFinetune.load_from_checkpoint(path).cuda()
    clf.zero_grad()
    lig = LayerIntegratedGradients(custom_forward, clf.model.embed_tokens)

    data_df = pd.read_csv("../Datasets/NESG/NESG_testset.csv")
    data_df = pd.DataFrame(fasta_dict.items(), columns=['sid', 'fasta'])

    newalphabet = NewAlphabet(alphabet)
    embed_dataset = FastaBatchedDataset(data_df)
    embed_batches = embed_dataset.get_batch_indices(2048, extra_toks_per_seq=1)
    embed_dataloader = torch.utils.data.DataLoader(embed_dataset, collate_fn=newalphabet.get_batch_converter(), batch_sampler=embed_batches)

    score_vises_dict = {}
    attribution_dict = {}
    pred_dict = {}
    for j, (toks, lengths, np_mask, labels) in enumerate(embed_dataloader):
        #print(toks.shape)
        if j % 10 == 0:
          print(j, "/", len(embed_dataloader))
        baseline_toks = torch.empty((toks.size(0), toks.size(1)), dtype=torch.int64)
        baseline_toks.fill_(newalphabet.alphabet.padding_idx)
        baseline_toks[:, 0] = newalphabet.alphabet.cls_idx
        attributions, delta = lig.attribute(inputs=toks.to("cuda"),
                                        baselines=baseline_toks.to("cuda"),
                                        n_steps=50,
                                        additional_forward_args=(lengths.to("cuda"), np_mask.to("cuda")),
                                        internal_batch_size=8,
                                        return_convergence_delta=True)
        
        preds = custom_forward(toks.to("cuda"),lengths.to("cuda"), np_mask.to("cuda"))

        for i in range(preds.shape[0]):
            #attributions_sum = summarize_attributions(attributions[i])
            attribution_dict[labels[i]] = attributions[i].sum(dim=-1).squeeze(0).cpu().numpy()[1:1+lengths[i]]
            pred_dict[labels[i]] = preds[i].cpu().detach().numpy()

    #print(attribution_dict)
    with open(f"{split_i}_attrs.pkl", "wb") as f:
        pickle.dump({"attributions": attribution_dict, "preds": pred_dict}, f)

In [None]:

attr_dict = {}
pred_dict = {}
for i in range(5):    
    with open(f"{split_i}_attrs.pkl", "rb") as f:
        attrs = pickle.load(f)
    for k in attrs['attributions']:
        if k in attr_dict:
            attr_dict[k] += attrs['attributions'][k]
        else:
            attr_dict[k] = attrs['attributions'][k]
    for k in attrs['preds']:
        if k in pred_dict:
            pred_dict[k] += attrs['preds'][k]
        else:
            pred_dict[k] = attrs['preds'][k]

for k in attr_dict:
    attr_dict[k] = attr_dict[k] / np.abs(attr_dict[k]).sum()

for k in attr_dict:
    pred_dict[k] = pred_dict[k] / 5

def read_fasta(fastafile):
    """Parse a file with sequences in FASTA format and store in a dict"""
    with open(fastafile, 'r') as f:
        content = [l.strip() for l in f.readlines()]

    res = {}
    seq, seq_id = '', None
    for line in content:
        if line.startswith('>'):
            
            if len(seq) > 0:
                res[seq_id] = seq
            
            seq_id = line.replace('>', '')
            seq = ''
        else:
            seq += line
    res[seq_id] = seq
    return res
seq_dict = read_fasta(f"../Datasets/NESG/NESG_testset.fasta")

# **Plots**

In [None]:
start_imp = []
def length_avg(attrs):
    sum_dict = {i:0 for i in range(0,101)}
    count_dict = {i:0 for i in range(0,101)}
    data_dict = {i:[] for i in range(0,101)}
    for k in attrs:
        if abs(attrs[k][0]) > 0.005:
            start_imp.append(k)
        for j in range(attrs[k].shape[0]):
            bin_val = int(((j)/attrs[k].shape[0]) * 100)
            sum_dict[bin_val] += abs(attrs[k][j])
            count_dict[bin_val] += 1
            data_dict[bin_val].append(abs(attrs[k][j]))
    return sum_dict, count_dict, data_dict

def length_avg2(attrs):
    sum_dict = {i:0 for i in range(0,101)}
    count_dict = {i:0 for i in range(0,101)}
    data_dict = {i:[] for i in range(0,101)}
    for k in attrs:
        bin_size = int(100 / attrs[k].shape[0])
        for j in range(attrs[k].shape[0]):
            bin_val = int((j/attrs[k].shape[0]) * 100)
            if bin_size > 1:
                for b in range(bin_size):
                    sum_dict[bin_val + b] += abs(attrs[k][j]) / bin_size
                    count_dict[bin_val + b] += 1 / bin_size
                    data_dict[bin_val + b].append(abs(attrs[k][j]) / bin_size)
            else:
                sum_dict[bin_val] += abs(attrs[k][j])
                count_dict[bin_val] += 1
                data_dict[bin_val].append(abs(attrs[k][j]))
    return sum_dict, count_dict, data_dict

def length_avg_abs(attrs):

    data_dict = {}
    for k in attrs:
        for j in range(attrs[k].shape[0]):
            if j in data_dict:
                data_dict[j].append(abs(attrs[k][j]))
            else:
                data_dict[j] = [abs(attrs[k][j])]

    return data_dict

def length_avg_label(attrs, label):
    sum_dict = {i:0 for i in range(0,101)}
    count_dict = {i:0 for i in range(0,101)}
    data_dict = {i:[] for i in range(0,101)}
    for k in attrs:
        if label_df[label_df.sid == k].solubility.item() == label:
            for j in range(attrs[k].shape[0]):
                bin_val = int((j/attrs[k].shape[0]) * 100)
                sum_dict[bin_val] += abs(attrs[k][j])
                count_dict[bin_val] += 1
                data_dict[bin_val].append(abs(attrs[k][j]))
    return sum_dict, count_dict, data_dict

swi_weights = {'A': 0.8356471476582918,
           'C': 0.5208088354857734,
           'E': 0.9876987431418378,
           'D': 0.9079044671339564,
           'G': 0.7997168496420723,
           'F': 0.5849790194237692,
           'I': 0.6784124413866582,
           'H': 0.8947913996466419,
           'K': 0.9267104557513497,
           'M': 0.6296623675420369,
           'L': 0.6554221515081433,
           'N': 0.8597433107431216,
           'Q': 0.789434648348208,
           'P': 0.8235328714705341,
           'S': 0.7440908318492778,
           'R': 0.7712466317693457,
           'T': 0.8096922697856334,
           'W': 0.6374678690957594,
           'V': 0.7357837119163659,
           'Y': 0.6112801822947587}

def aa_avg(attrs, seqs):
    sum_dict = {i:0 for i in swi_weights}
    count_dict = {i:0 for i in swi_weights}
    for k in attrs:
        assert attrs[k].shape[0] == len(seqs[k])
        for j in range(attrs[k].shape[0]):
            bin_val = seqs[k][j]
            if bin_val not in swi_weights:
                continue
            sum_dict[bin_val] += attrs[k][j]
            count_dict[bin_val] += 1
    return sum_dict, count_dict

In [None]:
len_sum, len_count, len_list = length_avg(attrs)
netsol_lengths = {k: len_sum[k]/(len_count[k] + 1e-5) for k in len_sum}
netsol_stds =  {k: np.array(len_list[k]).std() for k in len_sum}

import matplotlib.pyplot as plt
fig, ax = plt.subplots()

ax.plot(list(netsol_lengths.keys())[:-1], list(netsol_lengths.values())[:-1])
y = np.array(list(netsol_lengths.values())[:-1])
error = np.array(list(netsol_stds.values())[:-1])
plt.fill_between(list(netsol_lengths.keys())[:-1], y-error, y+error, alpha = 0.3)
ax.set_ylabel('Importance')
ax.set_xlabel('Position as % of length')
plt.title("Importance vs Length")
fig.savefig('importancevlength.png')
plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr

aa_sum, aa_count = aa_avg(attrs, seq_dict)
netsol_scores = {k: aa_sum[k]/aa_count[k] for k in aa_sum}

print(spearmanr(list(netsol_scores.values()), list(swi_weights.values())))

minx = min(list(netsol_scores.values()))
maxx = max(list(netsol_scores.values()))
miny = min(list(swi_weights.values()))
maxy = max(list(swi_weights.values()))

x = np.linspace(minx,maxx,100)
y = (miny-maxy) / (minx - maxx) * (x - minx) + miny

fig, ax = plt.subplots()
ax.scatter(netsol_scores.values(), swi_weights.values())

for i, txt in enumerate(netsol_scores.keys()):
    ax.annotate(txt, (list(netsol_scores.values())[i] + 0.0001, list(swi_weights.values())[i] + 0.01))

ax.plot(x, y, '-r', label='y=2x+1', alpha=0.5)
ax.set_xlabel("NetSolP Scores")
ax.set_ylabel("SWI Scores")
ax.set_ylim(miny-0.05, maxy+0.05)
ax.set_title("Amino acid score comparison")
ax.grid()
plt.savefig('scorevsaminoacid_scatter.png')
plt.show()

# **Helpers for formatting**

In [None]:
# helpful for choosing the ideal color range
max_attr = [x.max() for x in attribution_dict.values()]
min_attr = [x.min() for x in attribution_dict.values()]
print(max(max_attr), min(min_attr))

In [None]:
from IPython.core.display import HTML, display

def _get_color(attr):
    # clip values to prevent CSS errors (Values should be from [-1,1])
    attr = max(-1, min(1, attr))
    if attr > 0:
        hue = 120
        sat = 75
        # Attributions lie between -1 and 1 but for better coloring change scaling based on dataset
        lig = 100 - int(80 * attr)
    else:
        hue = 0
        sat = 75
        lig = 100 - int(-100 * attr)
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def _get_color_sol(attr):
    if attr > 0:
        hue = 240
        sat = 100
        lig = 95
    else:
        hue = 30
        sat = 100
        lig = 95
    return "hsl({}, {}%, {}%)".format(hue, sat, lig)

def format_word_importances(words, importances, solubility):
    tags = ["<td nowrap>"]
    idx = 0
    for word in words:
        if word == "-":
            color = _get_color_sol(solubility)
            # We ignore the - character by setting opacity to 0
            unwrapped_tag = '<mark style="background-color: {color}; opacity:0.0; \
                        line-height:1.75"><font color="black"> {word}\
                        </font></mark>'.format(
                color=color, word=word
            )
        else:
            color = _get_color(importances[idx])
            idx += 1
            unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                        line-height:1.75"><font color="black"> {word}\
                        </font></mark>'.format(
                color=color, word=word
            )
        tags.append(unwrapped_tag)

    tags.append("</td>")
    return "".join(tags)

def format_classname(s, t=-1):
    if t==1:
      unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                        line-height:1.75"><font color="black"> {word}\
                        </font></mark>'.format(color=_get_color(0.3), word=s)
      return f"<td>{unwrapped_tag}</td>"
    elif t==0:
      unwrapped_tag = '<mark style="background-color: {color}; opacity:1.0; \
                        line-height:1.75"><font color="black"> {word}\
                        </font></mark>'.format(color=_get_color(-0.3), word=s)
      return f"<td>{unwrapped_tag}</td>"
    else:
      return f"<td>{s}</td>"

In [None]:
dom = ['<table style="font-family:\'Courier New\', monospace" width: 100%>']
rows = [
    "<tr>"
    "<th>ID</th>"
    "<th>Label</th>"
    "<th>Prediction</th>"
    "<th>MSA</th>"
    "</tr>"
]

for idx in range(len(msa_tsv)):
    rows.append(
        "".join(
            [
                "<tr>",
                format_classname(msa_tsv.sid[idx], msa_tsv.solubility[idx]),
                format_classname(msa_tsv.solubility[idx]),
                format_classname(round(pred_dict[msa_tsv.sid[idx]].item(), 3)),
                format_word_importances(msa_tsv.msa[idx], attribution_dict[msa_tsv.sid[idx]], msa_tsv.solubility[idx]),
                "<tr>",
            ]
        )
    )
dom.append("".join(rows))
dom.append("</table>")
html = HTML("".join(dom))
display(html)

In [None]:
html_file= open("FILENAME","w")
html_file.write("".join(dom))
html_file.close()