# Reproducing zero-shot results of ESM
- mutagenesis: BLAT_ECOLX_Ranganathan2015
- model: ESM-1v (seeds 1 to 5)

## Download and load data

In [16]:
wildtype_esm = "HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW"
wildtype_full = "MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW"

# wildtype = wildtype_full
# offset = 0

wildtype = wildtype_esm
offset = wildtype_full.find(wildtype_esm)
print(offset)

23


In [19]:
import os
import urllib.request
import csv

if not os.path.exists("data/raw_df.csv"):
    print("Downloading raw_df.csv...")
    urllib.request.urlretrieve("https://dl.fbaipublicfiles.com/fair-esm/examples/variant-prediction/data/raw_df.csv", "data/raw_df.csv")

mutants = []
target_values = []

with open("data/raw_df.csv", "r") as fd:
    reader = csv.DictReader(fd)
    for row in reader:
        if row["protein_name"] == "BLAT_ECOLX_Ranganathan2015":
            mutants.append([])
            for mutation in row["mutant"].split(":"):
                original = mutation[0]
                pos = int(mutation[1:-1]) - 1 - offset
                amino_acid = mutation[-1]
                assert wildtype[pos] == original, f"Wildtype amino acid at position {pos} is {wildtype[pos]}, not {original}"
                mutants[-1].append((pos, amino_acid))
            target_values.append(float(row["gt"]))

print(f"Loaded {len(mutants)} mutants")

Loaded 4788 mutants


## Computing zero-shot scores

- Load foundation model
- Compute zero-shot score using masked marginal strategy (strategy (a), Eq. (1))

In [21]:
import esm
import torch
from tqdm.auto import tqdm

scores_path = "data/scores-esm2_t30.pt"

try:
    scores = torch.load(scores_path)
except FileNotFoundError:
    ensemble = [
        # esm.pretrained.esm1v_t33_650M_UR90S_1,
        # esm.pretrained.esm1v_t33_650M_UR90S_2,
        # esm.pretrained.esm1v_t33_650M_UR90S_3,
        # esm.pretrained.esm1v_t33_650M_UR90S_4,
        # esm.pretrained.esm1v_t33_650M_UR90S_5,
        esm.pretrained.esm2_t30_150M_UR50D,
    ]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    scores = torch.zeros(len(mutants), len(ensemble))
    with torch.no_grad():
        for i, model_factory in enumerate(ensemble):
            model, alphabet = model_factory()
            model.half()
            model.to(device)
            batch_converter = alphabet.get_batch_converter()
            
            # Tokenize wildtype sequence
            _, _, wt_tokens = batch_converter([("", wildtype)])
            
            for j in tqdm(range(len(mutants)), desc=f"Model {i + 1}"):
                # Mask wildtype sequence at all mutated positions
                tokens = wt_tokens.clone()
                for (pos, _) in mutants[j]:
                    tokens[:, pos + 1] = alphabet.mask_idx  # offset by 1 due to CLS token
                
                # Compute log probabilities
                log_probs = torch.log_softmax(model(tokens.to(device))["logits"], dim=-1)
                log_probs = log_probs[0, 1:-1].cpu()
                
                # Compute masked marginal probability
                score = torch.tensor(0.0)
                for (pos, aa) in mutants[j]:
                    score += (
                        log_probs[pos, alphabet.get_idx(aa)] -
                        log_probs[pos, alphabet.get_idx(wildtype[pos])]
                    )
                scores[j, i] = score
            del model  # Free memory
    torch.save(scores, scores_path)

Model 1:   0%|          | 0/4788 [00:00<?, ?it/s]

## Evaluation
- abs(spearmanr(target_values, scores))

In [22]:
from scipy.stats import spearmanr

for i in range(scores.shape[-1]):
    correlation = abs(spearmanr(target_values, scores[:, i]).correlation)
    print(f"Model {i + 1}: {correlation:.3f}")
correlation = abs(spearmanr(target_values, scores.mean(dim=-1)).correlation)
print(f"Ensemble: {correlation:.3f}")

Model 1: 0.683
Ensemble: 0.683
