In [1]:
import os
import sys
home_dir = "../../"
module_path = os.path.abspath(os.path.join(home_dir))
if module_path not in sys.path:
    sys.path.append(module_path)

import pandas as pd
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from vespa.predict.config import DEVICE, MODEL_PATH_DICT, EMBEDDING_HALF_PREC
print(EMBEDDING_HALF_PREC)
print(DEVICE)

False
cpu


In [11]:
# step 1: embedding generation
from vespa.predict.embedding import T5_Embed
t5_emb = T5_Embed(cache_dir=home_dir+"models/vespa_marquet/cache")
# EMBEDDING_HALF_PREC must be set False for t5_emb model to run
model, tokenizer = t5_emb.prott5.get_model(0) # EMBED=0

def compute_embedding(seq):
    seq_len = len(seq)
    seq = ' '.join(list(seq))
    token_encoding = tokenizer([seq], add_special_tokens=True, padding='longest', return_tensors="pt")
    input_ids = token_encoding['input_ids'].to(DEVICE)
    attention_mask = token_encoding['attention_mask'].to(DEVICE)
    # print(input_ids, attention_mask)

    with torch.no_grad():
        embedding_repr = model(input_ids, attention_mask=attention_mask) #1 x seq_len x embedding_dim
        emb = embedding_repr.last_hidden_state[0, :seq_len]
        emb = emb.detach().cpu().numpy().squeeze() # seq_len, 1024
        # print(emb.shape)
    return emb

In [17]:
# step 2: conservation prediction
from vespa.predict.conspred import ProtT5Cons
from pathlib import Path

checkpoint_path = Path(MODEL_PATH_DICT["CONSCNN"])
conspred = ProtT5Cons(checkpoint_path)
# print(checkpoint_path)
# conspred.predictor

def compute_conservation(embedding):
    with torch.no_grad():
        Yhat = conspred.predictor(torch.tensor(embedding).unsqueeze(0))
        prob = conspred.model.extract_probabilities(Yhat)
        # cls = conspred.model.extract_conservation_score(Yhat)

    # Yhat = Yhat.squeeze(0).detach().cpu().numpy()
    prob = prob.squeeze(0).detach().cpu().numpy()
    # cls = cls.squeeze(0).detach().cpu().numpy()
    # print(Yhat.shape, prob.shape, cls.shape) # shapes: (9, seq_len) (9, seq_len) (seq_len,)
    return prob

In [19]:
# step 3: computing log-odds
from vespa.predict.logodds import T5_condProbas
t5_condProbas = T5_condProbas(cache_dir=home_dir+"models/vespa_marquet/cache")

def get_log_odds(seq_dict, mutation_generator):
    proba_dict = t5_condProbas.get_proba_dict(seq_dict, mutation_generator)
    dmiss_data = t5_condProbas.get_log_odds(proba_dict) # seq_len, 20. DMISS (Deep mutational in-silico scanning.)
    return dmiss_data 

In [41]:
from vespa.predict.vespa import VespaPred
is_vespa=True
vespa_predictor = VespaPred(vespa=is_vespa, vespal=True)

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [43]:
from vespa.predict.utils import MutationGenerator


protid = "prot"
seq_dict = {protid: "MENFQYSVQLSDQXWA"}
mutations_dict = {protid: ["M1C", "E2S"]}
seq = seq_dict[protid]

temp_mutations_filepath = home_dir+"models/vespa_marquet/cache/temp_mutations.txt"
temp_protid = "protid"
with open(temp_mutations_filepath, "w") as f:
    for mutation in mutations_dict[protid]:
        f.write(f"{temp_protid}_{mutation}\n")
temp_seq_dict = {temp_protid: seq}

mutations_file_path = Path(temp_mutations_filepath)
mutation_generator = MutationGenerator(temp_seq_dict, file_path=mutations_file_path, one_based_file=True)


embedding = compute_embedding(seq) # shape: seq_len, 1024
conservation = compute_conservation(embedding)
conservation_dict = {temp_protid: conservation}

if is_vespa:
    log_odds = get_log_odds(temp_seq_dict, mutation_generator)
    predictions = vespa_predictor.generate_predictions(mutation_generator, conservation_dict, log_odds)
else: 
    predictions = vespa_predictor.generate_predictions(mutation_generator, conservation_dict)

predictions[protid] = predictions.pop(temp_protid)
predictions # this result is exactly same in the vespa_outs_from_cmd/0.csv

tensor([[19,  9, 17, 15, 16, 18,  7,  6, 16,  4,  7, 10, 16, 23, 21,  3,  1]]) tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


Extract Sequence Logodds: 1it [00:00,  1.11it/s]
Logodds Lookup: 100%|██████████| 2/2 [00:00<00:00, 28728.11it/s]
Blosum Lookup: 100%|██████████| 2/2 [00:00<00:00, 20213.51it/s]
Conservation Lookup: 100%|██████████| 2/2 [00:00<00:00, 39568.91it/s]


Generate Model Predictions
Predictions Done; Generate output


Info Generation: 100%|██████████| 2/2 [00:00<00:00, 49636.73it/s]


{'prot': [('M0C', {'VESPAl': 0.5799001754919151, 'VESPA': 0.5011233500844459}),
  ('E1S', {'VESPAl': 0.4177296331552022, 'VESPA': 0.32330920524589424})]}