In [1]:
from Bio import SeqIO
from Bio.Seq import Seq
import pandas as pd
import numpy as np
import torch
import esm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from numpy import dot
from numpy.linalg import norm
from Bio import SeqIO
from scipy.special import softmax
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

  from .autonotebook import tqdm as notebook_tqdm


## Evolving FMD One Mutation At-a-time

In [2]:
# Goal: To predict the path of evolution of the virus causing foot-and-mouth disease
# Method: 
# > feed starting sequence to a PLM 
# > extracting logit scores to select the position of most probable change 
# > mutate position to most likely amino acid > check the probability of the current sequence 
# > add to network and continue OR kill off path and backtrack 


In [3]:
reference_seq = 'TTSAGESADPVTATVENYGGETQVQRRQHTDIAFILDRFVKVKPKEQVNVLDLMQIPAHTLVGALLRTATYYFSDLELAVKHEGDLTWVPNGAPETALDNTTNPTAYHKEPLTRLALPYTAPHRVLATVYNGSSKYGDTSTNNVRGDLQVLAQKAERTLPTSFNFGAIKATRVTELLYRMKRAETYCPRPLLAIQPSDARHKQRIVAPAKQ'
print(len(reference_seq))

211


In [4]:
# !pip install nbformata
%run utils.ipynb

['<cls>', '<pad>', '<eos>', '<unk>', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', '<null_1>', '<mask>']


In [5]:
token_offset = 4 # amino acid characters are located at indices 4-23 in the alphabet

In [8]:
steps=6
new_seq = reference_seq
for i in range(steps):
    print(f"\nStep {i}")
    batch_tokens, batch_len = get_tokens("base_seq",new_seq)
    ln = len(batch_tokens[0,1:-1]) # exclude special tokens
    logits = get_logit_scores(batch_tokens,ln)
    ref_tokens = batch_tokens[0,1:-1]
    ref_logits = logits[torch.arange(logits.size(0)),ref_tokens - token_offset]
    mean_log_p_current = ref_logits.mean().item()
    print(f"Mean log p of current sequence: {mean_log_p_current}")
    # print(f"Logits shape: {logits.shape}") # expected: (211,20)
    # print(f"Reference sequence's logits shape: {ref_logits.shape}") # expected: (211)

    pos,aa_char = get_most_likely_mutation(logits,ref_logits,new_seq)
    print(f"Position to mutate: {pos}")
    print(f"Reference amino acid: {list(new_seq)[pos]}")
    print(f"New amino acid: {aa_char}")
    print(f"ROI in reference sequence: {new_seq[138:144]}")
    new_seq = mutate_seq(new_seq,pos,aa_char)
    print(f"ROI in mutated sequence: {new_seq[138:144]}")


Step 0
Mean log p of current sequence: -0.2114058881998062
Invalid amino acid candidate for mutation as it is the same as the current amino acid: {ref_aa}>{aa_char}
Using the second best fit amino acid for this position.
Position to mutate: 139
Reference amino acid: S
New amino acid: K
ROI in reference sequence: TSTNNV
ROI in mutated sequence: TKTNNV

Step 1
Mean log p of current sequence: -0.20803409814834595
Invalid amino acid candidate for mutation as it is the same as the current amino acid: {ref_aa}>{aa_char}
Using the second best fit amino acid for this position.
Position to mutate: 141
Reference amino acid: N
New amino acid: T
ROI in reference sequence: TKTNNV
ROI in mutated sequence: TKTTNV

Step 2
Mean log p of current sequence: -0.20574155449867249
Invalid amino acid candidate for mutation as it is the same as the current amino acid: {ref_aa}>{aa_char}
Using the second best fit amino acid for this position.
Position to mutate: 140
Reference amino acid: T
New amino acid: V
RO

In [100]:
# Signle point mutations - use the ESM model to evaluate a pool of of possible mutations (look at logit scores for each position to choose most probable change), build graph at the same time to allow it to backtranch
# note: we do want multiple mutations by the end but accumulate one at a time
# keep track of where i am in the "walk"/"possible paths and kill off any improbable paths i end up in 
# ensure its still functional! eg still 

## Terminologies
DMS: deep mutational scanning is used to study the impact of mutations on protein structure and function

Grammaticality: the distance between original and mutated embeddings - see Hie 2020 "Learning mutational semantics"


In [18]:
# F: using the original seq and the desired mutation(s), return the mutated sequence
def mutate_sequence(reference_sequence,mutations):
    mutated_seq = reference_sequence
    for mutation in mutation sequences:
        if 'ins' not in mutation and 'del' not in mutation and "X" not in mutation:
            mutant_amino = mutation[-1]
            mutant_pos = int(mutation[1:-1])
            mutated_seq = mutated_seq[:mutant_pos-1]+mutant_amino+mutated_seq[mutant_pos:]
    return mutated_seq

# F: generates a list of sequences where every position in the protein sequence is mutated to every possible amino acid by default
# Note: this is only for single-residue mutation - need to update to deal with multiple mutations per sequence
def DMS(reference,start=0,end = None):
  if end == None:
    end = len(reference)
  seq_list = []
  amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]
  for i,ref_amino_acid in enumerate(reference):
      if i>=start and i<=end:
        for mutant_amino_acid in amino_acids:
            mutated_seq = reference[:i]+mutant_amino_acid+reference[i+1:]
            seq = SeqRecord(Seq(mutated_seq), id=ref_amino_acid+str(i+1)+mutant_amino_acid)
            seq_list.append(seq)
  return seq_list

seqs_of_mutations = DMS(reference_protein,138,143)
len(seqs_of_mutations)

120

In [19]:
# F: use PLM to extract embedding and logits for a given mutation
def embed_sequence(sequence,model,device,model_layers,batch_converter):
    #Sequences to embed (We only embed the reference and use the probabilities from that to generate the scores)
    sequence_data = [('base', sequence)]

    #Get tokens etc
    batch_labels, batch_strs, batch_tokens = batch_converter(sequence_data)
    batch_len = (batch_tokens != alphabet.padding_idx).sum(1)[0]

    #Move tokens to GPU
    if torch.cuda.is_available():
        batch_tokens = batch_tokens.to(device=device, non_blocking=True)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[model_layers], return_contacts=False)
    del batch_tokens

    #Embed Sequences
    token_representation = results["representations"][model_layers][0]
    full_embedding = token_representation[1:batch_len - 1].cpu()
    base_mean_embedding  = token_representation[1 : batch_len - 1].mean(0).cpu()

    #Get Embedding and probabilities for reference sequence (Should be first sequence in data)
    lsoftmax = torch.nn.LogSoftmax(dim=1)
    base_logits = lsoftmax((results["logits"][0]).to(device="cpu"))
    return results, base_logits, base_mean_embedding,full_embedding

# F: process embeddings and logits for sequence and return as a dictionary
def process_protein_sequence(sequence,model,model_layers,batch_converter):
    #Embed Sequence
    base_seq = sequence
    results,base_logits, base_mean_embedding, full_embedding = embed_sequence(base_seq,model,device,model_layers,batch_converter)
    results_dict = {}
    results_dict["Mean_Embedding"] = base_mean_embedding.tolist()
    # results_dict["Full_Embedding"] = full_embedding.tolist()
    results_dict["Logits"] = base_logits.tolist()
    return results_dict

In [None]:
# F: 
def grammaticality_and_evolutionary_index(word_pos_prob, seq, mutations):
    if len(mutations) == 0:
        print('No mutations detected')
        return 0, 0
    mut_probs = []
    ev_ratios = []
    current_support = -1
    print('Mutations: ', mutations)
    for mutation in mutations:
        #Ignore insertions
        if 'ins' not in mutation and 'del' not in mutation and "X" not in mutation:
            #Split mutation 
            aa_orig = mutation[0]
            aa_pos = int(mutation[1:-1]) - 1
            aa_mut = mutation[-1]
            if (seq[aa_pos] != aa_orig):
                print(mutation)
            assert(seq[aa_pos] == aa_orig)

            #Get probabilities for changes
            prob_change = word_pos_prob[(aa_mut, aa_pos)]
            prob_original = word_pos_prob[(aa_orig, aa_pos)]
            #Log probabilities to allow for subtraction
            ev_ratio = prob_change - prob_original
            ev_ratios.append(ev_ratio)

            #Log probabilities to allow for sum rather than product
            mut_probs.append(word_pos_prob[(aa_mut, aa_pos)])
    return np.sum(mut_probs), np.sum(ev_ratios)