In [1]:
import sys
sys.path.append('..')
from transformers import AutoModelForMaskedLM, AutoTokenizer
from prosst.structure.quantizer import PdbQuantizer
from Bio import SeqIO
import torch
import pandas as pd
from scipy.stats import spearmanr

In [2]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:15777"
os.environ["https_proxy"] = "http://127.0.0.1:15777"

Load ProSST from Hugging Face. 
(You may need to configure the proxy settings if you are in a region that cannot access the hugging face model.)

In [3]:
deprot = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

Load strcuture quantizer

In [4]:
processor = PdbQuantizer()

Read protein sequence

In [5]:
residue_sequence = str(SeqIO.read('example_data/GRB2_HUMAN_Faure_2021.fasta', 'fasta').seq)
    

Quantize the structure

In [6]:
structure_sequence = processor("example_data/GRB2_HUMAN_Faure_2021.pdb")

Shift the quantized structure sequence, (for 3 special tokens [CLS], [SEP] and [PAD])

In [7]:
structure_sequence_offset = [i + 3 for i in structure_sequence]

Prepare model input

In [8]:
tokenized_res = tokenizer([residue_sequence], return_tensors='pt')
input_ids = tokenized_res['input_ids']
attention_mask = tokenized_res['attention_mask']
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2], dtype=torch.long).unsqueeze(0)

Inferece 

In [9]:
with torch.no_grad():
    outputs = deprot(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ss_input_ids=structure_input_ids
    )
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()

Score mutants

In [10]:
df = pd.read_csv("example_data/GRB2_HUMAN_Faure_2021.csv")
mutants = df['mutant'].tolist()

In [11]:
vocab = tokenizer.get_vocab()
pred_scores = []
for mutant in mutants:
    mutant_score = 0
    for sub_mutant in mutant.split(":"):
        wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
        pred = logits[idx, vocab[mt]] - logits[idx, vocab[wt]]
        mutant_score += pred.item()
    pred_scores.append(mutant_score)

Compute the spearman correlation

In [12]:
spearmanr(pred_scores, df['DMS_score'])

SignificanceResult(statistic=0.6997442598613315, pvalue=0.0)