# Distilled ProSST Scoring Notebook

This notebook demonstrates how to:
1. Load a protein sequence from a FASTA file.
2. Quantize a PDB structure.
3. Run ProSST inference.
4. Score mutational effects from a CSV of mutants.

Adjust the file paths in **Step 2** to your own sequence, PDB, and CSV.

In [1]:
# Step 1: Imports
import torch
import pandas as pd
from Bio import SeqIO
from scipy.stats import spearmanr
from transformers import AutoModelForMaskedLM, AutoTokenizer
from prosst.structure.quantizer import PdbQuantizer



OSError: /opt/conda/envs/worker/lib/python3.12/site-packages/torch_cluster/_version_cpu.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb

In [None]:
# Step 2: Define file paths
# Adjust these to point to your own files.
FASTA_FILE = 'example_data/GRB2_HUMAN_Faure_2021.fasta'  # Your sequence file
PDB_FILE = 'example_data/GRB2_HUMAN_Faure_2021.pdb'       # Your PDB file
MUT_CSV = 'example_data/GRB2_HUMAN_Faure_2021.csv'       # CSV file with mutants

### Step 3: Load ProSST model & tokenizer
If you are behind a corporate firewall or in a region that cannot access Hugging Face, you may need to configure a proxy.

In [None]:
model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

processor = PdbQuantizer()  # For converting PDB to quantized structural tokens

### Step 4: Read the protein sequence and quantize the PDB
We will:
1. Read the sequence from FASTA.
2. Quantize the structure from PDB.
3. Offset the structure tokens by `+3` for `[CLS]`, `[SEP]`, `[PAD]` handling.

In [None]:
# Read sequence
residue_sequence = str(SeqIO.read(FASTA_FILE, 'fasta').seq)

# Quantize structure
structure_sequence = processor(PDB_FILE)

# Offset structure codes for special tokens
structure_sequence_offset = [s + 3 for s in structure_sequence]

print(f"Residue sequence length: {len(residue_sequence)}")
print(f"Quantized structure length: {len(structure_sequence)}")

### Step 5: Prepare model input tensors
We'll tokenize the residue sequence, then build the input IDs for the structure (including `[CLS] = 1` and `[SEP] = 2`).

In [None]:
tokenized_res = tokenizer([residue_sequence], return_tensors='pt')
input_ids = tokenized_res['input_ids']
attention_mask = tokenized_res['attention_mask']

# Build structure input: [CLS] + structure_sequence + [SEP]
structure_input_ids = torch.tensor(
    [1, *structure_sequence_offset, 2],  # 1 = [CLS], 2 = [SEP]
    dtype=torch.long
).unsqueeze(0)

print("Sequence input size:", input_ids.shape)
print("Structure input size:", structure_input_ids.shape)

### Step 6: Run inference on the sequence
We'll get the logits (logits across the vocabulary for each residue position).

In [None]:
with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ss_input_ids=structure_input_ids
    )

# We'll convert logits to log probabilities.
# Note that outputs.logits has shape: [batch, seq_len, vocab_size].
# The model's input has 1 extra token on the left and right, so slice [1:-1].
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()
print("Logits shape (positions x vocabulary):", logits.shape)

### Step 7: Score your mutants
The CSV file should have a column `mutant` that describes the mutations (e.g. `A100V` or `A100V:M101T`).
We'll parse each mutation, subtract the log probability of the wild-type amino acid from the log probability of the mutant, and sum if there are multiple substitutions.

In [None]:
df = pd.read_csv(MUT_CSV)
mutants = df['mutant'].tolist()

vocab = tokenizer.get_vocab()
pred_scores = []

for mutant in mutants:
    total_mut_score = 0.0
    # Handle compound mutants like "A100V:M101T"
    for sub_mutant in mutant.split(":"):
        wt_aa = sub_mutant[0]
        pos   = int(sub_mutant[1:-1]) - 1  # zero-based index
        mt_aa = sub_mutant[-1]
        
        # Score = logP(mutant) - logP(wt)
        delta = logits[pos, vocab[mt_aa]] - logits[pos, vocab[wt_aa]]
        total_mut_score += delta.item()
    pred_scores.append(total_mut_score)

df['PredictedScore'] = pred_scores
df.head()

### (Optional) Step 8: Compute Spearman correlation
If your CSV contains an experimental DMS_score or another numeric metric, you can calculate the correlation.

In [None]:
if 'DMS_score' in df.columns:
    corr = spearmanr(df['PredictedScore'], df['DMS_score'])
    print("Spearman correlation:", corr)
else:
    print("No 'DMS_score' column found; skipping correlation calculation.")