In [50]:
import os
import sys
home_dir = "../../"
module_path = os.path.abspath(os.path.join(home_dir))
if module_path not in sys.path:
    sys.path.append(module_path)

import torch
import numpy as np
import pandas as pd
from models.aa_common.data_loader import get_pmd_dbnsfp_dataset
import models.tape_rao.model_utils as model_utils
import utils.pickle_utils as pickle_utils

In [4]:
task = "pmd"
variants_df, protid_seq_dict = get_pmd_dbnsfp_dataset(home_dir)

Index(['mut_id', 'md5', 'pmd_id', 'nr', 'prot_acc_version', 'snp_id',
       'mut_real', 'wt', 'mut', 'prot_pos', 'chrom', 'chrom_pos', 'ref_allele',
       'alt_allele', 'function', 'source', 'crossref', 'function_summarized',
       'class', 'SIFT_score', 'Polyphen2_HVAR_score', 'MetaRNN_score',
       'REVEL_score', 'MVP_score', 'CADD_raw_score',
       'integrated_fitCons_score', 'phyloP17way_primate_score',
       'phastCons17way_primate_score', 'bStatistic_score'],
      dtype='object')
(7179, 29)
Effect       3818
No-effect    1777
Knock-out    1584
Name: class, dtype: int64
#-unique prots:  2056


In [6]:
from tape import ProteinBertModel, TAPETokenizer

model_name = "protbert"
model = ProteinBertModel.from_pretrained('bert-base')
tokenizer = TAPETokenizer(vocab='iupac') 
model_task_out_dir, model_logits_out_dir = model_utils.create_output_directories(model_name, task, home_dir)


Log: Creating output directories ...


In [51]:
def get_embedding(seq, filename):
    filepath = f"{model_logits_out_dir}{filename}.pkl"

    if os.path.exists(filepath):
        print(f"Model logits already exists: {filename}")
        embedding = pickle_utils.load_pickle(filepath) 
    else: 
        print(f"Computing model logits: {filename}")
        with torch.no_grad():
            token_ids = torch.tensor(np.array([tokenizer.encode(seq)]))
            embedding = model(token_ids)[0].squeeze(0).detach().numpy()
        pickle_utils.save_as_pickle(embedding, filepath)
    return embedding

def compute_variant_effect_score(protid, seq, one_indexed_mut_pos, wt_aa, mt_aa):
    wt_seq = list(seq)
    mt_seq = list(seq)
    mt_seq[one_indexed_mut_pos] = mt_aa

    wt_filename = f"{protid}"
    mt_filename = f"{protid}_{str(one_indexed_mut_pos)}_{mt_aa}"
    
    wt_embedding = get_embedding(wt_seq, wt_filename)[1:-1] # 1st and last tokens are <cls>=2 and <sep>=3
    mt_embedding = get_embedding(mt_seq, mt_filename)[1:-1]
    # print(wt_embedding.shape, mt_embedding.shape)

    effect_score = abs(mt_embedding - wt_embedding).sum() / (768*len(seq)) # embedding_dim = 768
    # print(effect_score)
    return effect_score
    

In [55]:
preds = []
for i, tuple in enumerate(variants_df.itertuples()):
    protid, seq, one_indexed_mut_pos, wt_aa, mt_aa = tuple.prot_acc_version, protid_seq_dict[tuple.prot_acc_version], tuple.prot_pos, tuple.wt, tuple.mut
    print(protid, one_indexed_mut_pos, wt_aa, mt_aa)
    effect_score = compute_variant_effect_score(protid, seq, one_indexed_mut_pos, wt_aa, mt_aa)

    row = variants_df.loc[i]
    row = dict(row)
    row["pred"] = effect_score
    preds.append(row)
    if i==1: break

preds_df = pd.DataFrame(preds)   
preds_df

A000006_2 271 C Y
Model logits already exists: A000006_2
Model logits already exists: A000006_2_271_Y
A000006_2 62 N S
Model logits already exists: A000006_2
Model logits already exists: A000006_2_62_S


Unnamed: 0,mut_id,md5,pmd_id,nr,prot_acc_version,snp_id,mut_real,wt,mut,prot_pos,...,Polyphen2_HVAR_score,MetaRNN_score,REVEL_score,MVP_score,CADD_raw_score,integrated_fitCons_score,phyloP17way_primate_score,phastCons17way_primate_score,bStatistic_score,pred
0,168938,ef9a941d6a8a5d1b12be46de47ffd9ea,A000006,2,A000006_2,rs121913562,C271Y,C,Y,271,...,1.0,0.992647,0.836,0.894457,3.703299,0.487112,0.599,0.993,810.0,0.024019
1,168942,ef9a941d6a8a5d1b12be46de47ffd9ea,A000006,2,A000006_2,rs121913566,N62S,N,S,62,...,1.0,0.991474,0.946,0.996622,3.705575,0.487112,0.665,0.991,807.0,0.039788
