In [1]:
import unittest
import torch
import esm
from tqdm import tqdm # for progress bar

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Source: https://github.com/facebookresearch/esm/tree/main?tab=readme-ov-file#esmfold
model, alphabet = esm.pretrained.load_model_and_alphabet("esm2_t33_650M_UR50D") 
batch_converter = alphabet.get_batch_converter()

In [3]:
print(alphabet.all_toks)

['<cls>', '<pad>', '<eos>', '<unk>', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', '<null_1>', '<mask>']


In [4]:
def get_tokens(id,seq,model=model,alphabet=alphabet,batch_converter=batch_converter):
    data = [(id,seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_len = (batch_tokens != alphabet.padding_idx).sum(1)[0]
    return batch_tokens, batch_len

In [5]:
def get_logit_scores(batch_tokens,ln,start_pos=138, end_pos=143,model=model,):
    if torch.cuda.is_available():
        batch_tokens = batch_tokens.to(device=device, non_blocking=True)

    with torch.no_grad():
        logits_raw = model(batch_tokens)["logits"].squeeze(0)
        logits_target = logits_raw [1:(ln+1),4:24]

    # normalise logits to convert to probabilities 
    lsoftmax = torch.nn.LogSoftmax(dim=1)
    logits = lsoftmax(logits_target)

    return logits

In [6]:
def get_most_likely_mutation(logits,ref_logits,ref_seq,strategy="lc_pos_hc_aa",start_pos=138,end_pos=143,top_position=0):
    strategies = {
        "lc_pos_hc_aa" : lc_pos_hc_aa_strategy
    }
    return strategies[strategy](logits,ref_logits,ref_seq,start_pos,end_pos)
    
def index_to_char(aa,alphabet=alphabet,token_offset=4):
    return alphabet.all_toks[aa+token_offset]

def get_aa_char(ref_seq,pos,aa):
    ref_aa = list(ref_seq)[pos]
    aa_char = index_to_char(aa)
    if ref_aa!=aa_char:
        print("The top amino acid candidate for mutation is valid for this position: {ref_aa}>{aa_char}")
        return aa_char
    else:
        print("Invalid amino acid candidate for mutation as it is the same as the current amino acid: {ref_aa}>{aa_char}")
        return None

In [7]:
# Strategy 1: Least confident position, most confident amino acid
    # find position with lowest logit for current amino acid
    # at that position, find amino acid with highest logit
    # if current aa != aa with highest logit, mutate
    # else, 
        # if next highest aa logit > mutate
        # else, find next lowest logit aa
def lc_pos_hc_aa_strategy(logits,ref_logits,ref_seq,start_pos,end_pos):
    positions_of_interest_logits = logits[start_pos:end_pos+1,:]
    positions_of_interest_ref_logits = ref_logits[start_pos:end_pos+1]

    least_confident_pos = np.argmin(positions_of_interest_ref_logits)
    adjusted_pos_index = (least_confident_pos+start_pos).item()
    
    least_confident_pos_logits = positions_of_interest_logits[least_confident_pos]
    top_2_least_conf_pos = np.argsort(least_confident_pos_logits.numpy())[-2:][::-1] # top 2 in case new aa is current aa
    most_confident_aa_pos = top_2_least_conf_pos[0]
    aa_char = get_aa_char(ref_seq,adjusted_pos_index,most_confident_aa_pos)
    if not aa_char:
        print("Using the second best fit amino acid for this position.")
        second_most_confident_aa_pos = top_2_least_conf_pos[1]
        aa_char = index_to_char(second_most_confident_aa_pos)
        
    return adjusted_pos_index,aa_char

# Strategy 2: Most confident amino acid 
    # find amino acid with highest logit across all positions
    # if in same position, current aa != aa with highest position, mutate
    # else, find next highest logit aa

# Strategy 3: Most likely to mutate position, most confident amino acid
    # find position with highest number of aa logits above a threshold 
    # at that positon, find amino acid with highest logit
    # if current aa != aa with highest logit, mutate
    # else, 
        # if next highest aa logit > (next most likely to mutate pos's highest aa logit && current aa != aa with highest logit), mutate
        # else, find next most likely to mutate pos

In [8]:
def mutate_seq(reference_seq,pos,aa):
    list_seq = list(reference_seq)
    list_seq[pos] = aa
    return "".join(list_seq)

In [9]:
# Network class
# def add_seq(seq):
#     return 

In [10]:
# Sequence class
# from Bio.SeqRecord import SeqRecord