In [1]:
# Install required dependencies
%pip install transformers torch matplotlib numpy ipywidgets

Note: you may need to restart the kernel to use updated packages.


In [2]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display

In [3]:
# to restart kernels
%load_ext autoreload
%autoreload 2

In [4]:


def generate_heatmap(protein_sequence, start_pos=1, end_pos=None):
    # Load the model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Tokenize the input sequence
    input_ids = tokenizer.encode(protein_sequence, return_tensors="pt")
    sequence_length = input_ids.shape[1] - 2  # Excluding the special tokens

    # Adjust end position if not specified
    if end_pos is None:
        end_pos = sequence_length

    # List of amino acids
    amino_acids = list("ACDEFGHIKLMNPQRSTVWY")

    # Initialize heatmap
    heatmap = np.zeros((20, end_pos - start_pos + 1))

    # Calculate LLRs for each position and amino acid
    for position in range(start_pos, end_pos + 1):
        # Mask the target position
        masked_input_ids = input_ids.clone()
        masked_input_ids[0, position] = tokenizer.mask_token_id
        
        # Get logits for the masked token
        with torch.no_grad():
            logits = model(masked_input_ids).logits
            
        # Calculate log probabilities
        probabilities = torch.nn.functional.softmax(logits[0, position], dim=0)
        log_probabilities = torch.log(probabilities)
        
        # Get the log probability of the wild-type residue
        wt_residue = input_ids[0, position].item()
        log_prob_wt = log_probabilities[wt_residue].item()
        
        # Calculate LLR for each variant
        for i, amino_acid in enumerate(amino_acids):
            log_prob_mt = log_probabilities[tokenizer.convert_tokens_to_ids(amino_acid)].item()
            heatmap[i, position - start_pos] = log_prob_mt - log_prob_wt

    # Visualize the heatmap
    plt.figure(figsize=(15, 5))
    plt.imshow(heatmap, cmap="viridis", aspect="auto")
    plt.xticks(range(end_pos - start_pos + 1), list(protein_sequence[start_pos-1:end_pos]))
    plt.yticks(range(20), amino_acids)
    plt.xlabel("Position in Protein Sequence")
    plt.ylabel("Amino Acid Mutations")
    plt.title("Predicted Effects of Mutations on Protein Sequence (LLR)")
    plt.colorbar(label="Log Likelihood Ratio (LLR)")
    plt.show()

def interactive_heatmap(protein_sequence):
    # Define interactive widgets
    start_slider = widgets.IntSlider(value=1, min=1, max=len(protein_sequence), step=1, description='Start:')
    end_slider = widgets.IntSlider(value=len(protein_sequence), min=1, max=len(protein_sequence), step=1, description='End:')

    ui = widgets.HBox([start_slider, end_slider])

    def update_heatmap(start, end):
        if start <= end:
            generate_heatmap(protein_sequence, start, end)

    out = widgets.interactive_output(update_heatmap, {'start': start_slider, 'end': end_slider})

    # Display the interactive widgets
    display(ui, out)


In [5]:
# Example usage:
protein_sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
app = "KLVFFAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHHGVVEV"
interactive_heatmap(app)

HBox(children=(IntSlider(value=1, description='Start:', max=52, min=1), IntSlider(value=52, description='End:'…

Output()