## Initial processing to prepare the sequences for ESM-2

In [1]:
import pandas as pd

df  = pd.read_csv("../data/intermediate/variant_protein_sequences.txt", sep="\t", dtype="str")
df = df.dropna(subset=["wt_seq", "ref_aa", "alt_aa", "pos_aa", "mut_seq"], how="any").drop_duplicates()

# Remove the stop codons
df["wt_seq"] = df["wt_seq"].str.rstrip("*")
df["mut_seq"] = df["mut_seq"].str.rstrip("*")

#### Update the sequences to only keep the +/- 127 residues from the pos_aa

In [2]:
df["pos_aa"] = pd.to_numeric(df["pos_aa"], errors="coerce").astype("Int64")

flank = 127

def extract_subsequence(full_seq, pos, flank=flank):
    if pd.isna(full_seq) or pos is None:
        return None
    
    pos_idx = pos - 1 
    start = max(0, pos_idx - flank)
    end = min(len(full_seq), pos_idx + flank + 1) 
    
    return full_seq[start:end]

df["wt_seq"] = df.apply(lambda row: extract_subsequence(row["wt_seq"], row["pos_aa"]), axis=1)
df["mut_seq"] = df.apply(lambda row: extract_subsequence(row["mut_seq"], row["pos_aa"]), axis=1)

df.drop_duplicates(inplace=True)

#### Generate wild-type and mutant protein IDs for all variants in the variant set (these IDs are used for merging later)

In [3]:
variant_set = df[["ID", "ensg", "ref_aa", "alt_aa", "pos_aa"]].copy()
variant_set.rename(columns={"ID": "variant_ID"}, inplace=True)
variant_set["wt_aa"] = "wt"
variant_set["wt_ID"] = variant_set[["ensg", "pos_aa", "ref_aa", "wt_aa"]].astype(str).agg("-".join, axis=1)
variant_set["mut_ID"] = variant_set[["ensg", "pos_aa", "ref_aa", "alt_aa"]].astype(str).agg("-".join, axis=1)
variant_set.drop(columns=["wt_aa"], inplace=True)

variant_set

Unnamed: 0,variant_ID,ensg,ref_aa,alt_aa,pos_aa,wt_ID,mut_ID
0,10-100057090-C-T,ENSG00000120054,D,N,312,ENSG00000120054-312-D-wt,ENSG00000120054-312-D-N
1,10-100069757-C-T,ENSG00000120054,G,D,178,ENSG00000120054-178-G-wt,ENSG00000120054-178-G-D
2,10-100076062-G-A,ENSG00000120054,A,V,90,ENSG00000120054-90-A-wt,ENSG00000120054-90-A-V
3,10-100081405-G-A,ENSG00000120054,P,L,74,ENSG00000120054-74-P-wt,ENSG00000120054-74-P-L
4,10-100152307-T-C,ENSG00000107566,I,V,291,ENSG00000107566-291-I-wt,ENSG00000107566-291-I-V
...,...,...,...,...,...,...,...
408702,19-48318735-G-A,ENSG00000105479,H,Y,50,ENSG00000105479-50-H-wt,ENSG00000105479-50-H-Y
408703,19-48318737-A-G,ENSG00000105479,V,A,49,ENSG00000105479-49-V-wt,ENSG00000105479-49-V-A
408704,19-48318744-T-G,ENSG00000105479,K,Q,47,ENSG00000105479-47-K-wt,ENSG00000105479-47-K-Q
408705,19-48318746-C-G,ENSG00000105479,S,T,46,ENSG00000105479-46-S-wt,ENSG00000105479-46-S-T


#### Construct a df with all sequences

In [4]:
df.drop(columns=["ID"], inplace=True)
wt = df.copy().drop(columns=["mut_seq", "alt_aa"]).drop_duplicates(subset=["ensg", "ref_aa", "pos_aa"], keep="first")
mut = df.copy().drop(columns=["wt_seq"])
mut["ID"] = mut[["ensg", "pos_aa", "ref_aa", "alt_aa"]].astype(str).agg("-".join, axis=1)
mut = mut[["ID"] + [col for col in mut.columns if col != "ID"]]
mut = mut.rename(columns={"mut_seq": "seq"})
wt["alt_aa"] = "wt"
wt["ID"] = wt[["ensg", "pos_aa", "ref_aa", "alt_aa"]].astype(str).agg("-".join, axis=1)
wt = wt[["ID"] + [col for col in wt.columns if col != "ID"]]
wt = wt.rename(columns={"wt_seq": "seq"})
wt = wt[mut.columns]
sequences = pd.concat([wt, mut], ignore_index=True)[["ID", "seq"]]
del df

sequences

Unnamed: 0,ID,seq
0,ENSG00000120054-312-D-wt,LPDNWKSQVEPETRAVIRWMHSFNFVLSANLHGGAVVANYPYDKSF...
1,ENSG00000120054-178-G-wt,IGRSVEGRHLYVLEFSDHPGIHEPLEPEVKYVGNMHGNEALGRELM...
2,ENSG00000120054-90-A-wt,MSDLLSVFLHLLLLFKLVAPVTFRHHRYDDLVRTLYKVQNECPGIT...
3,ENSG00000120054-74-P-wt,MSDLLSVFLHLLLLFKLVAPVTFRHHRYDDLVRTLYKVQNECPGIT...
4,ENSG00000107566-291-I-wt,GLTIQAVRVTKPKIPEAIRRNFELMEAEKTKLLIAAQKQKVVEKEA...
...,...,...
691451,ENSG00000105479-50-H-Y,MPLGRLAGSARSEEGSEAFLEGMVDWELSRLQRQCKVMEGERRAYS...
691452,ENSG00000105479-49-V-A,MPLGRLAGSARSEEGSEAFLEGMVDWELSRLQRQCKVMEGERRAYS...
691453,ENSG00000105479-47-K-Q,MPLGRLAGSARSEEGSEAFLEGMVDWELSRLQRQCKVMEGERRAYS...
691454,ENSG00000105479-46-S-T,MPLGRLAGSARSEEGSEAFLEGMVDWELSRLQRQCKVMEGERRAYT...


## Compute ESM-2 embeddings and perplexity

In [None]:
import os
import torch
import torch.nn.functional as F
from transformers import EsmForMaskedLM, EsmTokenizer
import math

OUTPUT_PATH = "../data/intermediate/esm2_output.pkl"
MODEL_NAME = "facebook/esm2_t33_650M_UR50D"

tokenizer = EsmTokenizer.from_pretrained(MODEL_NAME)
model = EsmForMaskedLM.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def compute_output(seq):
    """
    Compute mean-pooled protein embedding and perplexity for a given sequence.
    """
    spaced_seq = " ".join(seq)
    inputs = tokenizer(spaced_seq, return_tensors="pt", add_special_tokens=True, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)

    # Mean-pooled embedding
    last_hidden_state = outputs.hidden_states[-1]
    protein_emb = last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy()

    # Perplexity
    logits = outputs.logits
    log_probs = F.log_softmax(logits, dim=-1)
    input_ids = inputs["input_ids"].squeeze(0)
    true_token_log_probs = log_probs[0, torch.arange(input_ids.shape[0]), input_ids]
    valid_log_probs = true_token_log_probs[1:-1]
    mean_log_likelihood = valid_log_probs.mean().item()
    perplexity = math.exp(-mean_log_likelihood)

    return protein_emb, mean_log_likelihood, perplexity

protein_embeddings = []
mll_values = []
ppl_values = []

for idx, row in enumerate(sequences.itertuples(), start=1):
    emb, mll, ppl = compute_output(row.seq)
    protein_embeddings.append(emb)
    mll_values.append(mll)
    ppl_values.append(ppl)
    print(f"Processed {idx}/{len(sequences)} sequences ({idx/len(sequences)*100:.2f}%)", end="\r")

sequences["esm2_protein_embedding"] = protein_embeddings
sequences["esm2_mean_log_likelihood"] = mll_values
sequences["esm2_perplexity"] = ppl_values

os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
sequences.to_pickle(OUTPUT_PATH)
print(f"\nAll sequences processed and saved to {OUTPUT_PATH}")


  from .autonotebook import tqdm as notebook_tqdm
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processed 691456/691456 sequences (100.00%)
All sequences processed and saved to ../data/intermediate/esm2_output.pkl


## Merge ESM-2 embeddings and perplexity metrics for wild-type and mutant sequences

This step merges ESM-2 outputs (embeddings, mean log-likelihood, and perplexity) from a shared dataframe (sequences) into the variant_set. Each variant has a wild-type (wt_ID) and mutant (mut_ID) identifier, which are used to join the corresponding metrics.

In [6]:
esm2_data = sequences[["ID", "esm2_protein_embedding", "esm2_mean_log_likelihood", "esm2_perplexity"]].drop_duplicates(subset=["ID"])

# Merge ESM-2 metrics for wild-type sequences
variant_set = variant_set.merge(
    esm2_data,
    left_on="wt_ID",
    right_on="ID",
    how="left"
).rename(columns={
    "esm2_protein_embedding": "wt_protein_embedding",
    "esm2_mean_log_likelihood": "wt_esm2_mean_log_likelihood",
    "esm2_perplexity": "wt_esm2_perplexity"
})
variant_set.drop(columns=["ID"], inplace=True)

# Merge ESM-2 metrics for mutant sequences
variant_set = variant_set.merge(
    esm2_data,
    left_on="mut_ID",
    right_on="ID",
    how="left"
).rename(columns={
    "esm2_protein_embedding": "mut_protein_embedding",
    "esm2_mean_log_likelihood": "mut_esm2_mean_log_likelihood",
    "esm2_perplexity": "mut_esm2_perplexity"
})
variant_set.drop(columns=["ID"], inplace=True)

## Compute differences in perplexity and mean log-likelihood (mutant - wild-type)

This section calculates the change in ESM-2 metrics between the mutant and wild-type protein sequences. These deltas can help assess how mutations impact sequence likelihood.

In [7]:
variant_set["esm2_mean_log_likelihood_difference"] = (
    variant_set["mut_esm2_mean_log_likelihood"] - variant_set["wt_esm2_mean_log_likelihood"]
)

variant_set["esm2_perplexity_difference"] = (
    variant_set["mut_esm2_perplexity"] - variant_set["wt_esm2_perplexity"]
)

variant_set.drop(columns=[
    "wt_esm2_mean_log_likelihood", "wt_esm2_perplexity",
    "mut_esm2_mean_log_likelihood", "mut_esm2_perplexity"
], inplace=True)


## Compute cosine similarity and average absolute difference between embeddings

This step compares the high-dimensional protein embeddings from ESM-2 between the wild-type and mutant sequences. Cosine similarity captures directional changes in embedding space, while the average absolute difference reflects the magnitude of overall change.

In [8]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

def compute_cosine_similarity(wt_embedding, mut_embedding):
    if wt_embedding is None or mut_embedding is None:
        return np.nan
    wt_embedding = np.array(wt_embedding).reshape(1, -1)
    mut_embedding = np.array(mut_embedding).reshape(1, -1)
    return cosine_similarity(wt_embedding, mut_embedding)[0][0]

def compute_average_absolute_difference(wt_embedding, mut_embedding):
    if wt_embedding is None or mut_embedding is None:
        return np.nan
    wt_embedding = np.array(wt_embedding)
    mut_embedding = np.array(mut_embedding)
    return np.mean(np.abs(wt_embedding - mut_embedding))

variant_set["esm2_cosine_similarity"] = variant_set.apply(
    lambda row: compute_cosine_similarity(row["wt_protein_embedding"], row["mut_protein_embedding"]),
    axis=1
)

variant_set["esm2_avg_abs_difference"] = variant_set.apply(
    lambda row: compute_average_absolute_difference(row["wt_protein_embedding"], row["mut_protein_embedding"]),
    axis=1
)

variant_set.drop(columns=["wt_protein_embedding", "mut_protein_embedding"], inplace=True)
variant_set.drop_duplicates(inplace=True)


In [9]:
variant_set.drop_duplicates(inplace=True)
variant_set.rename(columns={"variant_ID": "ID"}, inplace=True)
variant_set = variant_set[["ID", "esm2_mean_log_likelihood_difference", "esm2_perplexity_difference", "esm2_cosine_similarity", "esm2_avg_abs_difference"]]

output_file = "../data/intermediate/esm2_wt_mut_comparison_metrics.txt"
os.makedirs(os.path.dirname(output_file), exist_ok=True)
variant_set.to_csv(output_file, sep="\t", index=False)

In [10]:
variant_set

Unnamed: 0,ID,esm2_mean_log_likelihood_difference,esm2_perplexity_difference,esm2_cosine_similarity,esm2_avg_abs_difference
0,10-100057090-C-T,-0.001133,0.001560,0.999975,0.001065
1,10-100069757-C-T,0.000762,-0.001011,0.999961,0.001273
2,10-100076062-G-A,0.006985,-0.008888,0.999914,0.001779
3,10-100081405-G-A,-0.009512,0.012117,0.999850,0.002473
4,10-100152307-T-C,-0.002086,0.002513,0.999942,0.001543
...,...,...,...,...,...
408702,19-48318735-G-A,-0.022139,0.031477,0.999893,0.002803
408703,19-48318737-A-G,-0.008040,0.011297,0.999868,0.003154
408704,19-48318744-T-G,-0.002208,0.003130,0.999860,0.003399
408705,19-48318746-C-G,-0.000361,0.000516,0.999955,0.001909
