In [1]:
import sys
sys.path.append('..')
from transformers import AutoModelForMaskedLM, AutoTokenizer
from prosst.structure.get_sst_seq import SSTPredictor
from Bio import SeqIO
import torch
import pandas as pd
from scipy.stats import spearmanr

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ["http_proxy"] = "http://127.0.0.1:15777"
os.environ["https_proxy"] = "http://127.0.0.1:15777"

Load ProSST from Hugging Face. 
(You may need to configure the proxy settings if you are in a region that cannot access the hugging face model.)

In [2]:
prosst_model = AutoModelForMaskedLM.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)
prosst_tokenizer = AutoTokenizer.from_pretrained("AI4Protein/ProSST-2048", trust_remote_code=True)

Load strcuture quantizer

In [3]:
predictor = SSTPredictor(structure_vocab_size=2048)

---------- Load Model on cuda ----------
MODEL: 5.90M parameters


Read protein sequence

In [4]:
residue_sequence = str(SeqIO.read('example_data/GRB2_HUMAN_Faure_2021.fasta', 'fasta').seq)
    

Quantize the structure

In [5]:
structure_sequence = predictor.predict_from_pdb("example_data/GRB2_HUMAN_Faure_2021.pdb")[0]['2048_sst_seq']

---------- Building Subgraphs ----------


100%|██████████| 1/1 [00:00<00:00,  1.21it/s]
100%|██████████| 1/1 [00:01<00:00,  1.24s/it]


Shift the quantized structure sequence, (for 3 special tokens [CLS], [SEP] and [PAD])

In [6]:
structure_sequence_offset = [i + 3 for i in structure_sequence]

Prepare model input

In [7]:
tokenized_res = prosst_tokenizer([residue_sequence], return_tensors='pt')
input_ids = tokenized_res['input_ids']
attention_mask = tokenized_res['attention_mask']
structure_input_ids = torch.tensor([1, *structure_sequence_offset, 2], dtype=torch.long).unsqueeze(0)

Inferece 

In [8]:
with torch.no_grad():
    outputs = prosst_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        ss_input_ids=structure_input_ids
    )
logits = torch.log_softmax(outputs.logits[:, 1:-1], dim=-1).squeeze()

Score mutants

In [9]:
df = pd.read_csv("example_data/GRB2_HUMAN_Faure_2021.csv")
mutants = df['mutant'].tolist()

In [10]:
vocab = prosst_tokenizer.get_vocab()
pred_scores = []
for mutant in mutants:
    mutant_score = 0
    for sub_mutant in mutant.split(":"):
        wt, idx, mt = sub_mutant[0], int(sub_mutant[1:-1]) - 1, sub_mutant[-1]
        pred = logits[idx, vocab[mt]] - logits[idx, vocab[wt]]
        mutant_score += pred.item()
    pred_scores.append(mutant_score)

Compute the spearman correlation

In [11]:
spearmanr(pred_scores, df['DMS_score'])

SignificanceResult(statistic=0.7182950462783597, pvalue=0.0)