# Masked Protein Modeling and Visualization

## Get Top $n$ Predictions for a Masked Amino Acid

In [22]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

# Load pre-trained model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = EsmForMaskedLM.from_pretrained(model_name)

# Prepare the masked input (masked a V)
masked_input = "MA<mask>ESRVTQEEIKKEPEK"
inputs = tokenizer(masked_input, return_tensors="pt")

# Get the predictions from the model
with torch.no_grad():
    outputs = model(**inputs)
    predictions = outputs.logits

# Retrieve the index of the mask token
mask_token_index = torch.where(inputs["input_ids"][0] == tokenizer.mask_token_id)

# Get the top 5 predicted token ids for the mask token
n = 5
top_5_predicted_token_ids = torch.topk(predictions[0, mask_token_index], n).indices[0]

# Convert the token ids to tokens and print them
for token_id in top_5_predicted_token_ids:
    print(tokenizer.decode([token_id]))


E
K
S
D
A


## Get Top $n$ Predictions for a Subset of Masked Amino Acids

In [23]:
from transformers import AutoTokenizer, EsmForMaskedLM
import torch

def predict_masked_positions(sequence, positions, n):
    # Load pre-trained model and tokenizer
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = EsmForMaskedLM.from_pretrained(model_name)

    # Store the top n predictions for each position
    predictions = []

    # For each position to be masked
    for pos in positions:
        # Create a new sequence with the position masked
        masked_sequence = sequence[:pos] + "<mask>" + sequence[pos+1:]
        inputs = tokenizer(masked_sequence, return_tensors="pt")

        # Get the predictions from the model
        with torch.no_grad():
            outputs = model(**inputs)
            prediction = outputs.logits

        # Retrieve the index of the mask token
        mask_token_index = torch.where(inputs["input_ids"][0] == tokenizer.mask_token_id)

        # Get the top n predicted token ids for the mask token
        top_n_predicted_token_ids = torch.topk(prediction[0, mask_token_index], n).indices[0]

        # Convert the token ids to tokens
        predicted_tokens = [tokenizer.decode([token_id]) for token_id in top_n_predicted_token_ids]

        # Store the predictions
        predictions.append(predicted_tokens)

    return predictions


In [24]:
sequence = "MAVESRVTQEEIKKEPEK"
positions = [3, 4, 7]  # mask the 2nd, 6th, and 10th positions
n = 5  # top 5 predictions

predict_masked_positions(sequence, positions, n)


[['L', 'S', 'V', 'K', 'A'],
 ['K', 'E', 'V', 'L', 'G'],
 ['K', 'E', 'L', 'V', 'S']]

## Get all Sequences Predicted by Previous Cell

In [25]:
import itertools

def get_predicted_sequences(sequence, positions, predictions):
    # Create a list to store the possible sequences
    sequences = []

    # Get the Cartesian product of the predictions
    for predicted_tokens in itertools.product(*predictions):
        # Copy the original sequence
        new_sequence = list(sequence)

        # Replace the masked positions with the predicted tokens
        for pos, token in zip(positions, predicted_tokens):
            new_sequence[pos] = token

        # Add the new sequence to the list
        sequences.append(''.join(new_sequence))

    return sequences


In [26]:
sequence = "MAVESRVTQEEIKKEPEK" # original sequence
positions = [3, 4, 7]  # mask these positions
n = 2  # top n predictions

# Get the predictions for the masked positions
predictions = predict_masked_positions(sequence, positions, n)

# Generate all possible sequences
sequences = get_predicted_sequences(sequence, positions, predictions)

# Print the sequences (of size n^m, where m = len(positions))
for seq in sequences:
    print(seq)


MAVLKRVKQEEIKKEPEK
MAVLKRVEQEEIKKEPEK
MAVLERVKQEEIKKEPEK
MAVLERVEQEEIKKEPEK
MAVSKRVKQEEIKKEPEK
MAVSKRVEQEEIKKEPEK
MAVSERVKQEEIKKEPEK
MAVSERVEQEEIKKEPEK


## Predict Folded Atomic Coordinates

Now, we can predict the $3D$ atomic coordinates of all of the atoms in one of the protein sequences predicted above. 

In [27]:
from transformers import AutoTokenizer, EsmForProteinFolding

model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
inputs = tokenizer(["MAVESRVTQEEIKKEPEK"], return_tensors="pt", add_special_tokens=False)  #A tiny random peptide
outputs = model(**inputs)
folded_positions = outputs.positions

Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.weight', 'esm.contact_head.regression.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [28]:
folded_positions

tensor([[[[[-1.1769e+01, -5.2021e+00, -2.6651e+01],
           [-1.0810e+01, -6.0191e+00, -2.5914e+01],
           [-1.0552e+01, -5.4380e+00, -2.4527e+01],
           ...,
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00]],

          [[-9.7011e+00, -5.2572e+00, -2.3114e+01],
           [-9.8805e+00, -4.7423e+00, -2.1759e+01],
           [-8.5372e+00, -4.4074e+00, -2.1117e+01],
           ...,
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00]],

          [[-8.6835e+00, -3.6361e+00, -1.9533e+01],
           [-7.6022e+00, -4.1379e+00, -1.8692e+01],
           [-7.9442e+00, -3.9118e+00, -1.7221e+01],
           ...,
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00],
           [-0.0000e+00, -0.0000e+00, -0.0000e+00]],

          

## Visualize the Protein Sequence and Highlight the Masked Amino Acids

Now, we can plot one of the predicted sequences from above. This can be used to compare the topologies of each predicted sequence in the above list as the masked amino acids are changed. The atoms corresponding to the positions `[3, 4, 7]` are highlighted in red, and each atom is labeled by the amino acid it lies in. 

In [29]:
import torch
import plotly.graph_objects as go
import numpy as np

# flatten the tensor to 2D shape
folded_positions_flat = folded_positions.view(-1, 3)

# remove points with all zeros
folded_positions_filtered = folded_positions_flat[~torch.all(folded_positions_flat == 0, dim=1)]

# detach the tensor, we don't need gradients here
folded_positions_filtered = folded_positions_filtered.detach()

# extract individual coordinates
x = folded_positions_filtered[:, 0].numpy()
y = folded_positions_filtered[:, 1].numpy()
z = folded_positions_filtered[:, 2].numpy()

# define which residues are to be colored red
protein_sequence = "MAVESRVTQEEIKKEPEK"
residue_indices = [3, 4, 7]
residue_atom_counts = {
    'A': 13,
    'R': 24,
    'N': 14,
    'D': 12,
    'C': 11,
    'E': 15,
    'Q': 17,
    'G': 10,
    'H': 20,
    'I': 22,
    'L': 22,
    'K': 22,
    'M': 17,
    'F': 20,
    'P': 14,
    'S': 11,
    'T': 14,
    'W': 27,
    'Y': 21,
    'V': 16
}  # atom counts for each residue type

color_indices = []
labels = []
index = 0
for i, residue in enumerate(protein_sequence):
    for _ in range(residue_atom_counts[residue]):
        labels.append(residue)
        if i in residue_indices:
            color_indices.append(index)
        index += 1

colors = ['red' if i in color_indices else 'blue' for i in range(len(x))]

# distance threshold
epsilon = 0.0  # change to desired value

# prepare data for lines
Xe, Ye, Ze = [], [], []
for i in range(len(x)):
    for j in range(i + 1, len(x)):
        if np.sqrt((x[i] - x[j])**2 + (y[i] - y[j])**2 + (z[i] - z[j])**2) < epsilon:
            Xe += [x[i], x[j], None]
            Ye += [y[i], y[j], None]
            Ze += [z[i], z[j], None]

trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='black', width=2))

# create a 3D scatter plot
trace_nodes = go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers+text',
    marker=dict(
        size=5,
        color=colors,                # set color to an array/list of desired values
        opacity=0.8
    ),
    text=labels,  # Add labels
    textposition="top center"
)

fig = go.Figure(data=[trace_edges, trace_nodes])

# set labels for axes
fig.update_layout(scene = dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'))

fig.show()
