# Mutation Effects on Proteins with ESM-2

$$
S_{\text{masked marginal}} = \sum_{i \in M} \left[ \log p(x_i = x_{i}^{\text{mt}}|x_{-M}) - \log p(x_i = x_{i}^{\text{wt}}|x_{-M}) \right]
$$

In [104]:
import torch
from transformers import AutoTokenizer, EsmForMaskedLM
from typing import List, Tuple

def masked_marginal_scoring(
    tokenizer: AutoTokenizer,
    model: EsmForMaskedLM,
    sequence: str,
    mutations: List[Tuple[int, str, str]]
) -> float:
    
    # Create a copy of the sequence for mutation
    seq_list = list(sequence)

    # Check and mask the positions
    for pos, wt, mt in mutations:
        if seq_list[pos] != wt:
            raise ValueError(f"The amino acid at position {pos} is {seq_list[pos]}, not {wt}.")
        seq_list[pos] = tokenizer.mask_token
    
    # Convert the mutated sequence back to string
    masked_sequence = "".join(seq_list)
    
    # Tokenize the masked sequence
    inputs = tokenizer(masked_sequence, return_tensors="pt")

    # Get the model's output
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the logits
    logits = outputs.logits

    # Get the mask token indices
    mask_indices = torch.where(inputs.input_ids.squeeze() == tokenizer.mask_token_id)[0]

    # Initialize score
    score = 0

    # Iterate over each mutation
    for (pos, wt, mt), mask_index in zip(mutations, mask_indices):
        # Get the logits for the masked position
        position_logits = logits[0, mask_index]

        # Apply softmax to logits to get probabilities
        probabilities = torch.nn.functional.softmax(position_logits, dim=-1)

        # Convert probabilities to log probabilities
        log_probabilities = torch.log(probabilities)

        # Get the token ids for wt and mt
        wt_token_id = tokenizer.convert_tokens_to_ids(wt)
        mt_token_id = tokenizer.convert_tokens_to_ids(mt)

        # Retrieve the log probabilities
        wt_log_prob = log_probabilities[wt_token_id].item()
        mt_log_prob = log_probabilities[mt_token_id].item()

        # Compute the difference and update the score
        score += mt_log_prob - wt_log_prob

    return score

# Test the function
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
# mutations = [(67, 'V', 'R'), (82, 'R', 'D'), (83, 'E', 'A')]
mutations = [(68, 'L', 'R'), (83, 'E', 'D'), (84, 'K', 'A')]

score = masked_marginal_scoring(tokenizer, model, sequence, mutations)
score


-0.9924547672271729

In [75]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

def mask_protein_sequence(protein_sequence: str, mask_positions: list):
    # Tokenize the protein sequence
    inputs = tokenizer(protein_sequence, return_tensors="pt")

    # Mask the specified positions
    for pos in mask_positions:
        inputs["input_ids"][0][pos] = tokenizer.mask_token_id

    # Get the logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Get the logits for the masked positions
    mask_logits = logits[0, mask_positions]
    
    return mask_logits

# Test the function with a protein sequence and mask positions
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mask_positions = [67, 82, 83]
mask_protein_sequence(protein_sequence, mask_positions)


tensor([[-1.3275e+01, -2.3676e+01, -1.2741e+01, -2.3688e+01,  9.1518e-01,
          6.5759e-01,  1.0361e+00,  9.8178e-01,  2.6509e-01,  4.4436e-01,
          2.3581e-01,  2.7433e-01,  8.4145e-01, -1.0545e-02,  1.8591e-02,
          1.0111e+00, -1.5097e-02, -2.5256e-01,  1.6846e-01,  9.4608e-02,
         -4.9925e-01, -6.6282e-01, -1.0709e+00, -1.3224e+00, -8.5838e+00,
         -1.1722e+01, -1.2004e+01, -1.2563e+01, -1.5926e+01, -1.6301e+01,
         -1.6292e+01, -1.6336e+01, -2.3675e+01],
        [-1.1887e+01, -2.1386e+01, -1.2544e+01, -2.1378e+01,  8.9595e-01,
          5.8057e-01,  1.3578e+00,  5.4902e-01,  1.0946e+00,  1.2902e+00,
          4.9994e-01,  6.9993e-01,  3.3177e-01,  1.4841e+00,  1.1987e+00,
          9.9457e-01,  2.5640e-01,  5.9584e-01,  1.7739e-01, -2.5779e-02,
         -7.8020e-01, -4.5620e-01, -1.1412e+00, -1.8741e+00, -6.9884e+00,
         -1.1527e+01, -1.1931e+01, -1.2505e+01, -1.5916e+01, -1.6279e+01,
         -1.6230e+01, -1.6286e+01, -2.1387e+01],
        [-1.22

In [79]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

def mask_protein_sequence(protein_sequence: str, mask_positions: list):
    # Tokenize the protein sequence
    inputs = tokenizer(protein_sequence, return_tensors="pt")

    # Mask the specified positions
    for pos in mask_positions:
        inputs["input_ids"][0][pos] = tokenizer.mask_token_id

    # Get the logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Get the logits for the masked positions
    mask_logits = logits[0, mask_positions]

    # Convert the logits to a dictionary with vocabulary elements as keys
    vocab = tokenizer.get_vocab()
    vocab = {v: k for k, v in vocab.items()}  # reverse the key-value pairs in the vocab
    
    logits_dicts = []
    for logit in mask_logits:
        logits_dict = {vocab[i]: logit[i].item() for i in range(len(vocab))}
        logits_dicts.append(logits_dict)

    return logits_dicts

# Test the function with a protein sequence and mask positions
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mask_positions = [67, 82, 83]
mask_protein_sequence(protein_sequence, mask_positions)


[{'<cls>': -13.275341033935547,
  '<pad>': -23.676189422607422,
  '<eos>': -12.74148941040039,
  '<unk>': -23.687597274780273,
  'L': 0.9151788949966431,
  'A': 0.6575865745544434,
  'G': 1.0361005067825317,
  'V': 0.9817789793014526,
  'S': 0.2650868892669678,
  'E': 0.4443625509738922,
  'R': 0.23581159114837646,
  'T': 0.2743259072303772,
  'I': 0.8414493799209595,
  'D': -0.01054537296295166,
  'P': 0.018591076135635376,
  'K': 1.0111300945281982,
  'Q': -0.015097111463546753,
  'N': -0.25255557894706726,
  'F': 0.16846412420272827,
  'Y': 0.0946081280708313,
  'M': -0.4992474913597107,
  'H': -0.6628150939941406,
  'W': -1.070947289466858,
  'C': -1.322406530380249,
  'X': -8.58376693725586,
  'B': -11.722338676452637,
  'U': -12.004006385803223,
  'Z': -12.56307315826416,
  'O': -15.926107406616211,
  '.': -16.30080223083496,
  '-': -16.292280197143555,
  '<null_1>': -16.336050033569336,
  '<mask>': -23.67510223388672},
 {'<cls>': -11.88650894165039,
  '<pad>': -21.3856201171875,

In [82]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

def mask_protein_sequence(protein_sequence: str, mask_positions: list):
    # Tokenize the protein sequence
    inputs = tokenizer(protein_sequence, return_tensors="pt")

    # Mask the specified positions
    for pos in mask_positions:
        inputs["input_ids"][0][pos] = tokenizer.mask_token_id

    # Get the logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Get the logits for the masked positions
    mask_logits = logits[0, mask_positions]

    # Apply softmax to convert logits to probabilities
    probabilities = torch.nn.functional.softmax(mask_logits, dim=-1)

    # Convert the probabilities to a dictionary with vocabulary elements as keys
    vocab = tokenizer.get_vocab()
    vocab = {v: k for k, v in vocab.items()}  # reverse the key-value pairs in the vocab
    
    probabilities_dicts = []
    for probs in probabilities:
        probs_dict = {vocab[i]: prob.item() for i, prob in enumerate(probs)}
        probabilities_dicts.append(probs_dict)

    return probabilities_dicts

# Test the function with a protein sequence and mask positions
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mask_positions = [67, 82, 83]
mask_protein_sequence(protein_sequence, mask_positions)


[{'<cls>': 6.083488557351302e-08,
  '<pad>': 1.8497865999361762e-12,
  '<eos>': 1.0375320869115967e-07,
  '<unk>': 1.8288044689729333e-12,
  'L': 0.08851505070924759,
  'A': 0.06841419637203217,
  'G': 0.09989246726036072,
  'V': 0.09461090713739395,
  'S': 0.04620466008782387,
  'E': 0.055276963859796524,
  'R': 0.04487161338329315,
  'T': 0.0466335266828537,
  'I': 0.08222366869449615,
  'D': 0.03507358580827713,
  'P': 0.036110538989305496,
  'K': 0.09742899239063263,
  'Q': 0.034914303570985794,
  'N': 0.02753445692360401,
  'F': 0.041949138045310974,
  'Y': 0.03896258771419525,
  'M': 0.02151491306722164,
  'H': 0.018268506973981857,
  'W': 0.01214656513184309,
  'C': 0.009445960633456707,
  'X': 6.632502390857553e-06,
  'B': 2.874836013688764e-07,
  'U': 2.169133779261756e-07,
  'Z': 1.2401856963606406e-07,
  'O': 4.294765876267093e-09,
  '.': 2.9526474598640107e-09,
  '-': 2.9779174681721088e-09,
  '<null_1>': 2.850385927288812e-09,
  '<mask>': 1.8517988791683093e-12},
 {'<cls>'

In [84]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

def mask_protein_sequence(protein_sequence: str, mask_positions: list):
    # Tokenize the protein sequence
    inputs = tokenizer(protein_sequence, return_tensors="pt")

    # Mask the specified positions
    for pos in mask_positions:
        inputs["input_ids"][0][pos] = tokenizer.mask_token_id

    # Get the logits
    with torch.no_grad():
        logits = model(**inputs).logits

    # Get the logits for the masked positions
    mask_logits = logits[0, mask_positions]

    # Apply softmax to convert logits to probabilities
    probabilities = torch.nn.functional.softmax(mask_logits, dim=-1)

    # Convert the probabilities to log probabilities
    log_probabilities = torch.log(probabilities)

    # Convert the log probabilities to a dictionary with vocabulary elements as keys
    vocab = tokenizer.get_vocab()
    vocab = {v: k for k, v in vocab.items()}  # reverse the key-value pairs in the vocab
    
    log_probabilities_dicts = []
    for log_probs in log_probabilities:
        log_probs_dict = {vocab[i]: log_prob.item() for i, log_prob in enumerate(log_probs)}
        log_probabilities_dicts.append(log_probs_dict)

    return log_probabilities_dicts

# Test the function with a protein sequence and mask positions
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mask_positions = [67, 82, 83]
mask_protein_sequence(protein_sequence, mask_positions)


[{'<cls>': -16.615102767944336,
  '<pad>': -27.01595115661621,
  '<eos>': -16.08125114440918,
  '<unk>': -27.027359008789062,
  'L': -2.4245827198028564,
  'A': -2.6821749210357666,
  'G': -2.3036611080169678,
  'V': -2.3579823970794678,
  'S': -3.074674606323242,
  'E': -2.8953990936279297,
  'R': -3.103949785232544,
  'T': -3.0654356479644775,
  'I': -2.498311996459961,
  'D': -3.350306987762451,
  'P': -3.3211705684661865,
  'K': -2.3286314010620117,
  'Q': -3.354858636856079,
  'N': -3.5923171043395996,
  'F': -3.171297311782837,
  'Y': -3.2451534271240234,
  'M': -3.8390090465545654,
  'H': -4.00257682800293,
  'W': -4.410708904266357,
  'C': -4.662168025970459,
  'X': -11.923528671264648,
  'B': -15.062100410461426,
  'U': -15.343768119812012,
  'Z': -15.90283489227295,
  'O': -19.265869140625,
  '.': -19.64056396484375,
  '-': -19.632041931152344,
  '<null_1>': -19.675811767578125,
  '<mask>': -27.014863967895508},
 {'<cls>': -15.54505729675293,
  '<pad>': -25.044166564941406,
 

In [107]:
# You can add the following function in your local environment:

def compare_log_probs(protein_sequence: str, mask_positions: list, vocab_pairs: list):
    # Get the log probabilities for the masked positions
    log_probabilities_dicts = mask_protein_sequence(protein_sequence, mask_positions)
    
    # Compute the difference in log probabilities for the specified vocabulary pairs
    log_prob_diffs = []
    for i, (wt, mt) in enumerate(vocab_pairs):
        log_prob_diff = log_probabilities_dicts[i][mt] - log_probabilities_dicts[i][wt]
        log_prob_diffs.append(log_prob_diff)

    return log_prob_diffs

# Test the function with a protein sequence, mask positions, and vocabulary pairs
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mask_positions = [69, 84, 85]
vocab_pairs = [('L', 'R'), ('E', 'D'), ('K', 'A')]

print(compare_log_probs(protein_sequence, mask_positions, vocab_pairs))


[-0.8849828243255615, -0.1657564640045166, 0.05828452110290527]


In [108]:
sum(compare_log_probs(protein_sequence, mask_positions, vocab_pairs))

-0.9924547672271729