In [2]:
!pip install tqdm scipy numpy jax h5py

Collecting jax
  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.4.35,>=0.4.34 (from jax)
  Downloading jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl.metadata (983 bytes)
Collecting ml-dtypes>=0.4.0 (from jax)
  Downloading ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)
Downloading jax-0.4.35-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m45.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading jaxlib-0.4.35-cp311-cp311-manylinux2014_x86_64.whl (87.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 MB[0m [31m36.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading ml_dtypes-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m131.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected p

Code adapted from: https://github.com/zzhangzzhang/pLMs-interpretability/tree/main

In [1]:
import time
import os 
import torch 
import json 
from tqdm import tqdm 
import numpy as np
from scipy.special import softmax
from utils import *
import random 
import matplotlib.pyplot as plt
from itertools import islice
from tqdm import tqdm
import h5py
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
# Record GPU needed and time for processing one jacobian Matrix
# Save jacobian Matrix/contact map
def get_categorical_jacobian(seq, alphabet, model):
  # ∂in/∂out
  x,ln = alphabet.get_batch_converter()([("seq",seq)])[-1],len(seq)
  with torch.no_grad():
    f = lambda x: model(x)["logits"][...,1:(ln+1),4:24].cpu().numpy()
    fx = f(x.to(device))[0]
    x = torch.tile(x,[20,1]).to(device)
    fx_h = np.zeros((ln,20,ln,20))
    for n in range(ln): # for each position
      x_h = torch.clone(x)
      x_h[:,n+1] = torch.arange(4,24) # mutate to all 20 aa
      fx_h[n] = f(x_h)
    return fx_h - fx
def get_contacts(x, symm=True, center=True, rm=1):
  # convert jacobian (L,A,L,A) to contact map (L,L)
  j = x.copy()
  if center:
    for i in range(4): j -= j.mean(i,keepdims=True)
  j_fn = np.sqrt(np.square(j).sum((1,3)))
  np.fill_diagonal(j_fn,0)
  j_fn_corrected = do_apc(j_fn, rm=rm)
  if symm:
    j_fn_corrected = (j_fn_corrected + j_fn_corrected.T)/2
  return j_fn_corrected
def do_apc(x, rm=1):
  '''given matrix do apc correction'''
  # trying to remove different number of components
  # rm=0 remove none
  # rm=1 apc
  x = np.copy(x)
  if rm == 0:
    return x
  elif rm == 1:
    a1 = x.sum(0,keepdims=True)
    a2 = x.sum(1,keepdims=True)
    y = x - (a1*a2)/x.sum()
  else:
    # decompose matrix, rm largest(s) eigenvectors
    u,s,v = np.linalg.svd(x)
    y = s[rm:] * u[:,rm:] @ v[rm:,:]
  np.fill_diagonal(y,0)
  return y
def do_model_run(model_name, sequence, pid):
    # LOAD MODEL AND PUT TO GPU / Validation: higher similiarity ok: 50%
    model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
    # put model on GPU if available
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    model = model.to(device)
    model = model.eval()
    #model.args.token_dropout = False
    
    # Record the start time
    start_time = time.time()
    print(f"starting {pid}_{model_name}")
    cjm = get_categorical_jacobian(sequence, alphabet, model)
    contact_map = get_contacts(cjm)
    
    # Record the end time
    end_time = time.time()

    # Calculate the elapsed time
    elapsed_time = end_time - start_time
    print(f"{pid}_{model_name} took {elapsed_time} seconds")
    
    plt.figure(figsize=(8, 8))
    plt.imshow(contact_map, cmap="viridis", interpolation="nearest")
    plt.colorbar(label="Contact Strength")
    plt.title(f"{model_name} {pid} Contact Map Heatmap")
    plt.xlabel("Position")
    plt.ylabel("Position")
    plt.savefig(f"out/{pid}_{model_name}_contact_map.png")
    
    plt.clf()
    
    # Save the result
    with h5py.File(f'out/{pid}_{model_name}.h5', 'w') as f:
        f.create_dataset(pid, data=contact_map)
    print(f"finished {pid}_{model_name}")

    

In [5]:
# pid = "3LAEA"
pid3 = "1TX4-A"
# sequence = "SNAIQQSDGSMIIDGSANLRDLNKMFNWELDTEDARTFNGLILEHLEEIPDEGTICEIDGLLITILEVGDNMIKQAKVVKL"
# sequence2 = "RPPLPNQQFGVSLQHLQEKNPEQEPIPIVLRETVAYLQAHALTTEGIFRRSANTQVVREVQQKYNMGLPVDFDQYNALHLPAVILKTFLRELPEPLLTFDLYPHVVGFLNIDESQRVPATLQVLQTLPEENYQVLRFLTAFLVQISAHSDQNKMTNTNLAVVFGPNLLWAKDAAITLKAINPINTFTKFLLDHQGELF"
sequence3 = "PLPNQQFGVSLQHLQEKNPEQEPIPIVLRETVAYLQAHALTTEGIFRRSANTQVVREVQQKYNMGLPVDFDQYNALHLPAVILKTFLRELPEPLLTFDLYPHVVGFLNIDESQRVPATLQVLQTLPEENYQVLRFLTAFLVQISAHSDQNKMTNTNLAVVFGPNLLWAKDAAITLKAINPINTFTKFLLDHQGELF"
print(len(sequence3))
# model_names = ["esm2_t6_8M_UR50D", "esm2_t12_35M_UR50D", "esm2_t30_150M_UR50D", "esm2_t33_650M_UR50D", "esm2_t36_3B_UR50D"]
model_names = ["esm2_t33_650M_UR50D"]

for model_name in model_names:
    do_model_run(model_name, sequence3, pid3)

196


Using cache found in /Users/cw_liao/.cache/torch/hub/facebookresearch_esm_main


1TX4-A_esm2_t33_650M_UR50D took 1233.877678155899 seconds
finished 1TX4-A_esm2_t33_650M_UR50D


<Figure size 800x800 with 0 Axes>