In [1]:
from Bio import SeqIO
from Bio.Seq import Seq
import pandas as pd
import numpy as np
import torch
import esm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from numpy import dot
from numpy.linalg import norm
from Bio import SeqIO
from scipy.special import softmax
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

# Compression Functions

In [2]:
import bz2
import pickle
import _pickle as cPickle
def compressed_pickle(title, data):
  with bz2.BZ2File(title + '.pbz2', 'w') as f:
    cPickle.dump(data, f)

def decompress_pickle(file):
  data = bz2.BZ2File(file, 'rb')
  data = cPickle.load(data)
  return data

# Genbank Annotation Functions

In [3]:
def makeOrfTable(genbank_record):
    orfs=[]
    for feature in genbank_record.features:
        if feature.type =="CDS":
            orf = feature.qualifiers['gene'][0]
            for i, locations in enumerate(feature.location.parts):
                orfs.append([orf, locations.start, locations.end, i, locations])
    orfs = pd.DataFrame(orfs)
    orfs.columns = ['ORF','Start','End','Part','Locations']
    orfs = orfs.set_index("ORF")
    return orfs

def makeMatProteinTable(genbank_record):
    proteins=[]
    for feature in genbank_record.features:
        if feature.type =="mat_peptide":
            protein = feature.qualifiers['product'][0]
            orf = feature.qualifiers['gene'][0]
            for i, locations in enumerate(feature.location.parts):
                proteins.append([protein, orf ,locations.start, locations.end, i, locations])
    proteins = pd.DataFrame(proteins)
    proteins.columns = ['Protein',"ORF",'Start','End','Part','Locations']
    proteins = proteins.set_index("Protein")
    return proteins

# Mutation Functions

In [4]:
def mutate_sequence(reference_sequence,mutations):
    mutated_seq = reference_sequence
    for mutation in mutations:
        if 'ins' not in mutation and 'del' not in mutation and "X" not in mutation:
            mutant_amino = mutation[-1]
            mutant_pos = int(mutation[1:-1])
            mutated_seq = mutated_seq[:mutant_pos-1]+mutant_amino+mutated_seq[mutant_pos:]
    return mutated_seq

# This function generates a list of sequences where every position in the protein sequence is mutated to every possible amino acid by default
# For my application, I only want a list of mutated sequences for the 6 positions I want to mutate 
def DMS(reference,start=0,end = None):
  if end == None:
    end = len(reference)
  seq_list = []
  amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]
  for i,ref_amino_acid in enumerate(reference):

      if i>=start and i<=end:
        for mutant_amino_acid in amino_acids:
            mutated_seq = reference[:i]+mutant_amino_acid+reference[i+1:]
            seq = SeqRecord(Seq(mutated_seq), id=ref_amino_acid+str(i+1)+mutant_amino_acid)
            seq_list.append(seq)DMS
  return seq_list

# Translation Functions

In [5]:
def iterative_translate(sequence,truncate_proteins=False):
    amino_acid = ""
    for i in range(0,len(sequence)-2,3):
        codon = str(sequence[i:i+3])
        codon = codon.replace("?", "N")
        if "-" in codon:
            if codon == "---":
                amino_acid +="-"
            else:
                amino_acid+= "X"
        else:
            amino_acid += str(Seq(codon).translate())
    if truncate_proteins == True:
        if "*" in amino_acid:
            amino_acid = amino_acid[:amino_acid.index("*")]
    return amino_acid

def translate_with_genbank(sequence,ref):
    orfs = makeOrfTable(ref)
    translated_sequence = {orfs.index[i]+":"+str(orfs.iloc[i].Part):{"Sequence":"".join(list(iterative_translate("".join(orfs.iloc[i].Locations.extract(sequence)),truncate_proteins=True))),"ORF":orfs.index[i]} for i in range(len(orfs))}
    return translated_sequence

def translate_mat_proteins_with_genbank(sequence,ref):
    proteins = makeMatProteinTable(ref)
    proteins = proteins.drop_duplicates(subset=["ORF",'Start','End','Part',],keep="first")
    proteins_dict={}
    for i in range(len(proteins)):
        protein = "".join(list(iterative_translate("".join(proteins.iloc[i].Locations.extract(sequence)),truncate_proteins=True)))
        if proteins.index[i] in proteins_dict:
            proteins_dict[proteins.index[i]]["Sequence"] = proteins_dict[proteins.index[i]]["Sequence"]+protein
        else:
            proteins_dict[proteins.index[i]] = {"Sequence":protein, "ORF":proteins.iloc[i].ORF, "Part":proteins.iloc[i].Part}
    # translated_sequence = {proteins.index[i]:{"Sequence":"".join(list(iterative_translate("".join(proteins.iloc[i].Locations.extract(sequence)),truncate_proteins=True))), "ORF":proteins.iloc[i].ORF} }
    return proteins_dict

# Embedding Functions

In [6]:
def embed_sequence(sequence,model,device,model_layers,batch_converter):
    #Sequences to embed (We only embed the reference and use the probabilities from that to generate the scores)
    sequence_data = [('base', sequence)]

    #Get tokens etc
    batch_labels, batch_strs, batch_tokens = batch_converter(sequence_data)
    batch_len = (batch_tokens != alphabet.padding_idx).sum(1)[0]

    #Move tokens to GPU
    if torch.cuda.is_available():
        batch_tokens = batch_tokens.to(device=device, non_blocking=True)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[model_layers], return_contacts=False)
    del batch_tokens

    #Embed Sequences
    token_representation = results["representations"][model_layers][0]
    full_embedding = token_representation[1:batch_len - 1].cpu()
    base_mean_embedding  = token_representation[1 : batch_len - 1].mean(0).cpu()

    #Get Embedding and probabilities for reference sequence (Should be first sequence in data)
    lsoftmax = torch.nn.LogSoftmax(dim=1)
    base_logits = lsoftmax((results["logits"][0]).to(device="cpu"))
    return results, base_logits, base_mean_embedding,full_embedding

def process_protein_sequence(sequence,model,model_layers,batch_converter):
    #Embed Sequence
    base_seq = sequence
    results,base_logits, base_mean_embedding, full_embedding = embed_sequence(base_seq,model,device,model_layers,batch_converter)
    results_dict = {}
    results_dict["Mean_Embedding"] = base_mean_embedding.tolist()
    # results_dict["Full_Embedding"] = full_embedding.tolist()
    results_dict["Logits"] = base_logits.tolist()
    return results_dict

# Scoring Functions

In [7]:
def grammaticality_and_evolutionary_index(word_pos_prob, seq, mutations):
    if len(mutations) == 0:
        print('No mutations detected')
        return 0, 0
    mut_probs = []
    ev_ratios = []
    current_support = -1
    print('Mutations: ', mutations)
    for mutation in mutations:
        #Ignore insertions
        if 'ins' not in mutation and 'del' not in mutation and "X" not in mutation:
            #Split mutation 
            aa_orig = mutation[0]
            aa_pos = int(mutation[1:-1]) - 1
            aa_mut = mutation[-1]
            if (seq[aa_pos] != aa_orig):
                print(mutation)
            assert(seq[aa_pos] == aa_orig)

            #Get probabilities for changes
            prob_change = word_pos_prob[(aa_mut, aa_pos)]
            prob_original = word_pos_prob[(aa_orig, aa_pos)]
            #Log probabilities to allow for subtraction
            ev_ratio = prob_change - prob_original
            ev_ratios.append(ev_ratio)

            #Log probabilities to allow for sum rather than product
            mut_probs.append(word_pos_prob[(aa_mut, aa_pos)])
    return np.sum(mut_probs), np.sum(ev_ratios)

# Genbank Funnctions

In [8]:
from time import sleep
# from tqdm import tqdm
from tqdm.notebook import tqdm

def process_sequence_genbank(sequence,genbank,model,model_layers):
    #Translate nucleotide to proteins using genbank
    Coding_Regions= translate_with_genbank(sequence,genbank)
    Mature_Proteins= translate_mat_proteins_with_genbank(sequence,genbank)
    polyprotein_orfs =set([Mature_Proteins[prot]["ORF"] for prot in Mature_Proteins.keys()])
    Filtered_Coding_Regions = {**Coding_Regions}
    for orf in Coding_Regions.keys():
        if Coding_Regions[orf]["ORF"] in polyprotein_orfs:
            del Filtered_Coding_Regions[orf]
    Merged_Coding_Regions = {**Filtered_Coding_Regions,**Mature_Proteins}
    #Embed Sequence
    for key,value in Merged_Coding_Regions.items():
        base_seq = Merged_Coding_Regions[key]["Sequence"]
        results,base_logits, base_mean_embedding, full_embedding = embed_sequence(base_seq,model,device,model_layers,batch_converter)
        word_pos_prob = {}
        for pos in range(len(base_seq)):
            for word in alphabet.all_toks:
                word_idx = alphabet.get_idx(word)
                prob = base_logits[pos + 1, word_idx]
                word_pos_prob[(word, pos)] = prob
        value["Mean_Embedding"] = base_mean_embedding.tolist()
        # value["Full_Embedding"] = full_embedding.tolist()
        value["Logits"] = base_logits.tolist()

    all_embeddings = [np.array(Merged_Coding_Regions [key]["Mean_Embedding"]) for key in Merged_Coding_Regions.keys()]
#     Merged_Coding_Regions ["Sum_Embedding"] = list(np.sum(all_embeddings,axis=0))
#     Merged_Coding_Regions ["Concatenated_Embedding"] = list(np.concatenate(all_embeddings))
    return Merged_Coding_Regions


def get_sequence_grammaticality(sequence,sequence_logits):   
    prob_list = []
    sequence_logits = torch.FloatTensor(sequence_logits)
    for pos in range(len(sequence)):
        word_idx = alphabet.get_idx(sequence[pos])
        word = sequence_logits[(pos + 1,word_idx)]
        prob_list.append(word)
    base_grammaticality =np.sum(prob_list)
    return base_grammaticality


def process_and_dms_sequence_genbank(sequence,genbank,model,model_layers,specify_orf=""):
    #Translate nucleotide to proteins using genbank
    Coding_Regions= translate_with_genbank(sequence,genbank)
    Mature_Proteins= translate_mat_proteins_with_genbank(sequence,genbank)
    polyprotein_orfs =set([Mature_Proteins[prot]["ORF"] for prot in Mature_Proteins.keys()])
    Filtered_Coding_Regions = {**Coding_Regions}
    for orf in Coding_Regions.keys():
        if Coding_Regions[orf]["ORF"] in polyprotein_orfs:
            del Filtered_Coding_Regions[orf]
    Merged_Coding_Regions = {**Filtered_Coding_Regions,**Mature_Proteins}
    embeddings = {}
    if specify_orf !="":
        Merged_Coding_Regions = {specify_orf:Merged_Coding_Regions[specify_orf]}
    #Embed Sequence
    for key,value in Merged_Coding_Regions.items():
        embeddings[key] = {}
        base_seq = Merged_Coding_Regions[key]["Sequence"]
        results,base_logits, base_mean_embedding, full_embedding = embed_sequence(base_seq,model,device,model_layers,batch_converter)
        word_pos_prob = {}
        for pos in range(len(base_seq)):
            for word in alphabet.all_toks:
                word_idx = alphabet.get_idx(word)
                prob = base_logits[pos + 1, word_idx]
                word_pos_prob[(word, pos)] = prob
        embeddings[key]["Reference"] = {"Mean_Embedding":base_mean_embedding.tolist(),
                                        "Logits":base_logits.tolist(),
                                        "sequence_grammaticality":get_sequence_grammaticality(base_seq,base_logits)
                                     }
        # Now DMS the sequence and embed and measure to reference
        sequences = DMS(base_seq)
        for fasta in tqdm(sequences):
            name, sequence = fasta.id, str(fasta.seq)
#             print(key,name)
            mutations = [name]
            embeddings[key][name] = process_protein_sequence(sequence,model,model_layers,batch_converter)
            # L1/Manhattan Distance between mean embeddings used for the semantic change
            semantic_change = float(sum(abs(target-base) for target, base in zip(embeddings[key]["Reference"]["Mean_Embedding"],
                                                                                 embeddings[key][name] ["Mean_Embedding"])))
            gm, ev = grammaticality_and_evolutionary_index(word_pos_prob, base_seq, mutations)
#             print('Semantic score: ', semantic_change)
#             print('Grammaticality: ', gm)
#             print('Relative Grammaticality: ', ev)
            embeddings[key][name]["label"] = name
            embeddings[key][name]["semantic_score"] = semantic_change
            #Probability of mutation, given the reference sequence
            embeddings[key][name]["grammaticality"] = gm
            embeddings[key][name]["relative_grammaticality"] = ev
            #Probability of whole sequence
            embeddings[key][name]['sequence_grammaticality'] = get_sequence_grammaticality(sequence,embeddings[key][name]['Logits'])
#             print('Sequence Grammaticality: ', embeddings[key][name]['sequence_grammaticality'])
            #Probability ratio between the mutant sequence and the reference sequence
            embeddings[key][name]['relative_sequence_grammaticality'] = embeddings[key][name]['sequence_grammaticality']-embeddings[key]["Reference"]['sequence_grammaticality']
#             print('Relative Sequence Grammaticality: ', embeddings[key][name]['relative_sequence_grammaticality'])
            embeddings[key][name]["probability"] = np.exp(gm)
#             print(embeddings[key][name]['grammaticality'])
    return embeddings


def get_mutations(seq1, seq2):
    mutations = []
    for i in range(len(seq1)):
        if seq1[i] != seq2[i]:
            if seq1[i] != '-' and seq2[i] == '-':
                mutations.append('{}{}del'.format(seq1[i], i + 1))
            else:
                mutations.append('{}{}{}'.format(seq1[i] , i + 1, seq2[i]))
    return mutations

In [9]:
def single_protein_DMS(key,protein_sequence,model,model_layers,batch_converter,alphabet,device,start,end):
    embeddings = {}
    embeddings[key] = {}
    base_seq = protein_sequence
    results,base_logits, base_mean_embedding, full_embedding = embed_sequence(base_seq,model,device,model_layers,batch_converter)
    word_pos_prob = {}
    for pos in range(len(base_seq)):
        for word in alphabet.all_toks:
            word_idx = alphabet.get_idx(word)
            prob = base_logits[pos + 1, word_idx]
            word_pos_prob[(word, pos)] = prob
    embeddings[key]["Reference"] = {"Mean_Embedding":base_mean_embedding.tolist(),
                                    "Logits":base_logits.tolist(),
                                    "sequence_grammaticality":get_sequence_grammaticality(base_seq,base_logits)
                                }
    # Now DMS the sequence and embed and measure to reference
    sequences = DMS(protein_sequence,start,end)
    for fasta in tqdm(sequences):
        name, sequence = fasta.id, str(fasta.seq)
#             print(key,name)
        mutations = [name]
        embeddings[key][name] = process_protein_sequence(sequence,model,model_layers,batch_converter)
        # L1/Manhattan Distance between mean embeddings used for the semantic change
        semantic_change = float(sum(abs(target-base) for target, base in zip(embeddings[key]["Reference"]["Mean_Embedding"],
                                                                            embeddings[key][name] ["Mean_Embedding"])))
        gm, ev = grammaticality_and_evolutionary_index(word_pos_prob, base_seq, mutations)
#             print('Semantic score: ', semantic_change)
#             print('Grammaticality: ', gm)
#             print('Relative Grammaticality: ', ev)
        embeddings[key][name]["label"] = name
        embeddings[key][name]["semantic_score"] = semantic_change
        #Probability of mutation, given the reference sequence
        embeddings[key][name]["grammaticality"] = gm
        embeddings[key][name]["relative_grammaticality"] = ev
        #Probability of whole sequence
        embeddings[key][name]['sequence_grammaticality'] = get_sequence_grammaticality(sequence,embeddings[key][name]['Logits'])
#             print('Sequence Grammaticality: ', embeddings[key][name]['sequence_grammaticality'])
        #Probability ratio between the mutant sequence and the reference sequence
        embeddings[key][name]['relative_sequence_grammaticality'] = embeddings[key][name]['sequence_grammaticality']-embeddings[key]["Reference"]['sequence_grammaticality']
#             print('Relative Sequence Grammaticality: ', embeddings[key][name]['relative_sequence_grammaticality'])
        embeddings[key][name]["probability"] = np.exp(gm)
#             print(embeddings[key][name]['grammaticality'])
    return embeddings

In [10]:
torch.cuda.is_available()

False

# Load Model into GPU

In [11]:
model, alphabet = esm.pretrained.load_model_and_alphabet("esm2_t33_650M_UR50D") # 'esm2_t36_3B_UR50D' is too large for my system atm
model.eval()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda")
if torch.cuda.is_available():
    model =  model.to(device)
    print("Transferred model to GPU")

# Download Reference Sequence and Embed

In [12]:
model_layers = 33 # smaller model only has 22 layers, not 36 

In [13]:
reference_protein = 'TTSAGESADPVTATVENYGGETQVQRRQHTDIAFILDRFVKVKPKEQVNVLDLMQIPAHTLVGALLRTATYYFSDLELAVKHEGDLTWVPNGAPETALDNTTNPTAYHKEPLTRLALPYTAPHRVLATVYNGSSKYGDTSTNNVRGDLQVLAQKAERTLPTSFNFGAIKATRVTELLYRMKRAETYCPRPLLAIQPSDARHKQRIVAPAKQ'

In [14]:
dms_results = single_protein_DMS('FMDV_Reference_VP1',reference_protein,model,model_layers,batch_converter,alphabet,device,0,6) # cap to the first 6 positions instead of None

  0%|          | 0/140 [00:00<?, ?it/s]

Mutations:  ['T1A']
Mutations:  ['T1R']
Mutations:  ['T1N']
Mutations:  ['T1D']
Mutations:  ['T1C']
Mutations:  ['T1Q']
Mutations:  ['T1E']
Mutations:  ['T1G']
Mutations:  ['T1H']
Mutations:  ['T1I']
Mutations:  ['T1L']
Mutations:  ['T1K']
Mutations:  ['T1M']
Mutations:  ['T1F']
Mutations:  ['T1P']
Mutations:  ['T1S']
Mutations:  ['T1T']
Mutations:  ['T1W']
Mutations:  ['T1Y']
Mutations:  ['T1V']
Mutations:  ['T2A']
Mutations:  ['T2R']
Mutations:  ['T2N']
Mutations:  ['T2D']
Mutations:  ['T2C']
Mutations:  ['T2Q']
Mutations:  ['T2E']
Mutations:  ['T2G']
Mutations:  ['T2H']
Mutations:  ['T2I']
Mutations:  ['T2L']
Mutations:  ['T2K']
Mutations:  ['T2M']
Mutations:  ['T2F']
Mutations:  ['T2P']
Mutations:  ['T2S']
Mutations:  ['T2T']
Mutations:  ['T2W']
Mutations:  ['T2Y']
Mutations:  ['T2V']
Mutations:  ['S3A']
Mutations:  ['S3R']
Mutations:  ['S3N']
Mutations:  ['S3D']
Mutations:  ['S3C']
Mutations:  ['S3Q']
Mutations:  ['S3E']
Mutations:  ['S3G']
Mutations:  ['S3H']
Mutations:  ['S3I']


In [15]:
dms_results['FMDV_Reference_VP1'].keys()

dict_keys(['Reference', 'T1A', 'T1R', 'T1N', 'T1D', 'T1C', 'T1Q', 'T1E', 'T1G', 'T1H', 'T1I', 'T1L', 'T1K', 'T1M', 'T1F', 'T1P', 'T1S', 'T1T', 'T1W', 'T1Y', 'T1V', 'T2A', 'T2R', 'T2N', 'T2D', 'T2C', 'T2Q', 'T2E', 'T2G', 'T2H', 'T2I', 'T2L', 'T2K', 'T2M', 'T2F', 'T2P', 'T2S', 'T2T', 'T2W', 'T2Y', 'T2V', 'S3A', 'S3R', 'S3N', 'S3D', 'S3C', 'S3Q', 'S3E', 'S3G', 'S3H', 'S3I', 'S3L', 'S3K', 'S3M', 'S3F', 'S3P', 'S3S', 'S3T', 'S3W', 'S3Y', 'S3V', 'A4A', 'A4R', 'A4N', 'A4D', 'A4C', 'A4Q', 'A4E', 'A4G', 'A4H', 'A4I', 'A4L', 'A4K', 'A4M', 'A4F', 'A4P', 'A4S', 'A4T', 'A4W', 'A4Y', 'A4V', 'G5A', 'G5R', 'G5N', 'G5D', 'G5C', 'G5Q', 'G5E', 'G5G', 'G5H', 'G5I', 'G5L', 'G5K', 'G5M', 'G5F', 'G5P', 'G5S', 'G5T', 'G5W', 'G5Y', 'G5V', 'E6A', 'E6R', 'E6N', 'E6D', 'E6C', 'E6Q', 'E6E', 'E6G', 'E6H', 'E6I', 'E6L', 'E6K', 'E6M', 'E6F', 'E6P', 'E6S', 'E6T', 'E6W', 'E6Y', 'E6V', 'S7A', 'S7R', 'S7N', 'S7D', 'S7C', 'S7Q', 'S7E', 'S7G', 'S7H', 'S7I', 'S7L', 'S7K', 'S7M', 'S7F', 'S7P', 'S7S', 'S7T', 'S7W', 'S7Y', 'S7

In [16]:
compressed_pickle('FMDV_Reference_VP1',dms_results)


In [17]:
dms_results=decompress_pickle('FMDV_Reference_VP1.pbz2')

In [18]:
mutations_list = list(dms_results['FMDV_Reference_VP1'].keys())
columns = ['label', 'semantic_score', 'grammaticality', 'relative_grammaticality', 'sequence_grammaticality', 'relative_sequence_grammaticality', 'probability']
table = []
for key in mutations_list:
    if key != 'Reference':
        row = pd.DataFrame([dms_results['FMDV_Reference_VP1'][key].get(c) for c in columns]).T
        row.columns = columns
        table.append(row)

In [19]:
dms_table = pd.concat(table)

# Annotate table

In [20]:
dms_table['ref'] = dms_table.label.str[0]
dms_table['alt'] = dms_table.label.str[-1]
dms_table['position'] = dms_table.label.str[1:-1].astype(int)

#Keep Reference scores
reference_s_table = dms_table[dms_table.ref == dms_table.alt]
#Filter non mutations
dms_table = dms_table[dms_table.ref != dms_table.alt]


dms_table = dms_table.sort_values('semantic_score')
dms_table['semantic_rank'] = dms_table.reset_index().index.astype(int) + 1
dms_table = dms_table.sort_values('grammaticality')
dms_table['grammatical_rank'] =dms_table .reset_index().index.astype(int) + 1
dms_table['acquisition_priority'] = dms_table['semantic_rank'] + dms_table['grammatical_rank']

dms_table = dms_table.sort_values('sequence_grammaticality')
dms_table['sequence_grammatical_rank'] =dms_table.reset_index().index.astype(int) + 1
dms_table['sequence_acquisition_priority'] = dms_table['semantic_rank'] + dms_table['sequence_grammatical_rank']



In [21]:
def fmdv_domain_annotation(position):
    if position>=137 and position <=143: # 6 positions we want to focus on mutations of?
        return "GH Loop"
    elif position>=145 and position <=147:
        return "RDG Motif"
    else:
        return ""


In [22]:
dms_table["Domain"] = ""
dms_table["Domain"] = [ fmdv_domain_annotation(pos) for pos in dms_table["position"] ]

In [23]:
dms_table.sort_values('position').to_csv('FMDV_Reference_VP1.csv',index=False)