In [1]:
!pip install tqdm scipy numpy jax h5py



Code adapted from: https://github.com/zzhangzzhang/pLMs-interpretability/tree/main

In [2]:
import time
import os 
import torch 
import json 
from tqdm import tqdm 
import numpy as np
from scipy.special import softmax
from utils import *
import random 
import matplotlib.pyplot as plt
from itertools import islice
from tqdm import tqdm
import h5py
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')

In [7]:
def do_model_run(model_name, sequence, pid):
    # Create output directory if it doesn't exist
    os.makedirs("out", exist_ok=True)
    
    # LOAD MODEL AND PUT TO DEVICE
    model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name)
    
    # Device selection logic
    if torch.backends.mps.is_available():
        device = torch.device("mps")
    elif torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    
    model = model.to(device)
    model = model.eval()
    # model.args.token_dropout = False

    # Record the start time
    start_time = time.time()

    cjm = get_categorical_jacobian(sequence, alphabet, model)
    contact_map = get_contacts(cjm)

    # Record the end time
    end_time = time.time()

    # Calculate the elapsed time
    elapsed_time = end_time - start_time
    print(f"{pid}_{model_name} took {elapsed_time} seconds")

    plt.figure(figsize=(8, 8))
    plt.imshow(contact_map, cmap="viridis", interpolation="nearest")
    plt.colorbar(label="Contact Strength")
    plt.title(f"{model_name} {pid} Contact Map Heatmap", fontsize=14)
    plt.xlabel(f"Position: {sequence}", fontsize=10)
    plt.ylabel(f"Position: {sequence}", fontsize=10)
    plt.tight_layout()  # Ensure the layout adjusts to fit labels
    plt.savefig(f"out/{pid}_{model_name}_contact_map.png")

    plt.clf()

    # Save the result
    with h5py.File(f'out/{pid}_{model_name}.h5', 'w') as f:
        f.create_dataset(pid, data=contact_map)
    print(f"finished {pid}_{model_name}")


In [8]:
# pid = "3LAEA"
pid3 = "1TX4-A"
sequence = "SNAIQQSDGSMIIDGSANLRDLNKMFNWELDTEDARTFNGLILEHLEEIPDEGTICEIDGLLITILEVGDNMIKQAKVVKL"
# sequence2 = "RPPLPNQQFGVSLQHLQEKNPEQEPIPIVLRETVAYLQAHALTTEGIFRRSANTQVVREVQQKYNMGLPVDFDQYNALHLPAVILKTFLRELPEPLLTFDLYPHVVGFLNIDESQRVPATLQVLQTLPEENYQVLRFLTAFLVQISAHSDQNKMTNTNLAVVFGPNLLWAKDAAITLKAINPINTFTKFLLDHQGELF"
# sequence3 = "PLPNQQFGVSLQHLQEKNPEQEPIPIVLRETVAYLQAHALTTEGIFRRSANTQVVREVQQKYNMGLPVDFDQYNALHLPAVILKTFLRELPEPLLTFDLYPHVVGFLNIDESQRVPATLQVLQTLPEENYQVLRFLTAFLVQISAHSDQNKMTNTNLAVVFGPNLLWAKDAAITLKAINPINTFTKFLLDHQGELF"
print(len(sequence))
# model_names = ["esm2_t6_8M_UR50D", "esm2_t12_35M_UR50D", "esm2_t30_150M_UR50D", "esm2_t33_650M_UR50D", "esm2_t36_3B_UR50D"]
model_names = ["esm2_t33_650M_UR50D"]

for model_name in model_names:
    do_model_run(model_name, sequence, pid3)

81


Using cache found in /Users/nicholas/.cache/torch/hub/facebookresearch_esm_main


1TX4-A_esm2_t33_650M_UR50D took 133.4346218109131 seconds
finished 1TX4-A_esm2_t33_650M_UR50D


<Figure size 800x800 with 0 Axes>