# ESM2 Protein Sequence Likelihood Calculator

This notebook computes the per-position pseudo-log-likelihood (PPPL) for protein sequences using the ESM2 language model. It can process either a single sequence or multiple sequences from a CSV file.

## How to use:

1. For a single sequence:
   - Set the `sequence` variable in the cell below.
   - Run all cells.

2. For multiple sequences in a CSV file:
   - Set the `input_csv` variable in the cell below.
   - Optionally set `output_csv`, `num_sequences`, and other parameters.
   - Run all cells.

Note: Ensure your input CSV file has a column named "sequence" containing the protein sequences.

In [None]:
import argparse
from tqdm import tqdm
import pandas as pd
import torch

# For notebook use, we'll set default values instead of using argparse
max_length = 1022
model = "ESM2_650M"
num_sequences = None
skip_nham_aa = False

# Uncomment and modify these lines to set your parameters
# sequence = "TEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEYSAMRDQYMRTGEGFLCVFAINNTKSFEDIHHYREQIKRVKDSEDVPMVLVGNKCDLPSRTVDTKQAQDLARSYGIPFIETSAKTRQGVDDAFYTLVREIRKHKEKMSKDGKKKKKKSKTKCVIM"
input_csv = "KRAS_data/KRAS_data_BindingPCA_DARPin_K55_muts.csv"
output_csv = "PCA_DARP_in_K55_pLLs.csv"
model = "ESM2_650M"
num_sequences = 100

In [None]:
def determine_esm_details(model_name):
    if model_name == "ESM2_15B":
        return "esm2_t48_15B_UR50D"
    elif model_name == "ESM2_3B":
        return "esm2_t36_3B_UR50D"
    elif model_name == "ESM2_650M":
        return "esm2_t33_650M_UR50D"
    else:
        raise ValueError(f"Unknown model: {model_name}")

In [None]:
def rename_column_to_sequence(df, original_column_name):
    df.rename(columns={original_column_name: "sequence"}, inplace=True)
    return df

In [None]:
# Load the model
model_name = model
model_details = determine_esm_details(model_name)

print(f"==> Loading model {model_name}")
model, alphabet = torch.hub.load("facebookresearch/esm:main", model_details)
batch_converter = alphabet.get_batch_converter()
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Using device: {device}")

In [None]:
def compute_pppl(sequence, verbose=False):
    with torch.no_grad():
        data = [("protein1", sequence)]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)
        if verbose:
            print(batch_tokens)

        # compute probabilities at each position
        log_probs = []
        for i in range(1, len(sequence) + 1):
            batch_tokens_masked = batch_tokens.clone()
            batch_tokens_masked[0, i] = alphabet.mask_idx
            if verbose:
                print(batch_tokens_masked)
            with torch.no_grad():
                outputs = model(batch_tokens_masked, repr_layers=[33], return_contacts=False)
                token_probs = torch.log_softmax(outputs["logits"], dim=-1)
                if verbose:
                    print(token_probs)
            log_probs.append(token_probs[0, i, alphabet.get_idx(sequence[i - 1])].item())  # vocab size
        if verbose:
            print(log_probs)
        return sum(log_probs) / len(sequence)

In [None]:
# Process single sequence
if sequence:
    print(f"Computing likelihood for single sequence: {sequence}")
    pppl = compute_pppl(sequence, verbose=True)
    print(f"PPPL: {pppl}")

# Process multiple sequences from CSV
elif input_csv:
    df = pd.read_csv(input_csv)
    if "sequence" not in df.columns:
        raise ValueError("Input CSV must contain a 'sequence' column")

    ll_list = []
    for i, row in enumerate(tqdm(df.itertuples(), total=len(df))):
        seq = getattr(row, "sequence")
        nham_aa = getattr(row, "Nham_aa", None)

        if '*' in seq:
            continue
        if skip_nham_aa and nham_aa != 1:
            continue
        if num_sequences is not None and i >= num_sequences:
            break

        this_pppl = compute_pppl(seq[:max_length])
        ll_list.append(this_pppl)

    df = df.iloc[:len(ll_list)]
    df[f"{model}_pppl"] = ll_list
    df.to_csv(output_csv, index=False)
    print(f"Done. Output saved to {output_csv}")
else:
    print("Please provide either a sequence or an input CSV file.")

## Next Steps

- Analyze the results in the output CSV file.
- Visualize the PPPL scores if working with multiple sequences.
- Compare PPPL scores across different protein variants or families.