<a href="https://colab.research.google.com/github/MassimoGregorioTotaro/ESM-zs-prediction/blob/main/ESM_variant_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**ESM zero-shot variant prediction**
for more details see: [Github](https://github.com/facebookresearch/esm/tree/main/esm), [Paper](https://doi.org/10.1101/2021.07.09.450648)

#### **Tips and Instructions**
- click the little ▶ play icon to the left of each cell below.
- once the install and module setup cells are run, you shouldn't run them again.
- there's three running modes that can be chosen in the predict cell:
  - 'seq vs seq' which expects a sequence, similar to the input one, with corresponding aminoacid substitutions which will be evaluated;
  - 'deep mutational scan' which expects a list of residue indexes where the algorithm weill try and place all the other 19 aminoacid substituents and estimate a score value for the individual substitution (e.g. 1 100);
  - 'aa substitutions' where the mutations to be evaluated have to be written down explicitly (e.g. M1W D100R);


In [None]:
#@title install (run only once per session)

# Copyright (c) 2023, Massimo G. Totaro All rights reserved.
# Redistribution and use in source and binary forms, with or without 
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice, 
#    this list of conditions and the following disclaimer in the documentation 
#    and/or other materials provided with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

!pip install -q biopython fair-esm torch

from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title model setup and training (run only once per session; modify only if you know what you're doing)
import string

from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple
import torch
import warnings
warnings.filterwarnings('ignore')

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    # This is an efficient way to delete lowercase characters and insertion characters from a string
    deletekeys = dict.fromkeys(string.ascii_lowercase)
    deletekeys["."] = None
    deletekeys["*"] = None

    translation = str.maketrans(deletekeys)
    return sequence.translate(translation)


def label_row(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()


def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    # modify the sequence
    sequence = sequence[:idx] + mt + sequence[(idx + 1) :]

    # encode the sequence
    data = [
        ("protein1", sequence),
    ]

    batch_converter = alphabet.get_batch_converter()

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # compute probabilities at each position
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())  # vocab size
    return sum(log_probs)

model_name = 'esm2_t33_650M_UR50D' #@param ['esm1v_t33_650M_UR90S_1', 'esm1v_t33_650M_UR90S_2', 'esm1v_t33_650M_UR90S_3', 'esm1v_t33_650M_UR90S_4', 'esm1v_t33_650M_UR90S_5', 'esm2_t33_650M_UR50D', 'esm2_t36_3B_UR50D', 'esm2_t48_15B_UR50D']

# model, alphabet = torch.hub.load("facebookresearch/esm:main", model_name) ## esm update broke this, well done, guys!
model, alphabet = eval(f"pretrained.{model_name}()")
model.eval()
scoring_strategy = "wt-marginals" #@param ["wt-marginals", "pseudo-ppl", "masked-marginals"]
if torch.cuda.is_available():
    model = model.cuda()
    print("Transferred model to GPU")

In [None]:
#@title predict

mode = 'deep mutational scan'  #@param ['seq vs seq', 'deep mutational scan', 'aa substitutions']
sequence = "" #@param {type:"string"}
target = "" #@param {type:"string"}
offset_idx = 1 #\@param {type:"integer"}
substitutions = list()

if mode == 'seq vs seq':
  for resi,(src,trg) in enumerate(zip(sequence,target), offset_idx):
    if src != trg:
      substitutions.append(f"{src}{resi}{trg}")
elif mode == 'deep mutational scan':
  for resi in map(int, target.split()):
    src = sequence[resi-offset_idx]
    for trg in "ACDEFGHIKLMNPQRSTVWY".replace(src,''):
      substitutions.append(f"{src}{resi}{trg}")
elif mode == 'aa substitutions':
  substitutions = target.split()
else:
  raise RuntimeError("Unrecognised running mode")

df = pd.DataFrame(substitutions, columns=['\0'])
mutation_col = df.columns[0]

batch_converter = alphabet.get_batch_converter()

data = [("protein1", sequence),]

batch_labels, batch_strs, batch_tokens = batch_converter(data)

if scoring_strategy == "wt-marginals":
    with torch.no_grad():
        token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1)
    df[model_name] = df.apply(
        lambda row: label_row(
            row[mutation_col],
            sequence,
            token_probs,
            alphabet,
            offset_idx,
        ),
        axis=1,
    )
elif scoring_strategy == "masked-marginals":
    all_token_probs = []
    for i in tqdm(range(batch_tokens.size(1))):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(
                model(batch_tokens_masked.cuda())["logits"], dim=-1
            )
        all_token_probs.append(token_probs[:, i])  # vocab size
    token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
    df[model_name] = df.apply(
        lambda row: label_row(
            row[mutation_col],
            sequence,
            token_probs,
            alphabet,
            offset_idx,
        ),
        axis=1,
    )
elif scoring_strategy == "pseudo-ppl":
    tqdm.pandas()
    df[model_name] = df.progress_apply(
        lambda row: compute_pppl(
            row[mutation_col], sequence, model, alphabet, offset_idx
        ),
        axis=1,
    )

if mode == 'aa substitutions':
  df = df.sort_values(model_name, ascending=False)
elif mode == 'deep mutational scan':
  df = pd.concat([(df.assign(resi=df['\0'].str.extract(f'(\d+)', expand=False).astype(int))
          .sort_values(['resi', model_name], ascending=[True,False])
          .groupby(['resi'])
          .head(19)
          .drop(['resi'], axis=1)).iloc[19*x:19*(x+1)].reset_index(drop=True) for x in range(df.shape[0]//19)]
        , axis=1).set_axis(range(df.shape[0]//19*2), axis='columns')

df.style.hide_index().hide_columns().background_gradient(cmap="RdYlGn", vmax=8, vmin=-8)