In [2]:
import json
from Bio import SeqIO, Phylo, Seq
from string import digits
from collections import defaultdict, Counter
import matplotlib.pyplot as plt

In [55]:
def CDS_list(reference):
    """this function finds CDS location and not CDS location, and saves only those not located at the end of the function"""
    cds_ = []
    for feature in reference.features:
        if feature.type == 'CDS': 
            cds_.extend(list(feature.location))
    return(cds_)

def before_and_after (reference, nr, location, nucl):
        """this function returns counts of nucleotides before or after a specific nucleotide in the CDS region of the reference"""

        ref = SeqIO.read(reference, "genbank")
        CDS_ = CDS_list(ref)
        all_motifs = []
        for i, nu in enumerate(ref.seq):
            if nu == nucl and i in CDS_:    
                if location == "before" and i-nr in CDS_:
                    if "N" not in ref.seq[i-1]:
                        all_motifs.append(ref.seq[i-nr])
                if location == "after" and i+nr in CDS_:
                    if 'N' not in ref.seq[i+1]:
                     all_motifs.append(ref.seq[i+1])
        return(Counter(all_motifs))


def CDS_finder(reference):
    """this function finds CDS location and not CDS location, and saves only those not located at the end of the function"""
    cds_ = dict()
    for feature in reference.features:
        if feature.type == 'CDS': 
            cds_[feature.qualifiers['gene'][0]] = (list(feature.location))
    return(cds_)


ref_file = SeqIO.read("data/areference.gbk", "genbank")


def mutation_recursive(node, dictionary_=None, new_=None):
    """ this function returns a dictionary with node name as key and a list of mutations along that branch as the info, and excludes deletions and unknowns"""
    gene_cds = CDS_finder(ref_file)
    if new_ is None: new_ = []
    if dictionary_ is None: dictionary_ = dict()
    if 'mutations' in node['branch_attrs']:
        aa_mutations, new_, in_it = ([] for i in range(3))
        if 'nuc' in node['branch_attrs']['mutations']:
            for gene, loc in gene_cds.items():
                if gene in node['branch_attrs']['mutations']:
                    for mut in node['branch_attrs']['mutations'][gene]:
                        aa_mutations.append(int(mut[1:-1])*3+loc[0])
            for mut in node['branch_attrs']['mutations']['nuc']:
                if '-' not in mut and '*' not in mut and 'N' not in mut and "R" not in mut and "Y" not in mut and "M" not in mut and "D" not in mut:
                    if int(mut[1:-1]) not in aa_mutations and int(mut[1:-1])+2 not in aa_mutations and int(mut[1:-1])+1 not in aa_mutations:
                        new_.append(mut)
                    else: in_it.append(mut[1:-1])
    if 'name' in node:
            dictionary_[node['name']] = new_
    if 'children' in node:
        for child in node['children']:
           mutation_recursive(child, dictionary_, new_=None)
    return(dictionary_)

ref_file = SeqIO.read("data/areference.gbk", "genbank")
with open ("data/rsv_a_genome.json") as file_:
    f = json.load(file_)  
    mut_by_node = mutation_recursive(f['tree'])

def find_relevant_nucl(sequencefile, dictionarysynonymous, motif, where):
    """this function takes as input reconstructed tree branches and a dictionary of synonymous mutations and returns nucleotides before or after the mutation of interest"""
    records = SeqIO.parse(sequencefile, "fasta")
    output = []
    for record in records:
        for mut in dictionarysynonymous[record.id]:
            if f'{mut[0]}{mut[-1]}' == motif:
                    location_of_mut = int(mut[1:-1])-1
                    if where == "after":
                        nucleotide = record.seq[location_of_mut+1]
                        if nucleotide != '-' and nucleotide!= 'N': output.append(nucleotide)
                    if where == "before":
                        nucleotide = record.seq[location_of_mut-1]
                        if nucleotide != '-' and nucleotide!= 'N': output.append(nucleotide)
    return(Counter(output))


all_= find_relevant_nucl("data/reconstructed_sequences.fasta", mut_by_node, "TC", "before")
scale = before_and_after("data/areference.gbk", 1, "before", "T")
# multiply by nr of branches

sequences = SeqIO.parse("data/reconstructed_sequences.fasta", "fasta")
length = 0
for i in sequences:
    length+=1

from Bio import Phylo

tree_ = Phylo.read("data/a_tree.nwk", "newick")
treelength = tree_.total_branch_length()
#here I divide the chances of it occurring by the number of branches in the tree
#scaled_ ={x:float(all_[x])/(scale[x]*length) for x in all_}

scaled_ ={x:float(all_[x])/(scale[x]*treelength) for x in all_}
print("nucleotide before T to C", scaled_)
import pandas as pd



all_= find_relevant_nucl("data/reconstructed_sequences.fasta", mut_by_node, "CT", "before")
scale = before_and_after("data/areference.gbk", 1, "before", "C")
scaled_ ={x:float(all_[x])/(scale[x]*treelength) for x in all_}
print("nucleotide before C to T", scaled_)

1.1469072199999948
nucleotide before T to C {'A': 0.8858364517246887, 'C': 1.7750104997364882, 'G': 0.8921508800984707, 'T': 0.7986962836899302}
nucleotide before C to T {'T': 1.3640732062049785, 'C': 2.559275717665431, 'A': 1.4251365484155716, 'G': 1.1260229491513583}


In [68]:
#dataframe of counts
import pandas as pd
combos = ["CT", "CA", "CG", "TC", "TA", "TG", "AC", "AG", "AT", "GC", "GT", "GA"]
all_dictionaries = dict()
for mut in combos:
    all_= find_relevant_nucl("data/reconstructed_sequences.fasta", mut_by_node, mut, "before")
    scale = before_and_after("data/areference.gbk", 1, "before", mut[0])
    scaled_ ={x:float(all_[x])/(scale[x]*treelength) for x in all_}
    all_dictionaries[mut]= scaled_

df = pd.DataFrame(all_dictionaries.values()).T
df.columns = ['{}'.format(i) for i in all_dictionaries.keys()]
print(df)

         CT        CA        CG        TC        TA        TG        AC  \
T  1.364073  0.302975  0.019193  0.798696  0.177805  0.038033  0.035426   
C  2.559276  0.401455  0.010455  1.775010  0.320409  0.052456  0.082893   
A  1.425137  0.131470  0.003155  0.885836  0.050861  0.012715  0.040429   
G  1.126023  0.107240  0.006994  0.892151  0.068507  0.026469  0.010061   

         AG        AT        GC        GT        GA  
T  0.489455  0.101940  0.004529  0.106441  0.777925  
C  0.358435  0.290125  0.009797  0.215528  4.075445  
A  0.435468  0.036533  0.001015  0.036541  0.920632  
G  0.334232  0.070424  0.009634  0.038537  0.674406  
