# Get Categorical Jacobian from gLM2
Copied from @sokrypton's https://colab.research.google.com/github/sokrypton/ColabBio/blob/main/categorical_jacobian/esm2.ipynb  
Minor edits to run gLM2 instead of ESM2


In [None]:
MODEL_NAME = "tattabio/gLM2_650M"

In [None]:
# Query sequence: ModAC
SEQUENCE = "MFLKVRAEKRLGNFRLNVDFEMGRDYCVLLGPTGAGKSVFLELIAGIVKPDRGEVRLNGADITPLPPERRGIGFVPQDYALFPHLSVYRNIAYGLRNVERVERDRRVREMAEKLGIAHLLDRKPARLSGGERQRVALARALVIQPRLLLLDEPLSAVDLKTKGVLMEELRFVQREFDVPILHVTHDLIEAAMLADEVAVMLNGRIVEKGKLKELFSAKNGEVAEFLSARNLLLKVSKILDMRLLFSALLALLSSIILLFVLLPVAATVTLQLFNFDEFLKAASDPAVWKVVLTTYYAALISTLIAVIFGTPLAYILARKSFPGKSVVEGIVDLPVVIPHTVAGIALLVVFGSSGLIGSFSPLKFVDALPGIVVAMLFVSVPIYINQAKEGFASVDVRLEHVARTLGSSPLRVFFTVSLPLSVRHIVAGAIMSWARGISEFGAVVVIAYYPMIAPTLIYERYLSEGLSAAMPVAAILILLSLAVFVALRIIVGREDVSEGQG"


In [None]:
import os
import torch
import string
import matplotlib.pyplot as plt
from scipy.special import softmax
import pandas as pd
import numpy as np
import bokeh.plotting
from bokeh.transform import linear_cmap
from bokeh.plotting import figure, show
from bokeh.palettes import viridis
from transformers import AutoTokenizer, AutoModelForMaskedLM
from matplotlib.colors import to_hex
import tqdm.notebook

bokeh.io.output_notebook()


## Utility functions.


In [None]:
def load_gLM2_model(model_name, device="cuda"):
    model = AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True).eval().to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    return model, tokenizer

def pssm_to_dataframe(pssm, esm_alphabet):
  sequence_length = pssm.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(pssm, index=idx, columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['Position', 'Amino Acid', 'Probability']
  return df

def contact_to_dataframe(con):
  sequence_length = con.shape[0]
  idx = [str(i) for i in np.arange(1, sequence_length + 1)]
  df = pd.DataFrame(con, index=idx, columns=idx)
  df = df.stack().reset_index()
  df.columns = ['i', 'j', 'value']
  return df

def pair_to_dataframe(pair,esm_alphabet):
  df = pd.DataFrame(pair, index=list(esm_alphabet), columns=list(esm_alphabet))
  df = df.stack().reset_index()
  df.columns = ['aa_i', 'aa_j', 'value']
  return df

def parse_fasta(filename, a3m=True):
    '''function to parse fasta file'''

    if a3m:
        # for a3m files the lowercase letters are removed
        # as these do not align to the query sequence
        rm_lc = str.maketrans(dict.fromkeys(string.ascii_lowercase))

    header, sequence = [], []
    lines = open(filename, "r")
    for line in lines:
        line = line.rstrip()
        if line[0] == ">":
            header.append(line[1:])
            sequence.append([])
        else:
            if a3m:
                line = line.translate(rm_lc)
            else:
                line = line.upper()
            sequence[-1].append(line)
    lines.close()
    sequence = [''.join(seq) for seq in sequence]

    return header, sequence


def mk_msa(seqs, alphabet=None):
    '''one hot encode msa'''
    if alphabet is None:
        alphabet = "ARNDCQEGHILKMFPSTWYV-"
    states = len(alphabet)
    a2n = {a: n for n, a in enumerate(alphabet)}
    msa_ori = np.array([[a2n.get(aa, states-1) for aa in seq] for seq in seqs])
    return np.eye(states)[msa_ori]


def _do_apc(x, rm_diag=True):
    '''given matrix do apc correction'''
    if rm_diag:
        np.fill_diagonal(x, 0.0)
    a1 = x.sum(0, keepdims=True)
    a2 = x.sum(1, keepdims=True)
    y = x - (a1*a2)/x.sum()
    np.fill_diagonal(y, 0.0)
    return y


def inv_cov(Y):
    '''given one-hot encoded MSA, return contacts'''
    N, L, A = Y.shape
    Y_flat = Y.reshape(N, -1)
    c = np.cov(Y_flat.T)
    shrink = 4.5/np.sqrt(N) * np.eye(c.shape[0])
    ic = np.linalg.inv(c + shrink)
    raw = np.sqrt(np.square(ic.reshape(L, A, L, A)
                  [:, :20, :, :20]).sum((1, 3)))
    return {"c": c, "ic": ic,
            "raw": raw, "apc": _do_apc(raw)}

def do_apc(x, rm=1):
    '''given matrix do apc correction'''
    # trying to remove different number of components
    # rm=0 remove none
    # rm=1 apc
    x = np.copy(x)
    if rm == 0:
        return x
    elif rm == 1:
        a1 = x.sum(0, keepdims=True)
        a2 = x.sum(1, keepdims=True)
        y = x - (a1*a2)/x.sum()
    else:
        # decompose matrix, rm largest(s) eigenvectors
        u, s, v = np.linalg.svd(x)
        y = s[rm:] * u[:, rm:] @ v[rm:, :]
    np.fill_diagonal(y, 0)
    return y


def get_contacts(x, symm=True, center=True, rm=1):
    # convert jacobian (L,A,L,A) to contact map (L,L)
    j = x.copy()
    if center:
        for i in range(4):
            j -= j.mean(i, keepdims=True)
    j_fn = np.sqrt(np.square(j).sum((1, 3)))
    np.fill_diagonal(j_fn, 0)
    j_fn_corrected = do_apc(j_fn, rm=rm)
    if symm:
        j_fn_corrected = (j_fn_corrected + j_fn_corrected.T)/2
    return j_fn_corrected


In [None]:
cmap = plt.colormaps["bwr_r"]
bwr_r = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]
cmap = plt.colormaps["gray_r"]
gray = [to_hex(cmap(i)) for i in np.linspace(0, 1, 256)]


TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model, tokenizer = load_gLM2_model(MODEL_NAME)

esm_alphabet = tokenizer.convert_ids_to_tokens(range(4, 25))
ALPHABET = "AFILVMWYDEKRHNQSTGPC"
ALPHABET_map = [esm_alphabet.index(a) for a in ALPHABET]
mask_idx = tokenizer.convert_tokens_to_ids('<mask>')

def jac_to_con(jac, symm=True, center=True, diag="remove", apc=True):

  X = jac.copy()
  Lx,Ax,Ly,Ay = X.shape
  if Ax == 20:
    X = X[:,ALPHABET_map,:,:]

  if Ay == 20:
    X = X[:,:,:,ALPHABET_map]
    if symm and Ax == 20:
      X = (X + X.transpose(2,3,0,1))/2

  if center:
    for i in range(4):
      if X.shape[i] > 1:
        X -= X.mean(i,keepdims=True)

  contacts = np.sqrt(np.square(X).sum((1,3)))

  if symm and (Ax != 20 or Ay != 20):
    contacts = (contacts + contacts.T)/2

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  if diag == "normalize":
    contacts_diag = np.diag(contacts)
    contacts = contacts / np.sqrt(contacts_diag[:,None] * contacts_diag[None,:])

  if apc:
    ap = contacts.sum(0,keepdims=True) * contacts.sum(1, keepdims=True) / contacts.sum()
    contacts = contacts - ap

  if diag == "remove":
    np.fill_diagonal(contacts,0)

  return {"jac":X, "contacts":contacts}

In [None]:
sequence = SEQUENCE.upper()
sequence = ''.join([i for i in sequence if i.isalpha()])

PARALLEL = 20
if len(sequence) > 1500:
  PARALLEL = 10
elif len(sequence) > 2400:
  PARALLEL = 1

os.makedirs("output",exist_ok=True)
with open("output/README.txt","w") as handle:
  handle.write("conservation_logits.txt = (L, A) matrix\n")
  handle.write("coevolution.txt = (L, L) matrix\n")
  handle.write("jac.npy = ((L*L-L)/2, A, A) tensor\n")
  handle.write("jac index can be recreated with np.triu_indices(L,1)\n")
  handle.write(f"[A]lphabet: {ALPHABET}\n")
  handle.write(f"sequence: {sequence}\n")

In [None]:
def get_logits(seq, p=1, return_jac=False):
  x, ln = tokenizer([seq], return_tensors='pt')['input_ids'].to(DEVICE), len(seq)
  if p is None: p = ln
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    f = lambda x: model(x).logits[:,:,4:24].detach().cpu().numpy()
    logits = np.zeros((ln,20), dtype=np.float32)
    if return_jac:
      jac = np.zeros((ln,1,ln,20), dtype=np.float32)
      fx = f(x.to(DEVICE))[0]
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(0,ln,p):
        m = min(n+p,ln)
        x_h = torch.tile(torch.clone(x),[m-n,1])
        for i in range(m-n):
          x_h[i,n+i] = mask_idx
        fx_h = f(x_h.to(DEVICE))
        for i in range(m-n):
          logits[n+i] = fx_h[i,n+i]
          if return_jac:
            jac[n+i] = fx_h[i,None] - fx
        pbar.update(m-n)
    if return_jac:
      return jac
    else:
      return logits


def get_categorical_jacobian(seq, layer=None, fast=False):
  # ∂in/∂out
  x, ln = tokenizer([seq], return_tensors='pt')['input_ids'].to(DEVICE), len(seq)
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
    if layer is None:
      f = lambda x: model(x).logits[:, :, 4:24].detach().cpu().numpy()
    else:
      f = lambda x: model(x, output_hidden_states=True).hidden_states[layer].detach().cpu().numpy()

    fx = f(x.to(DEVICE))[0]
    fx_h = np.zeros([ln, 1 if fast else 20, ln, fx.shape[-1]], dtype=np.float32)
    x = x.to(DEVICE) if fast else torch.tile(x, [20, 1]).to(DEVICE)
    with tqdm.notebook.tqdm(total=ln, bar_format=TQDM_BAR_FORMAT) as pbar:
      for n in range(ln):  # for each position
        x_h = torch.clone(x)

        # mutate to all 20 aa
        x_h[:, n] = mask_idx if fast else torch.arange(4, 24)
        fx_h[n] = f(x_h)
        pbar.update(1)

  # note: direction here differs from manuscript
  # positive = good
  # negative = bad
  return fx_h - fx


# Compute Conservation

In [None]:
model_name = os.path.basename(MODEL_NAME)
logits = get_logits(sequence, p=PARALLEL)
logits = logits[:,ALPHABET_map]
np.savetxt(f"output/conservation_logits_{model_name}.txt",logits)
pssm = softmax(logits,-1)
df = pssm_to_dataframe(pssm, ALPHABET)

# plot pssm
num_colors = 256  # You can adjust this number
palette = viridis(256)
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="CONSERVATION",
           x_range=[str(x) for x in range(1,len(sequence)+1)],
           y_range=list(ALPHABET)[::-1],
           width=900, height=400,
           tools=TOOLS, toolbar_location='below',
           tooltips=[('Position', '@Position'), ('Amino Acid', '@{Amino Acid}'), ('Probability', '@Probability')])

r = p.rect(x="Position", y="Amino Acid", width=1, height=1, source=df,
           fill_color=linear_cmap('Probability', palette, low=0, high=1),
           line_color=None)
p.xaxis.visible = False  # Hide the x-axis
show(p)

# Compute Coevolution

In [None]:
#@markdown Set output `layer` and postprocessing options such as to `center`, [`symm`]etrize, remove [`diag`]onal and to perform average product correction `apc`.
#@markdown The `fast` approximation only perturbs the mask token.
fast = False # @param {type:"boolean"}
layer = None # @param ["0","1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","23","24","25","26","27","28","29","30","31","32","33","None"] {type:"raw"}
center = True # @param {type:"boolean"}
symm = True # @param {type:"boolean"}
diag = "remove" # @param ["remove", "normalize", "none"]
apc = True # @param {type:"boolean"}
settings = dict(layer=layer,
                sequence=sequence,
                fast=fast)
if fast and layer is None:
  jac = get_logits(sequence, p=PARALLEL, return_jac=True)
else:
  jac = get_categorical_jacobian(sequence, layer=layer, fast=fast)
settings_ = settings.copy()
model_name = os.path.basename(MODEL_NAME)
con = jac_to_con(jac, symm=symm, center=center, diag=diag, apc=apc)

np.savetxt(f"output/coevolution_{model_name}.txt",con["contacts"])
if layer is not None:
  i,j = np.triu_indices(len(sequence),1)
  np.save(f"output/jac_{model_name}.npy",con["jac"][i,:,j,:].astype(np.float16))

df = contact_to_dataframe(con["contacts"])
TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
p = figure(title="COEVOLUTION",
          x_range=[str(x) for x in range(1,len(sequence)+1)],
          y_range=[str(x) for x in range(1,len(sequence)+1)][::-1],
          width=800, height=800,
          tools=TOOLS, toolbar_location='below',
          tooltips=[('i', '@i'), ('j', '@j'), ('value', '@value')])

r = p.rect(x="i", y="j", width=1, height=1, source=df,
          fill_color=linear_cmap('value', gray, low=df.value.min(), high=df.value.max()),
          line_color=None)
p.xaxis.visible = False  # Hide the x-axis
p.yaxis.visible = False  # Hide the x-axis
show(p)

In [None]:
#@markdown ##select pair of residues to investigate
#@markdown Note: 1-indexed (first position is 1)

position_i = 402 # @param {type:"integer"}
position_j = 93 # @param {type:"integer"}
if layer is None:
  if fast:
    print("this function is only supported when `fast=True`")
  else:
    i = position_i - 1
    j = position_j - 1
    df = pair_to_dataframe(con["jac"][i,:,j,:], ALPHABET)

    # plot pssm
    TOOLS = "hover,save,pan,box_zoom,reset,wheel_zoom"
    p = figure(title=f"coevolution between {position_i} {position_j}",
              x_range=list(ALPHABET),
              y_range=list(ALPHABET)[::-1],
              width=400, height=400,
              tools=TOOLS, toolbar_location='below',
              tooltips=[('aa_i', '@aa_i'), ('aa_j', '@aa_j'), ('value', '@value')])
    p.xaxis.axis_label = f"{sequence[i]}{position_i}"
    p.yaxis.axis_label = f"{sequence[j]}{position_j}"

    r = p.rect(x="aa_i", y="aa_j", width=1, height=1, source=df,
               fill_color=linear_cmap('value', bwr_r, low=-3.0, high=3.0),
               line_color=None, dilate=True)
    show(p)
else:
  print("this function is only supported when `layer=None`")