# Zero-shot mutant prediction with Prime

This tutorial demonstrates how to predict the mutant effect of a protein using a pretrained model from the Prime model.

In this example, we will load the basic prime model (no tuning on homology sequence) and predict the effect of a mutation on the GAL4_YEAST_Kitzman_2015 exmperiment.

We provide:

- The wild sequence, a FASTA file.
- The mutant list, a CSV file.

Goals
Obtain an predicted score for each mutant.


## Import the necessary modules

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd
from Bio import SeqIO
from tqdm.notebook import tqdm

def read_seq(seq_file):
    for record in SeqIO.parse(seq_file, "fasta"):
        return str(record.seq)

## Read wildtype sequence and mutant list

In [19]:
wild_type = f"../proteingym_v1.0_fasta/fasta/PAI1_HUMAN_Huttinger_2021.fasta"
mutant = f"../proteingym_v1.0_fasta/mutant/PAI1_HUMAN_Huttinger_2021.csv"

In [20]:
sequence = read_seq(wild_type)
df = pd.read_csv(mutant)

In [21]:
sequence

'MQMSPALTCLVLGLALVFGEGSAVHHPPSYVAHLASDFGVRVFQQVAQASKDRNVVFSPYGVASVLAMLQLTTGGETQQQIQAAMGFKIDDKGMAPALRHLYKELMGPWNKDEISTTDAIFVQRDLKLVQGFMPHFFRLFRSTVKQVDFSEVERARFIINDWVKTHTKGMISNLLGKGAVDQLTRLVLVNALYFNGQWKTPFPDSSTHRRLFHKSDGSTVSVPMMAQTNKFNYTEFTTPDGHYYDILELPYHGDTLSMFIAAPYEKEVPLSALTNILSAQLISHWKGNMTRLPRLLVLPKFSLETEVDLRKPLENLGMTDMFRQFQADFTSLSDQEPLHVAQALQKVKIEVNESGTVASSSTAVIVSARMAPEEIIMDRPFLFVVRHNPTGTVLFMGQVMEP'

In [22]:
df.head()

Unnamed: 0,mutant,score
0,A119D,0.157541
1,A119E,-0.992784
2,A119G,-0.28124
3,A119K,-0.652502
4,A119L,-0.615973


The 'score' column is the score of the mutant. Tt is only used for evaluation purpose. Tn practice, it is not available and you do not need the column, just the 'mutant' column is enough.

## Load model and tokenizer

In [7]:
model_path = "AI4Protein/Prime_690M"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
model = model.to(device)



### Compute the logits of the wild type sequence

In [23]:
with torch.no_grad():
    tokenied_results = tokenizer(sequence, return_tensors="pt")
    input_ids = tokenied_results.input_ids.to(device)
    attention_mask = tokenied_results.attention_mask.to(device)
    logits = model(input_ids, attention_mask=attention_mask).logits[0, 1:-1, :].log_softmax(dim=-1)

### Score mutants

In [25]:
scores = []
for mutant in tqdm(df["mutant"]):
    score = 0
    for sub_mutant in mutant.split(":"):
        wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
        score += (logits[idx, tokenizer.get_vocab()[mt]] - logits[idx, tokenizer.get_vocab()[wt]]).item()
    scores.append(score)
df["predict_score"] = scores

  0%|          | 0/5345 [00:00<?, ?it/s]

### Check the results

In [26]:
df.head()

Unnamed: 0,mutant,score,predict_score
0,A119D,0.157541,-9.170906
1,A119E,-0.992784,-9.041076
2,A119G,-0.28124,-4.39185
3,A119K,-0.652502,-4.568374
4,A119L,-0.615973,-7.744344


### Evaluation (optional)

In [27]:
from scipy.stats import spearmanr

In [28]:
spearmanr(df["score"], df["predict_score"])

SignificanceResult(statistic=0.46896927252072046, pvalue=1.5783500358946215e-290)