#Jupyter Notebook to run the IS-DSM prediction and create figure number 6

In [2]:
%%capture

!pip install fair-esm 
!pip install torch
!pip install biopython

In [3]:
import argparse
import pathlib
import string
import torch
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
import pandas as pd
from tqdm import tqdm
from Bio import SeqIO
import itertools
from typing import List, Tuple
import numpy as np
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

# Functions #

In [5]:
def generate_mutations(seq, alphabet, zero_based_indexing):
    
    "The function lacks a control for non-standard symbols UZOB"
    
    mutations = []
    off = 1 if not zero_based_indexing else 0

    for i in range(len(seq)):
        alphabet_ = alphabet.copy()
        alphabet_.remove(seq[i])
        for j in alphabet_:
            mutations.append(seq[i] + str(i + off) + j)
    
    return mutations

def generate_mutations_(seq, alphabet, zero_based_indexing):
    
    "The function lacks a control for non-standard symbols UZOB"
    
    mutations = []
    off = 1 if not zero_based_indexing else 0

    for i in range(len(seq)):
        alphabet_ = alphabet.copy()
        for j in alphabet_:
            mutations.append(seq[i] + str(i + off) + j)
    
    return mutations

def remove_insertions(sequence: str) -> str:
    """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
    # This is an efficient way to delete lowercase characters and insertion characters from a string
    deletekeys = dict.fromkeys(string.ascii_lowercase)
    deletekeys["."] = None
    deletekeys["*"] = None

    translation = str.maketrans(deletekeys)
    return sequence.translate(translation)


def read_msa(filename: str, nseq: int) -> List[Tuple[str, str]]:
    """ Reads the first nseq sequences from an MSA file, automatically removes insertions.
    
    The input file must be in a3m format (although we use the SeqIO fasta parser)
    for remove_insertions to work properly."""

    msa = [
        (record.description, remove_insertions(str(record.seq)))
        for record in itertools.islice(SeqIO.parse(filename, "fasta"), nseq)
    ]
    return msa


def create_parser():
    parser = argparse.ArgumentParser(
        description="Label a deep mutational scan with predictions from an ensemble of ESM-1v models."  # noqa
    )

    # fmt: off
    parser.add_argument(
        "--model-location",
        type=str,
        help="PyTorch model file OR name of pretrained model to download (see README for models)",
        nargs="+",
    )
    parser.add_argument(
        "--sequence",
        type=str,
        help="Base sequence to which mutations were applied",
    )
    parser.add_argument(
        "--dms-input",
        type=pathlib.Path,
        help="CSV file containing the deep mutational scan",
    )
    parser.add_argument(
        "--mutation-col",
        type=str,
        default="mutant",
        help="column in the deep mutational scan labeling the mutation as 'AiB'"
    )
    parser.add_argument(
        "--dms-output",
        type=pathlib.Path,
        help="Output file containing the deep mutational scan along with predictions",
    )
    parser.add_argument(
        "--offset-idx",
        type=int,
        default=0,
        help="Offset of the mutation positions in `--mutation-col`"
    )
    parser.add_argument(
        "--scoring-strategy",
        type=str,
        default="wt-marginals",
        choices=["wt-marginals", "pseudo-ppl", "masked-marginals"],
        help=""
    )
    parser.add_argument(
        "--msa-path",
        type=pathlib.Path,
        help="path to MSA in a3m format (required for MSA Transformer)"
    )
    parser.add_argument(
        "--msa-samples",
        type=int,
        default=400,
        help="number of sequences to select from the start of the MSA"
    )
    # fmt: on
    parser.add_argument("--nogpu", action="store_true", help="Do not use GPU even if available")
    return parser


def label_row(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, print("The listed wildtype does not match the provided sequence", idx, sequence[idx], wt, mt)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = token_probs[0, 1 + idx, mt_encoded] - token_probs[0, 1 + idx, wt_encoded]
    return score.item()

def label_row_likelihood(row, sequence, token_probs, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    assert sequence[idx] == wt, print("The listed wildtype does not match the provided sequence", idx, sequence[idx], wt, mt)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = token_probs[0, 1 + idx, mt_encoded]
    return score.item()


def compute_pppl(row, sequence, model, alphabet, offset_idx):
    wt, idx, mt = row[0], int(row[1:-1]) - offset_idx, row[-1]
    print(idx)
    assert sequence[idx] == wt, "The listed wildtype does not match the provided sequence"

    # modify the sequence
    sequence = sequence[:idx] + mt + sequence[(idx + 1) :]

    # encode the sequence
    data = [
        ("protein1", sequence),
    ]

    batch_converter = alphabet.get_batch_converter()

    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # compute probabilities at each position
    log_probs = []
    for i in range(1, len(sequence) - 1):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i] = alphabet.mask_idx
        with torch.no_grad():
            token_probs = torch.log_softmax(model(batch_tokens_masked.cuda())["logits"], dim=-1)
        log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i])].item())  # vocab size
    return sum(log_probs)


def main(args):
    # Load the deep mutational scan
    df = pd.read_csv(args["dms_input"])

    # inference for each model
    for model_location in args["model_location"]:
        model, alphabet = pretrained.load_model_and_alphabet(model_location)
        model.eval()
        
        model = model.cuda()
        print("Transferred model to GPU")

        batch_converter = alphabet.get_batch_converter()

        
        
        data = [("protein1", args["sequence"]), ]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)

        if args["scoring_strategy"] == "wt-marginals":
            with torch.no_grad():
                token_probs = torch.log_softmax(model(batch_tokens.cuda())["logits"], dim=-1)
            df[model_location] = df.apply(
                lambda row: label_row(
                    row[args["mutation_col"]],
                    args["sequence"],
                    token_probs,
                    alphabet,
                    args["offset_idx"],
                ),
                axis=1,
            )
        elif args["scoring_strategy"] == "masked-marginals":
            all_token_probs = []
            for i in tqdm(range(batch_tokens.size(1))):
                batch_tokens_masked = batch_tokens.clone()
                batch_tokens_masked[0, i] = alphabet.mask_idx
                with torch.no_grad():
                    token_probs = torch.log_softmax(
                        model(batch_tokens_masked.cuda())["logits"], dim=-1
                    )
                all_token_probs.append(token_probs[:, i])  # vocab size
            token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
            df[model_location] = df.apply(
                lambda row: label_row(
                    row[args["mutation_col"]],
                    args["sequence"],
                    token_probs,
                    alphabet,
                    args["offset_idx"],
                ),
                axis=1,
            )
        elif args["scoring_strategy"] == "pseudo-ppl":
            tqdm.pandas()
            df[model_location] = df.progress_apply(
                lambda row: compute_pppl(
                    row[args["mutation_col"]], args["sequence"], model, alphabet, args["offset_idx"]
                ),
                axis=1,
            )

    df.to_csv(args["dms_output"])

def main_only_likelihood(args):
    # Load the deep mutational scan
    df = pd.read_csv(args["dms_input"])

    # inference for each model
    for model_location in args["model_location"]:
        model, alphabet = pretrained.load_model_and_alphabet(model_location)
        model.eval()
        model = model.cuda()
        print("Transferred model to GPU")

        batch_converter = alphabet.get_batch_converter()
        data = [("protein1", args["sequence"]), ]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)

        all_token_probs = []
        for i in tqdm(range(batch_tokens.size(1))):
            batch_tokens_masked = batch_tokens.clone()
            batch_tokens_masked[0, i] = alphabet.mask_idx
            with torch.no_grad():
                token_probs = torch.log_softmax( model(batch_tokens_masked.cuda())["logits"], dim=-1)
            all_token_probs.append(token_probs[:, i])  # vocab size
        token_probs = torch.cat(all_token_probs, dim=0).unsqueeze(0)
        df[model_location] = df.apply( lambda row: label_row_likelihood( row[args["mutation_col"]], args["sequence"], token_probs, alphabet, args["offset_idx"] ), axis=1 )

    df.to_csv(args["dms_output"])

def alphabet_to_mat_position(x, alphabet):
    return alphabet[x]

def mat_position_to_alphabet(x, alphabet):
    inv_map = {v: k for k, v in alphabet.items()}
    return inv_map[x]





# IS-DMS creation #

In [None]:
#Create a dictionary with all possible residues
alph = {'A': 0, 'R': 1, 'N': 2, 'D': 3, 'C': 4, 'Q': 5 ,'E': 6, 'G': 7, 'H': 8, 'I': 9, 'L': 10, 'K': 11, 'M': 12, 'F': 13, 'P': 14, 'S': 15, 'T': 16, 'W': 17, 'Y': 18, 'V': 19}
Sequence_to_be_ran = "Paste your sequence here"
Length_of_protein = len(Sequence_to_be_ran)
#Generates all possible mutations in every position
mutations_to_hyperactive = generate_mutations(Sequence_to_be_ran, alphabet = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'K', 'V', 'Y', 'W'], zero_based_indexing = False)

#Generates a pandas dataframe with all the mutation
sequence_data_frame = pd.DataFrame({'mutant': mutations_to_hyperactive, 'bla': np.nan})

#Transforms the data frame to a CSV file
sequence_data_frame.to_csv('Hyperactive_DMS.csv')

#Creates a library with all the necessary values to run the main function, the different models (esm1v_t33_650M_UR90S_#.pt) must have been previously downloaded and their location must be given in model_location
args_ens_1 = {"model_location": ["/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_1.pt", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_2.pt", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_3.pt", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_4.pt", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_5.pt"], 
        "sequence": Sequence_to_be_ran, 
        "dms_input": "Hyperactive_DMS.csv",
        "mutation_col": "mutant", 
        "dms_output": "Hyperactive_DMS.csv", 
        "offset_idx": 1, 
        "scoring_strategy": "masked-marginals"}

#Runs ESM-1v variant prediction
main(args_ens_1)

#Main 1 writes the calculated fitness score for each possible mutation for each model directly to Hyperactive_DMS.csv 

#Create figure 6 #

In [6]:
#Drop extra columns from Hyperactive_DMS.csv
Seq_df = pd.read_csv('Hyperactive_DMS.csv').drop(['Unnamed: 0.1', 'Unnamed: 0', 'bla'], axis = 1)

#Rename the columns with the model names
Seq_df = Seq_df.rename(columns={"/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_1.pt": "esm1v_t33_650M_UR90S_1", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_2.pt": "esm1v_t33_650M_UR90S_2","/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_3.pt": "esm1v_t33_650M_UR90S_3","/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_4.pt": "esm1v_t33_650M_UR90S_4", "/content/drive/Shareddrives/ESM_for_variant_prediction/esm1v_t33_650M_UR90S_5.pt": "esm1v_t33_650M_UR90S_5"})

#Create a column with the mean score for each mutation from all the models
Seq_df['ensemble'] = Seq_df[['esm1v_t33_650M_UR90S_1', 'esm1v_t33_650M_UR90S_2', 'esm1v_t33_650M_UR90S_3', 'esm1v_t33_650M_UR90S_4', 'esm1v_t33_650M_UR90S_5']].mean(axis=1)

#Create a column to number the mutations according to their position
Seq_df['values'] = [int(i[1:-1]) - 1 for i in Seq_df['mutant'].values]

#Create a column to number the mutation according to what residue is the mutation
Seq_df['mutation_mat_position'] = [alphabet_to_mat_position(i[-1], alph) for i in Seq_df['mutant'].values]

#Create two numpy arrays where the positions of the residues are grouped together with the mean value in temp_1 and with the type of residue in temp_2
temp_1 = np.array(list(Seq_df[['values', 'ensemble']].groupby('values').ensemble.apply(list).reset_index()['ensemble'].values))
temp_2 = np.array(list(Seq_df[['values', 'mutation_mat_position']].groupby('values').mutation_mat_position.apply(list).reset_index()['mutation_mat_position'].values))


#Create a list of lists of zero values according to the length of the sequence and all possible mutations(20)
full_mutations_matrix_HPB = np.zeros((Length_of_protein, 20))*np.nan

#Loop through length of the protein to fill the full_mutations_matrix with the mutation values
for j in range(Length_of_protein):
    for i in zip(temp_1[j, :], temp_2[j, :]):
      full_mutations_matrix_HPB[j, i[1]] = i[0]

#Delete temp_1 and temp_2
del temp_1, temp_2

#Create the figure and axis to plot the mutation scores
fig, ax = plt.subplots(1, figsize = (12, 5))

#Create an axvspan according to the different domains of the piggyBac
ax.axvspan(xmin=0,  xmax=116,   ymin=0, linewidth=0, color='grey',     alpha = 0.1, label = 'N-terminal');
ax.axvspan(xmin=116,  xmax=262,   ymin=0, linewidth=0, color='yellow',     alpha = 0.1, label = 'DDBD');
ax.axvspan(xmin=262,  xmax=371,   ymin=0, linewidth=0, color='green',     alpha = 0.1, label = 'Catalitic_domain');
ax.axvspan(xmin=371,  xmax=432,   ymin=0, linewidth=0, color='blue',     alpha = 0.1, label = 'Insertion_domain');
ax.axvspan(xmin=432,  xmax=456,   ymin=0, linewidth=0, color='green',     alpha = 0.1, label = 'Catalitic_domain');
ax.axvspan(xmin=456,  xmax=534,   ymin=0, linewidth=0, color='yellow',     alpha = 0.1, label = 'DDBD');
ax.axvspan(xmin=534,  xmax=552,   ymin=0, linewidth=0, color='grey',     alpha = 0.1, label = '--');
ax.axvspan(xmin=552,  xmax=593,   ymin=0, linewidth=0, color='red',     alpha = 0.1, label = 'CRD');


#Set ticks to mark the position of the different domains in the x axis
ax.set_xticks([0, 116, 262, 371, 432, 456, 534, 552, 593])
ax.set_xticklabels([1, 117 , 263, 372, 433 , 457, 535 , 553, 594 ])



#Loop through all the possible mutations at each position and plot them in a scater plot
for i in range(20):
    ax.scatter(np.arange(0, full_mutations_matrix_HPB.shape[0]), full_mutations_matrix_HPB[:, i], s = 0.5, color = 'blue', alpha = 0.5)

# Set the legends and spine    
ax.spines[['right', 'top', 'bottom']].set_visible(False)
legend = ax.legend();
for lh in legend.legendHandles: 
    lh.set_alpha(1)

#Create a red doted line to mark the 0 value on the graph 
ax.axhline(y=0, color='r', linestyle='--', alpha = 0.5);


## Plot heatmap with fitness scores

In [None]:
alphabet = ''.join(alph.keys())

#Select the residues to be ploted together with all its possible mutations and scores in a heatmap
residues = [1,2,5,6]
matrix = full_mutations_matrix_HPB[residues, :]

# Create a heatmap of the mutations
fig, ax = plt.subplots(figsize = (30, 100))
im = ax.imshow(matrix, cmap='RdYlBu')

ax.set_xticks(np.arange(len(alphabet)))
ax.set_yticks([0, 1, 2, 3, 4, 5, 6, 7, 8])
ax.set_xticklabels(list(alphabet))
ax.set_yticklabels(residues)

# Rotate the x-axis labels for better readability
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor");

cbar = ax.figure.colorbar(im, ax=ax, fraction = 0.01)

for i in range(len(residues)):
    for j in range(len(alphabet)):
        text = ax.text(j, i, round(matrix[i, j], 2), ha="center", va="center", color="black", size = 20)