In [2]:
import os
import sys
home_dir = "../../"
module_path = os.path.abspath(os.path.join(home_dir))
if module_path not in sys.path:
    sys.path.append(module_path)
    
# import sys
# home_dir = ""
# sys.path.append("../variant_effect_analysis")


import torch
import pandas as pd
import numpy as np
import time
import models.aa_common.pickle_utils as pickle_utils
from models.aa_common.data_loader import get_population_freq_proteomic_SNVs, get_population_freq_proteomic_SNVs_fasta_iterator

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
variants_df = get_population_freq_proteomic_SNVs(home_dir)
fasta_iterator = get_population_freq_proteomic_SNVs_fasta_iterator(home_dir)


Log: Loading data ...
raw data: (2882721, 8)
Index(['prot_acc_version', 'pos', 'wt', 'mut', 'wt_population',
       'mut_poulation', 'wt_freq', 'mt_freq'],
      dtype='object')
After combining common (18279), rare (29383) and sampled-singletons (47662), data: (95324, 8)

Log: Loading combined fasta iterator ...


In [4]:
print("\nLog: Model loading ...")
start = time.time()
from models.bioembeddings_dallago.lm_heads.prottrans_lms_factory import load_prottrans_model
model = load_prottrans_model("prottrans_bert_bfd")
tokenizer = model._tokenizer

end = time.time()
print(f"Time taken to load model: {end-start} s")

# an small example
logits = model.embed("SEQVENCE") # already converted to numpy array
print(np.array(logits).shape)


Log: Model loading ...


If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at /home/akabir4/.cache/bio_embeddings/prottrans_bert_bfd/model_directory were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Time taken to load model: 39.28939962387085 s
(8, 30)


In [6]:
print(tokenizer.convert_tokens_to_ids("A")) # this is not underscore(_), something else(▁)
tokenizer.get_vocab()

6


{'[PAD]': 0,
 '[UNK]': 1,
 '[CLS]': 2,
 '[SEP]': 3,
 '[MASK]': 4,
 'L': 5,
 'A': 6,
 'G': 7,
 'V': 8,
 'E': 9,
 'S': 10,
 'I': 11,
 'K': 12,
 'R': 13,
 'D': 14,
 'T': 15,
 'P': 16,
 'N': 17,
 'Q': 18,
 'F': 19,
 'Y': 20,
 'M': 21,
 'H': 22,
 'C': 23,
 'W': 24,
 'X': 25,
 'U': 26,
 'B': 27,
 'Z': 28,
 'O': 29}

In [7]:
model_path = home_dir+"models/bioembeddings_dallago/"
logits_output_path = f"{model_path}lm_outputs_{model.name}/"
model_output_path = f"{model_path}outputs/"
os.makedirs(logits_output_path, exist_ok=True)
os.makedirs(model_output_path, exist_ok=True)

In [8]:
def compute_model_logits(prot_acc_version, seq):
    filepath = f"{logits_output_path}{prot_acc_version}.pkl"
    if os.path.exists(filepath):
        logits = pickle_utils.load_pickle(filepath) # numpy array of l x vocab_size=30
    else: 
        with torch.no_grad():
            logits = model.embed(seq) # l x vocab_size=30
            pickle_utils.save_as_pickle(logits, filepath)
    # print(logits.shape)
    return logits

In [9]:
def execute(data):
    # data format: [(prot_id, seq)]
    # print(data)
    preds = []        
    for i, (prot_acc_version, seq) in enumerate(data):
        output_logits = compute_model_logits(prot_acc_version, seq) # l x vocab_size=30
        indices = variants_df[variants_df["prot_acc_version"]==prot_acc_version].index 
        # print(prot_acc_version, len(indices)) # indices can be of different shape for different runs, b/c we sample the singletons when computing variants_df
        for idx in indices:
            tuple = variants_df.loc[idx]
            
            wt_tok_idx = tokenizer.convert_tokens_to_ids(model.aa_prefix+tuple.wt)
            mt_tok_idx = tokenizer.convert_tokens_to_ids(model.aa_prefix+tuple.mut)
            pos = tuple.pos-1 #ncbi prot variants are 1 indexed, so <cls> is not at 0-position, so have to minus 1
            
            wt_logit = output_logits[pos][wt_tok_idx]
            mt_logit = output_logits[pos][mt_tok_idx]
            var_effect_score = mt_logit - wt_logit
            tuple = dict(tuple)
            tuple["pred"] = var_effect_score
            preds.append(tuple)
            # print(preds)
            # break
            
    preds_df = pd.DataFrame(preds)   
    # print(preds_df)
    return preds_df

In [10]:
# if __name__=="__main__": # main worker
start = time.time()
is_cuda = torch.cuda.is_available()
pred_dfs = []

data = [(seq_record.id, str(seq_record.seq)) for seq_record in fasta_iterator]

chunk_size = 32 if is_cuda else 1
data_chunks = [data[x:x+chunk_size] for x in range(0, len(data), chunk_size)]
data_chunks = data_chunks[:3] 
print(f"#-of chunks: {len(data_chunks)}, 1st chunk size: {len(data_chunks[0])}")


# sequential run and debugging
for i, data_chunk in enumerate(data_chunks):
    pred_df = execute(data_chunk)
    print(f"Finished {i}/{len(data_chunks)}th chunk: {pred_df.shape}")
    pred_dfs.append(pred_df)

 # mpi run    
# from mpi4py.futures import MPIPoolExecutor
# executor = MPIPoolExecutor()
# for i, pred_df in enumerate(executor.map(execute, data_chunks, unordered=True)):
#     print(f"Finished {i}/{len(data_chunks)}th chunk: {pred_df.shape}")
#     pred_dfs.append(pred_df)
# executor.shutdown()


result_df = pd.concat(pred_dfs)  
print("Saving predictions ...")  
result_df.to_csv(f"{model_output_path}popu_freq_preds_{model.name}.csv", sep="\t", index=False, header=True)
print(result_df.shape)
print(result_df.head())

print(f"Time taken: {time.time()-start} seconds")

#-of chunks: 3, 1st chunk size: 1
Finished 0/3th chunk: (8, 9)
Finished 1/3th chunk: (5, 9)
Finished 2/3th chunk: (4, 9)
Saving predictions ...
(17, 9)
  prot_acc_version  pos wt mut  wt_population  mut_poulation   wt_freq  \
0      NP_112509.3  236  D   N           4470              6  0.998660   
1      NP_112509.3  195  R   H          55576            141  0.997469   
2      NP_112509.3  374  H   R         202922           1874  0.990849   
3      NP_112509.3  254  P   L          66634            206  0.996918   
4      NP_112509.3  259  P   S          10680              1  0.999906   

    mt_freq      pred  
0  0.001340 -4.396527  
1  0.002531 -5.587605  
2  0.009151  0.419850  
3  0.003082 -4.399530  
4  0.000094 -4.352335  
Time taken: 14.284538745880127 seconds
