In [8]:
import os
import numpy as np
import torch
from tqdm import tqdm # for progress bar
import Bio.PDB

In [None]:
def process_fasta_files(folder_name):
    data = []
    lengths = []
    for file in os.listdir(folder_name):
        with open(os.path.join(folder_name, file), 'r') as fastaFile:
            id = ""
            seq = ""
            for line in fastaFile:
                line = line.strip()
                if line.startswith(">"):
                    contents = line.split("|")
                    id = contents[0][1:]
                else:
                    seq = line.upper()
            protein = (id, seq)
            data.append(protein)
            lengths.append(len(seq))
    return data,lengths

In [11]:
def process_pdb_files(folder_name):
    models = {}
    for file in os.listdir(folder_name):
        protein_name = os.path.splitext(file)[0].upper()
        structure = Bio.PDB.PDBParser().get_structure(protein_name, folder_name+"/"+file)
        model = structure[0]
        models[protein_name] = model
    return models

In [2]:
#Source: https://github.com/zzhangzzhang/pLMs-interpretability/blob/main/jac/utils.py
def do_apc(x, rm=1):
  '''given matrix do apc correction'''
  # trying to remove different number of components
  # rm=0 remove none
  # rm=1 apc
  x = np.copy(x)
  if rm == 0:
    return x
  elif rm == 1:
    a1 = x.sum(0,keepdims=True)
    a2 = x.sum(1,keepdims=True)
    y = x - (a1*a2)/x.sum()
  else:
    # decompose matrix, rm largest(s) eigenvectors
    u,s,v = np.linalg.svd(x)
    y = s[rm:] * u[:,rm:] @ v[rm:,:]
  np.fill_diagonal(y,0)
  return y

In [3]:
# Source: https://github.com/zzhangzzhang/pLMs-interpretability/blob/main/jac/01_jac_calculate_visualise.ipynb 
def get_categorical_jacobian(x,ln,model,device='cpu'):
  # ∂in/∂out
  with torch.no_grad():
    f = lambda x: model(x)["logits"][...,1:(ln+1),4:24].cpu().numpy()
    fx = f(x.to(device))[0]
    x = torch.tile(x,[20,1]).to(device)
    fx_h = np.zeros((ln,20,ln,20))
    with tqdm(total=ln) as pbar:
        for n in range(ln): # for each position
          x_h = torch.clone(x)
          x_h[:,n+1] = torch.arange(4,24) # mutate to all 20 aa
          fx_h[n] = f(x_h)
        pbar.update(1)
    return fx_h - fx

In [4]:
# Source: https://github.com/zzhangzzhang/pLMs-interpretability/blob/main/jac/utils.py
def get_contacts(x, symm=True, center=True, rm=1):
  # convert jacobian (L,A,L,A) to contact map (L,L)
  j = x.copy()
  if center:
    for i in range(4): j -= j.mean(i,keepdims=True)
  j_fn = np.sqrt(np.square(j).sum((1,3)))
  np.fill_diagonal(j_fn,0)
  j_fn_corrected = do_apc(j_fn, rm=rm)
  if symm:
    j_fn_corrected = (j_fn_corrected + j_fn_corrected.T)/2
  return j_fn_corrected