In [1]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
from scipy.stats import gaussian_kde
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Align.Applications import MuscleCommandline
from Bio.Align.Applications import ClustalOmegaCommandline
from pymsaviz import MsaViz, get_msa_testdata
import torch
import pwlf
import math
from torch import nn
from Bio import AlignIO

In [2]:
directory = '/oak/stanford/groups/rbaltman/esm_embeddings/esm2_t33_650M_uniprot_human'
fasta_file = '/oak/stanford/groups/rbaltman/esm_embeddings/uniprot_human_full.fasta'

In [3]:
seqs = SeqIO.parse(fasta_file, "fasta")
protName = []
for seqrecord in seqs:
    name = seqrecord.id.split('|')[1]
    protName.append(name)

In [None]:
##create column normalized heatmaps for all proteins

for prot in protName:
    heatmap = []
    filename = f'{directory}/attention_matrices_mean_max_perLayer/{prot}.pt'
    try:
        data = torch.load(filename)
    except Exception as e:
        print (protId)
        print (e)
        continue

    for layer in range(0,33):
        attn_matrix = data[0,0,layer,1:-1,1:-1]
        col_list = torch.sum(attn_matrix, dim=0)
        col_list_norm = col_list / max(col_list)
        heatmap.append(np.round(col_list_norm.numpy(),2))
    np.save(f'../data/heatmap/{prot}.npy', heatmap, allow_pickle=True)
    


In [None]:
##get LoC and HA sites

for prot in protName:
    heatmap_filename = '../data/heatmap/{}.npy'.format(prot)
    if (os.path.exists(heatmap_filename)):
        heatmap = np.load(heatmap_filename)
        prot_length = heatmap.shape[1]
        theta_list = []
        layer_descIndices = {}
        val_data =[] 
        pred_data = []
        indices_layer = {}
        layer_break = {}
        x = [i/prot_length for i in range(0, prot_length)]
        for layer in range(0,33):
            vec = heatmap[layer, :]
            sorted_indices = np.argsort(vec)
            sorted_indices_desc = sorted_indices[::-1]
            sorted_values_desc = vec[sorted_indices_desc]
            indices_layer[layer] = sorted_indices_desc

            layer_descIndices[layer] = sorted_indices_desc

            pwlf_inst = pwlf.PiecewiseLinFit(x, sorted_values_desc)
            breaks = pwlf_inst.fit(2)
            #print (f'Layer {layer}, {math.floor(breaks[1]*prot_length)}')

            layer_break[layer] = math.floor(breaks[1]*prot_length)
            #layer_highAttend[layer] = desc_indices[:math.floor(breaks[1]*prot_length)]

            y_hat = pwlf_inst.predict(x)
            slopes = pwlf_inst.slopes
            m1, m2 = float(slopes[0]), float(slopes[1])
            print (f'Layer {layer}, {m1}, {m2}')

            theta_deg = math.degrees(math.atan((m2-m1)/(1+(m2*m1))))
            theta_list.append(theta_deg)
            val_data.append(sorted_values_desc)
            pred_data.append(y_hat)

with open('../data/prot_HA.pkl', 'wb') as file:
    pickle.dump(pid_highAttend, file)
with open('../data/prot_impLayer.pkl', 'wb') as file:
    pickle.dump(pid_impLayer, file)