# Categorical Jacobian from ESM2

This Jupyter Notebook is an adaptation of [this colab notebook](https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian/esm2.ipynb) from [this paper](https://doi.org/10.1073/pnas.2406285121). 

This file lets you run the same workflow locally (e.g., when you don't have enough GPU access on Colab). If Colab runs for you, using the original Colab is likely cleaner.

In [1]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.special import softmax
from pathlib import Path
from urllib.request import urlretrieve
from matplotlib.colors import to_hex
from tqdm.auto import tqdm

import bokeh.io
import bokeh.plotting
from bokeh.models import BasicTicker, PrintfTickFormatter
from bokeh.palettes import viridis, RdBu
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show

import ipywidgets as widgets
from IPython.display import display

from esm_helpers import (
    make_palettes,
    pssm_to_dataframe,
    contact_to_dataframe,
    pair_to_dataframe,
    get_logits,
    load_model,
    get_categorical_jacobian,
    jac_to_con,
)

bokeh.io.output_notebook()

## Select ESM model

In [2]:
# ---- choose model ----
model_widget = widgets.Dropdown(
    options=[
        "esm2_t48_15B_UR50D",
        "esm2_t36_3B_UR50D",
        "esm2_t33_650M_UR50D",
        "esm2_t30_150M_UR50D",
        "esm2_t12_35M_UR50D",
        "esm2_t6_8M_UR50D",
        "esm1b_t33_650M_UR50S",
    ],
    value="esm2_t33_650M_UR50D",
    description="Model:",
    layout=widgets.Layout(width="90%")
)

display(model_widget)

Dropdown(description='Model:', index=2, layout=Layout(width='90%'), options=('esm2_t48_15B_UR50D', 'esm2_t36_3…

## Download model files

In [3]:
# ---- Download model files ----
model_name = model_widget.value
print(f"Selected model: {model_name}")

# ---- paths ----
cache_dir = Path.home() / ".cache" / "torch" / "hub" / "checkpoints"
cache_dir.mkdir(parents=True, exist_ok=True)

model_pt = cache_dir / f"{model_name}.pt"
regression_pt = cache_dir / f"{model_name}-contact-regression.pt"

MODEL_URL = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
REGRESSION_URL = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt"

def download_if_missing(url: str, dst: Path):
    if dst.exists() and dst.stat().st_size > 0:
        print(f"OK: {dst.name} already present")
        return
    print(f"Downloading: {dst.name}")
    urlretrieve(url, dst.as_posix())

# ---- download/check files ----
download_if_missing(MODEL_URL, model_pt)
download_if_missing(REGRESSION_URL, regression_pt)
print("All files are ready.")

Selected model: esm2_t33_650M_UR50D
OK: esm2_t33_650M_UR50D.pt already present
OK: esm2_t33_650M_UR50D-contact-regression.pt already present
All files are ready.


In [4]:
# ---- Set up model, alphabet, and palettes ----
# load model
model, alphabet, device = load_model(model_name)

# alphabets
esm_alphabet = list("".join(alphabet.all_toks[4:24]))
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [esm_alphabet.index(a) for a in ALPHABET]

# palettes
bwr_r, gray = make_palettes()

Using cache found in /grid/koo/home/nagai/.cache/torch/hub/facebookresearch_esm_main


## Select sequence

In [5]:
# ---- select sequence ----
sequence_box = widgets.Textarea(
    value="MKAKELREKSVEELNTELLNLLREQFNLRMQAASGQLQQSHLLKQVRRDVARVKTLLNEKAGA",
    description="Sequence:",
    layout=widgets.Layout(width="95%", height="100px")
)

display(sequence_box)

Textarea(value='MKAKELREKSVEELNTELLNLLREQFNLRMQAASGQLQQSHLLKQVRRDVARVKTLLNEKAGA', description='Sequence:', lay…

In [6]:
# ---- prepare output folder ----
# get cleaned sequence
sequence = "".join(c for c in sequence_box.value.upper() if c.isalpha())
print(f"Sequence length: {len(sequence)}")

# choose parallelism
PARALLEL = 20
if len(sequence) > 2400:
    PARALLEL = 1
elif len(sequence) > 1500:
    PARALLEL = 10

# output folder + README
os.makedirs("output", exist_ok=True)
with open("output/README.txt", "w") as handle:
    handle.write("conservation_logits.txt = (L, A) matrix\n")
    handle.write("coevolution.txt = (L, L) matrix\n")
    handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
    handle.write("jac index via np.triu_indices(L,1)\n")
    handle.write(f"sequence: {sequence}\n")


Sequence length: 63


## Compute conservation

In [7]:
# run conservation
logits = get_logits(model, alphabet, device, sequence, p=PARALLEL)
logits = logits[:, ALPHABET_map]
np.savetxt(f"output/conservation_logits_{model_name}.txt", logits)
pssm = softmax(logits, axis=-1)
df = pssm_to_dataframe(pssm, ALPHABET)

  0%|          | 0/63 [elapsed: 00:00 remaining: ?]

In [8]:
# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(1,len(sequence)+1)],
           y_range=list(ALPHABET)[::-1],
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

## Compute coevolution

In [9]:
# --- Select variables for coevolution ----

# ---- Make output folder ----
os.makedirs("output", exist_ok=True)

# ---- raw widgets (no descriptions) ----
fast_w = widgets.Checkbox(value=False, indent=False)
layer_w = widgets.Dropdown(
    options=[("Default (logits)", None)] + [(str(i), i) for i in range(34)],
    value=None,
    layout=widgets.Layout(width="260px"),
)
center_w = widgets.Checkbox(value=True, indent=False)
symm_w = widgets.Checkbox(value=True, indent=False)
diag_w = widgets.Dropdown(
    options=["remove", "normalize", "none"],
    value="remove",
    layout=widgets.Layout(width="220px"),
)
apc_w = widgets.Checkbox(value=True, indent=False)

def labeled(label, widget, label_width="60px"):
    return widgets.HBox(
        [widgets.Label(label, layout=widgets.Layout(width=label_width)), widget],
        layout=widgets.Layout(align_items="center")
    )

ui = widgets.VBox(
    [
        labeled("fast", fast_w),
        labeled("layer", layer_w),
        labeled("diag", diag_w),
        labeled("center", center_w),
        labeled("symm", symm_w),
        labeled("apc", apc_w),
    ],
    layout=widgets.Layout(align_items="flex-start"),
)

display(ui)

# cache holders (optional, but keeps behavior consistent)
if "settings_" not in globals():
    settings_ = None
if "jac" not in globals():
    jac = None
if "con" not in globals():
    con = None

VBox(children=(HBox(children=(Label(value='fast', layout=Layout(width='60px')), Checkbox(value=False, indent=F…

In [10]:
# ---- Plot coevolution ----
# ---- read values from widgets ----
fast = fast_w.value
layer = layer_w.value
center = center_w.value
symm = symm_w.value
diag = diag_w.value
apc = apc_w.value

# ---- caching logic ----
settings = dict(layer=layer, sequence=sequence, fast=fast)

if (jac is None) or (settings_ != settings):
    if fast and layer is None:
        jac = get_logits(sequence, p=PARALLEL, return_jac=True)
    else:
        jac = get_categorical_jacobian(model, alphabet, device, sequence, layer=layer, fast=fast)
    settings_ = settings.copy()
    print("Computed jacobian.")
else:
    print("Reused cached jacobian.")

# ---- contacts ----
con = jac_to_con(jac, esm_alphabet, ALPHABET, center=center, diag=diag, apc=apc, symm=symm)

# ---- save outputs ----
np.savetxt(f"output/coevolution_{model_name}.txt", con["contacts"])
if layer is not None:
    i, j = np.triu_indices(len(sequence), 1)
    np.save(f"output/jac_{model_name}.npy", con["jac"][i, :, j, :].astype(np.float16))

# ---- plot ----
df = contact_to_dataframe(con["contacts"])

TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(
    title="COEVOLUTION",
    x_range=[str(x) for x in range(1, len(sequence) + 1)],
    y_range=[str(x) for x in range(1, len(sequence) + 1)][::-1],
    width=800, height=800,
    tools=TOOLS, toolbar_location="below",
    tooltips=[("i", "@i"), ("j", "@j"), ("value", "@value")]
)

p.rect(
    x="i", y="j", width=1, height=1, source=df,
    fill_color=linear_cmap("value", gray, low=df.value.min(), high=df.value.max()),
    line_color=None
)

p.xaxis.visible = False
p.yaxis.visible = False
bokeh.io.show(p)


  0%|          | 0/63 [elapsed: 00:00 remaining: ?]

Computed jacobian.


## Select pair of residues to investigate 

In [11]:
# ---- Select variables for pair ----
pos_i_w = widgets.BoundedIntText(value=54, min=1, max=len(sequence), step=1)
pos_j_w = widgets.BoundedIntText(value=18, min=1, max=len(sequence), step=1)

def labeled(label, widget, label_width="80px"):
    return widgets.HBox(
        [widgets.Label(label, layout=widgets.Layout(width=label_width)), widget],
        layout=widgets.Layout(align_items="center")
    )

ui_pair = widgets.VBox(
    [
        labeled("position_i", pos_i_w),
        labeled("position_j", pos_j_w),
    ],
    layout=widgets.Layout(align_items="flex-start"),
)

display(ui_pair)

VBox(children=(HBox(children=(Label(value='position_i', layout=Layout(width='80px')), BoundedIntText(value=54,…

In [12]:
# ---- Plot pair ----
position_i = pos_i_w.value
position_j = pos_j_w.value

if layer is None:
    if fast:
        print("This function is only supported when fast=False (layer=None).")
    else:
        if con is None:
            print("Run coevolution first (so `con` exists).")
        else:
            i = position_i - 1
            j = position_j - 1

            if not (0 <= i < len(sequence) and 0 <= j < len(sequence)):
                raise ValueError("position_i/position_j out of range for the current sequence.")

            df_pair = pair_to_dataframe(con["jac"][i, :, j, :], ALPHABET)

            TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
            p = figure(
                title=f"coevolution between {position_i} and {position_j}",
                x_range=list(ALPHABET),
                y_range=list(ALPHABET)[::-1],
                width=400,
                height=400,
                tools=TOOLS,
                toolbar_location="below",
                tooltips=[("aa_i", "@aa_i"), ("aa_j", "@aa_j"), ("value", "@value")]
            )

            p.xaxis.axis_label = f"{sequence[i]}{position_i}"
            p.yaxis.axis_label = f"{sequence[j]}{position_j}"

            p.rect(
                x="aa_i",
                y="aa_j",
                width=1,
                height=1,
                source=df_pair,
                fill_color=linear_cmap("value", bwr_r, low=-3.0, high=3.0),
                line_color=None
            )

            bokeh.io.show(p)
else:
    print("This function is only supported when layer=None.")