In [None]:
from Bio import SeqIO
from Bio import AlignIO
from Bio import Phylo
import matplotlib.pyplot as plt
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from adjustText import adjust_text
from Bio import SeqIO
from collections import Counter
import pandas as pd
from itertools import product
import re



In [None]:
def root_tree_with_first_sequence(tree, root_sequence):
    """
    Roots the tree file with the provided root sequence.

    Args:
    tree: The tree object to be rooted.
    root_sequence (str): The sequence_id to be used as the root.

    Returns:
    rooted_tree_file (str): The path to the rooted tree file.
    """
    # Root the tree file with the outgroup
    rooted_tree_file = tree.root_with_outgroup(root_sequence)

    return rooted_tree_file



In [None]:
def get_reconstructed_ancestral_sequences(ancestral_state_file):
    """
    Reconstructs ancestral sequences from the ancestral state file.
    
    Parameters:
        ancestral_state_file (str): The path to the ancestral state file.
    
    Returns:
        dict: A dictionary containing the reconstructed ancestral sequences, where the keys are the node names and the values are the sequences.
    """
    # Read the ancestral state file
    ancestral_states = pd.read_csv(ancestral_state_file, sep='\t', skiprows=8)
    # Concatenate the ancestral states at each position in a given node to reconstruct ancestral sequences
    ancestral_sequences = {}
    for node, group in ancestral_states.groupby('Node'):
        sequence = ''.join(group['State'])
        ancestral_sequences[node] = sequence
    
    return ancestral_sequences


In [None]:
#Function to compute the trinucleotide frequency of a sequence to be used in normalization step while computing the trinucleotide mutation rate
def get_kmer_frequency(sequence, k):
    """
    Calculate the frequency of k-mers in a given sequence.

    Args:
        sequence (str): The input RNA sequence.
        k (int): The length of the k-mer.

    Returns:
        kmers_frequency_dict: A dictionary containing all possible k-mers as keys and their frequencies in the input sequence as values.
    """
    # Remove gaps from the sequence
    sequence = sequence.replace("-", "")
    #Total number of kmers in the sequence
    total_kmers = len(sequence) - k + 1

    # Declare and Initialize the dictionary to store the kmer frequency by assigning all possible k-mers with frequency 0
    nucleotides = ["C", "T", "G", "A"]
    kmers_frequency_dict = {''.join(combo): 0 for combo in product(nucleotides, repeat=k)}

    # Iterate through the sequence to find all k-mers
    for i in range(total_kmers):
        kmer = sequence[i:i + k]  # Extract the k-mer
        # Check if the k-mer consists only of A, T, G, or C and not any other characters
        if set(kmer).issubset(set(nucleotides)):
            # Increment the frequency of this k-mer
            kmers_frequency_dict[kmer] += 1  

    return kmers_frequency_dict

In [None]:
def find_substitution_mutations(ancestor_seq, sample_seq):
    """
    To compute the mutations by comparing each sequence against its nearest reconstructed ancestral sequence

    Args:
        ancestor_seq (str): The ancestor sequence.
        sample_seq (str): The sample sequence.

    Returns:
        mutations: A list of mutations, where each mutation is represented as a list containing the mutation, the trinucleotide context, and the trinucleotide normalized mutation rate.
    """
    # Initialize the list to store the mutations
    mutations = []
    # Get the trinucleotide mutation frequencies in the ancestor sequence
    trinucleotide_mutation_frequency = get_kmer_frequency(ancestor_seq, 3)
    # Iterate through each position and base in the two sequences
    for i, (a, b) in enumerate(zip(ancestor_seq, sample_seq)):
        # If the base is different between the ancestor and sample sequences 
        if a != b:
            #position of the mutation in the alignment is index + 1
            position = i + 1  
            # Extract the trinucleotide context of the mutation
            trinucleotide = str(ancestor_seq[position-2:position+1]).upper()  
            # Base in the ancestor sequence
            ref_base = a.upper()
            # Base in the sample sequence
            sample_base = b.upper()
            # We ignore those mutations at the starting or end poisition of the alignment as their trinucleotide context cannot be captured
            # We also ignore those mutations in which the surrounding trinucleotide contains abnormal basses or the mutation itself contains abnormal bases
            if (
                any(base not in ["C", "T", "G", "A"] for base in trinucleotide) or
                any(base not in ["C", "T", "G", "A"] for base in [ref_base, sample_base]) or
                len(trinucleotide) != 3
            ):                
                continue
            else:
                # Record the mutation data which includes the mutation, its trinucleotide context and trinuleotide mutation rate (normalized) 
                mutation = f"{ref_base}{position}{sample_base}".upper()
                mutations.append([mutation, trinucleotide, 1/trinucleotide_mutation_frequency[trinucleotide]])
               
    return mutations


In [None]:
def find_ancestral_node(tree, node_name):
    """
    To find the nearest ancestor to a particular node in the tree

    Args:
        tree: The phylogenetic tree to search in.
        node_name: The name of the node for which we wish to find the neareast ancestor.

    Returns:
        ancestral_node: The ancestral node if found, None otherwise.
    """
    # Search the target node in the tree
    target_node = next((clade for clade in tree.find_clades() if str(clade.name).split("/")[0] == node_name), None)
    if target_node is None:
        print(f"Node '{node_name}' not found.")
        return None

    # Traverse the tree to find the ancestral node to the target node
    ancestral_node = None
    for clade in tree.find_clades():
        if target_node in clade:
            ancestral_node = clade
            break

    return ancestral_node



In [None]:

def get_mutation_data_for_internal_nodes(ancestral_sequences, alignment_file, tree):
    """ 
    To compute the mutations for the internal nodes of the tree

    Args:
        ancestral_sequences (dict): Dictionary of reconstructed ancestral sequences.
        alignment_file (str): The path to the mutiple sequence alignment file.
        tree: The phylogenetic tree.

    Returns:
        - denovo_mutations: A dictionary mapping each internal node to its denovo mutations.
         - ancestral_mutations: A dictionary mapping each internal node to its ancestral mutations.
         - ancestral_node_data: A dictionary mapping each internal node to its nearest ancestral node.

    """
    # Initialize the dictionaries to store the mutation data
    denovo_mutations = {}
    ancestral_mutations = {}
    ancestral_node_data = {}
    # Sort the ancestral nodes based on their node number
    ancestral_nodes = list(ancestral_sequences.keys())
    ancestral_nodes = sorted(ancestral_nodes, key=lambda node: int(re.search(r'\d+', node).group()))

    for node in ancestral_nodes:
        #If it is the initial node, then the reference sequence is used as the ancestral sequence
        if node == "Node1":
            reference_record = next(SeqIO.parse(alignment_file, "fasta"), None)
            reference_id = reference_record.id
            reference_seq = str(reference_record.seq).upper()
            denovo_mutations[node] = find_substitution_mutations(reference_seq, ancestral_sequences[node])
            ancestral_mutations[node] = []
            ancestral_node_data["Node1"] = reference_id
        #If it is not the initial node, then we find the nearest ancestor and compute the mutations against it
        else:
            ancestral_node = find_ancestral_node(tree, node)
            denovo_mutations[node] = find_substitution_mutations(
                ancestral_sequences[str(ancestral_node).split("/")[0]], ancestral_sequences[node]
            )
            ancestral_mutations[node] = (
                denovo_mutations[str(ancestral_node).split("/")[0]]
                + ancestral_mutations[str(ancestral_node).split("/")[0]]
            )
            ancestral_node_data[node] = ancestral_node

    return denovo_mutations, ancestral_mutations, ancestral_node_data


In [None]:

def get_mutation_data_for_leaf_nodes(alignment_file, ancestral_sequences, denovo_mutations, ancestral_mutations, tree, ancestral_node_data):
    """
    To compute the mutations for the leaf nodes of the tree
    
    Parameters:
    - alignment_file (str): Path to the multiple sequence alignment file .
    - ancestral_sequences (dict): Dictionary of reconstructed ancestral sequences.
    - denovo_mutations (dict): Dictionary to store denovo mutations. 
    - ancestral_mutations (dict): Dictionary to store ancestral mutations. 
    - tree: the Phylogenetic tree object.
    - ancestral_node_data (dict): Dictionary to store ancestral node data.

    Returns:
    - denovo_mutations (dict): Updated dictionary of denovo mutations.
    - ancestral_mutations (dict): Updated dictionary of ancestral mutations.
    - ancestral_node_data (dict): Updated dictionary of ancestral node data for each leaf node.
    """
    # Get the leaf node names from the alignment file    
    leaf_nodes = [record.id for record in SeqIO.parse(alignment_file, "fasta")]
    #As the first leaf node is the reference sequence, we need to handle it separately
    denovo_mutations[leaf_nodes[0]] = []
    ancestral_mutations[leaf_nodes[0]] = []
    #Remove the reference sequence from the leaf nodes
    leaf_nodes = leaf_nodes[1:]
    #Compute the mutations for each leaf node
    for leaf_node in leaf_nodes:
        #find its nearest ancestral node data
        ancestral_node = find_ancestral_node(tree, leaf_node)
        #store the ancestral node data for the leaf node
        ancestral_node_data[leaf_node] = ancestral_node
        #extract the sequence of the leaf node
        sequence = next((str(record.seq).upper() for record in SeqIO.parse(alignment_file, "fasta") if record.id == leaf_node), None)
        #compute the mutations against the nearest ancestral sequence
        denovo_mutations[leaf_node] = find_substitution_mutations(
            ancestral_sequences[str(ancestral_node).split("/")[0]], 
            sequence)
        #bootstrapping during tree construction can modify the names of the nodes in the tree by including the bootstrap results, hence we need to extract only the node name from the tree
        #and use it to save mutation data for that leafnode
        ancestral_mutations[leaf_node] = denovo_mutations[str(ancestral_node).split("/")[0]] + ancestral_mutations[str(ancestral_node).split("/")[0]]
    return denovo_mutations, ancestral_mutations, ancestral_node_data



We access the precomputed MSA file and the precomputed tree file. As IQTREE does not automatically root the tree with the outgroup hence we do it explicitly. Also, we use the ancestral states at each ancestral nodes of the tree to reconstruct the probable ancestral sequences.

In [None]:
#path to the MSA file
alignment_file = 'SARS-CoV-2_18000+1_msa.fasta'

#we load the phylognetic tree 
tree = Phylo.read(f"Fast IQTREE/fast.treefile", "newick")

#we now root it by assigning the outgroup/reference sequence as the root of the tree
outgroup_id = "NC_045512.2"  
root_tree_with_first_sequence(tree, outgroup_id)

#we now reconstruct the ancestral sequences of the tree from the ancestral state file
ancestral_sequences=get_reconstructed_ancestral_sequences(f"Fast IQTREE/fast.state")

Next, we compute the mutations by comparing each sequence to its nearest reconstructed ancestor. At first, we compute the mutations for the internal nodes and then for the leaf nodes

In [None]:
#we now compute the mutations for the internal nodes of the tree
denovo_mutations, ancestral_mutations,ancestral_node_data=get_mutation_data_for_internal_nodes(ancestral_sequences,alignment_file, tree)

#we now compute the mutations for the leaf nodes of the tree   
denovo_mutations, ancestral_mutations,ancestral_node_data=get_mutation_data_for_leaf_nodes(alignment_file,ancestral_sequences, denovo_mutations, ancestral_mutations,tree,ancestral_node_data)


As we are only concerend about the mutations in the leaf nodes which represents the sampled genomes. Hence we filter out the data for the internal nodes and the outgroup genome

In [None]:
#we extract the mutations for the leaf nodes(sampled genomes) of the tree
denovo_mutations_in_sampled_genomes = {key: value for key, value in denovo_mutations.items() if not key.startswith("Node")}
denovo_mutations_in_sampled_genomes.pop(outgroup_id)
ancestral_mutations_in_sampled_genomes = {key: value for key, value in denovo_mutations.items() if not key.startswith("Node")}
ancestral_mutations_in_sampled_genomes.pop(outgroup_id)


We save these dictionaries for later use

In [None]:
import pickle

with open('denovo_mutations_in_sampled_genomes.pkl', 'wb') as f:
    pickle.dump(denovo_mutations_in_sampled_genomes, f)

with open("ancestral_mutations_in_sampled_genomes.pkl", 'wb') as f:
    pickle.dump(ancestral_mutations_in_sampled_genomes, f)

Next, we use the de novo mutations to compute the overall mutation profile of the 18000 SARS-CoV-2 sampled genomes

In [None]:

def get_mutation_stats_table(mutation_data):
    """
    Calculates the mutation count and average trinucleotide mutation rate for each mutation type.

    Args:
        mutation_data (dict): the dictionary containing de novo mutation data for each sampled genome.

    Returns:
        overall_mutation_df: A DataFrame containing the mutation count and average trinucleotide mutation rate for each mutation type.

    """
    # Initialize the dictionary to store the mutation count and average trinucleotide mutation rate for each mutation type
    nucleotides=["C", "T", "G", "A"]
    mutation_count_dict={}
    for nucleotide1 in nucleotides:
        for nucleotide2 in nucleotides:
            if(nucleotide1!=nucleotide2):
                mutation_count_dict[nucleotide1+"->"+nucleotide2]=[0,0]

    # Compute the mutation count and average trinucleotide mutation rate for each mutation type
    total_genomes=18000
    for leaf_node in mutation_data:
        for mutation in mutation_data[leaf_node]:
            mutation_type=mutation[0][0]+"->"+mutation[0][-1]
            mutation_count_dict[mutation_type][0]+=1
            mutation_count_dict[mutation_type][1]+=mutation[2]*(10000/(total_genomes*16))


    # Creating the DataFrame
    overall_mutation_df = pd.DataFrame(mutation_count_dict, index=['Mutation Count',  'Mean Trinucleotide Mutation Rate']).T

    # Renaming the columns
    overall_mutation_df.index.name = 'Mutation Type'

    # Replacing the T with U in the index for the Mutation Type because we are dealing with RNA virus
    overall_mutation_df.index = overall_mutation_df.index.str.replace('T', 'U')
    
    # Sorting the DataFrame based on the Mutation Count
    overall_mutation_df = overall_mutation_df.sort_values(by='Mutation Count', ascending=False)
    
    
    return overall_mutation_df

overall_mutation_df=get_mutation_stats_table(denovo_mutations_in_sampled_genomes)

print("Total mutation count:", overall_mutation_df['Mutation Count'].sum())
overall_mutation_df

In [None]:
def plot_trinucleotide_mutation(denovo_mutations):
    """
    Plots the trinucleotide mutation Rate for each mutation type.

    Args:
        denovo_mutations (dict): The dictionary containing the denovo mutations data for the sampled genomes.

    """
    # Generate all possible trinucleotide combinations
    flank_length=1
    nucleotides=["C", "T", "G", "A"]
    trinucleotide_list=nucleotides.copy()
    for i in range(flank_length*2):
        new_list=[]
        for base in nucleotides:
            for old_index in trinucleotide_list.copy():
                new_index=old_index[:int(len(old_index)/2)]+base+old_index[int(len(old_index)/2):]
                new_list.append(new_index)
        trinucleotide_list=new_list        
        
    # Initialize the dictionary to store the trinucleotide mutation rates for each mutation type
    combined_mutation_data={}
    for index in trinucleotide_list:
        combined_mutation_data[index]={}
        for nucleotide in nucleotides:
            if(index[flank_length]!=nucleotide):
                combined_mutation_data[index][nucleotide]=[]
                
    #Compute the trinucleotide mutation rate for each mutation type
    for genome in denovo_mutations:
        for mutation_in_genome in denovo_mutations[genome]:
            trinucleotide=mutation_in_genome[1]
            alternate_base=mutation_in_genome[0][-1]
            combined_mutation_data[trinucleotide][alternate_base].append(mutation_in_genome)
            
    # Generate the mutation signature list from the combined mutation data to help in plotting
    mutation_signature_list=[]
    #subheadings to be displayed at the top of every 16 bars(representing the trinucleotide mutation rates) are the mutation types associated these trinucleotide muation rates
    subheadings=[]
    #Iterating through the combined mutation data to generate the mutation signature list
    for i in range(0,len(combined_mutation_data),4**(flank_length*2)):
            sublist=list(combined_mutation_data.keys())[i:i+4**(flank_length*2)]
            if(flank_length!=0):
                if(sublist[flank_length][1] in nucleotides):
                    for new_base in ["A","C","G","T"]:
                        if(sublist[0][flank_length]!=new_base):
                            subheadings.append((sublist[0][flank_length]+"->"+new_base).replace("T","U"))
                        for polynucleotide in sublist:
                            if(polynucleotide[flank_length]!=new_base):
                                mutation_signature_list.append([polynucleotide, new_base,10000*sum(i[2] for i in combined_mutation_data[polynucleotide][new_base])/(len(denovo_mutations)),len(combined_mutation_data[polynucleotide][new_base])])
            else:
                if(sublist[flank_length] in nucleotides):
                    for new_base in ["A","C","G","T"]:
                        if(sublist[flank_length]!=new_base):
                            subheadings.append((sublist[0][flank_length]+"->"+new_base).replace("T","U"))
                        for polynucleotide in sublist:
                            if(polynucleotide[flank_length]!=new_base):
                                mutation_signature_list.append([polynucleotide, new_base, len(combined_mutation_data[polynucleotide][new_base])])
                


    # Extracting mutation types and counts
    mutation_types = [mutation[0].replace("T","U")+"->"+mutation[1].replace("T","U") for mutation in mutation_signature_list] # Replace 'T' with 'U' for RNA
    mutation_counts = [mutation[2] for mutation in mutation_signature_list]
    mutation_count_fraction=mutation_counts
    print(mutation_types)
    print(mutation_counts)

    # Define colors for every 16 bars
    colors = ['#5fbceb', '#06070a', '#d33c33', '#cacaca', '#accc6f', '#e7cac4', '#a5d4e6', '#5e6164', '#cd7430', '#eae8eb', '#dff488', '#f1dcdb']
    color_cycle = [color for color in colors for _ in range(16)]


    # Plotting the bar plot with cycling colors and gaps
    plt.figure(figsize=(30, 7))  
    for i in range(len(mutation_types)):
            plt.bar(mutation_types[i], mutation_counts[i], color=color_cycle[i], edgecolor='none')  # Set edgecolor to 'none'

            # Add heading at the top middle of every 16 bars
            if (i + 1) % 16 == 8:  # Assuming the middle point is the 8th bar in each group
                bbox_props = dict(boxstyle="round,pad=0.5", edgecolor=color_cycle[i], facecolor=color_cycle[i], alpha=0.5)
                plt.text(i, max(mutation_counts)+1, subheadings[int((i+1)/16)],
                        ha='center', va='bottom', fontsize=12, color='black', bbox=bbox_props)
            
            # Add value(i.e. trinucleotide mutation rates) at the top of each bar
            plt.text(i, mutation_counts[i], "  "+str(round(mutation_count_fraction[i],2)), ha='center', va='bottom', fontsize=10, rotation=90)

    plt.xlabel('Trinuclotides', fontsize=12)
    plt.ylabel(r'Average Trinucleotide Mutation Rate (Normalized) ($10^{-4}$) $\rightarrow$', fontsize=12)
    # Rotate x-axis labels for better visibility
    plt.xticks(rotation=90, ha='center', fontsize=10)  
    # Remove outline at top and right
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.tight_layout()

    # Display the plot
    plt.show()

plot_trinucleotide_mutation(denovo_mutations_in_sampled_genomes)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import logomaker

def plot_trinucleotide_mutations_per_strand_bar_plot(mutation_data, reference_mutation_type, color1, complementary_mutation_type, color2):
    """ 
    In this function we compare two complemantry muttions types based on the trinucleotide mutation rates
    We plot the trinucleotide mutation rates for the reference and complementary mutation types in a bar plot both in sorted and unsorted manner
    We also plot the Position Weight Matrix(PWM) computed from the trinucleotide mutation rate for the reference and complementary mutation types

    Args: 
    - mutation_data (dict): The dictionary containing the denovo mutations data for the sampled genomes.
    - reference_mutation_type (str): The reference mutation type.
    - color1 (str): The color for representing the reference_mutation_type.
    - complementary_mutation_type (str): The complementary mutation type.
    - color2 (str): The color for representing the complementary_mutation_type.

    """

    # Generate all possible trinucleotide combinations
    nucleotides=["C", "T", "G", "A"] 
    trinucleotide_list=nucleotides.copy()
    for i in range(1*2):
        new_list=[]
        for base in nucleotides:
            for old_index in trinucleotide_list.copy():
                new_index=old_index[:int(len(old_index)/2)]+base+old_index[int(len(old_index)/2):]
                new_list.append(new_index)
        trinucleotide_list=new_list        

    # Initialize the dictionaries to store the trinucleotide mutation rates for the different mutation types
    count_data={}
    cumulative_count={}
    for trinucleotide in trinucleotide_list:
        count_data[trinucleotide]={}
        cumulative_count[trinucleotide]={}
        for nucleotide in nucleotides:
            if(trinucleotide[1]!=nucleotide):
                count_data[trinucleotide][nucleotide]=0
                cumulative_count[trinucleotide][nucleotide]=[]

    # Compute the trinucleotide mutation rates for each mutation type
    for this_genome in mutation_data:
        if(this_genome in mutation_data):
            trinucleotide_mutations_in_this_genome=mutation_data[this_genome]
            for trinucleotide_mutation in trinucleotide_mutations_in_this_genome:
                trinucleotide=trinucleotide_mutation[1]
                alternate_base=trinucleotide_mutation[0][-1]
                trinucleotide_muttion_rate=trinucleotide_mutation[2]
                count_data[trinucleotide][alternate_base]+=trinucleotide_muttion_rate/(len(mutation_data))



    # WE PLOT 6 SUBPLOTS(2 ROWS AND 3 COLUMNS)

    # Specify the width ratios for each column
    width_ratios = [2, 1,1]  
    # Create subplots with different column widths
    fig, ax = plt.subplots(2, 3, figsize=(20, 10), gridspec_kw={'width_ratios': width_ratios})

    #Compute the trinucleotide mutation rates for the reference and complementary mutation types
    #for the reference mutation type
    trinucleotide_mutation_rates={}
    for trinucleotide in count_data:
        if(trinucleotide[1]==reference_mutation_type[0]):   
            for alternate_base in count_data[trinucleotide]:
                if(alternate_base==reference_mutation_type[-1]):   
                    trinucleotide_mutation_rates[trinucleotide]=[count_data[trinucleotide][alternate_base]*10000]
    #for the complementary mutation type
    for trinucleotide in count_data:
        if(trinucleotide[1]==complementary_mutation_type[0]):   
            for alternate_base in count_data[trinucleotide]:
                if(alternate_base==complementary_mutation_type[-1]):   
                    trinucleotide_mutation_rates[str(Seq(trinucleotide).reverse_complement())].append(count_data[trinucleotide][alternate_base]*10000)    



    #PLOTTING THE TRINUCLEOTIDE MUTATATION RATES BAR PLOTS FOR THE REFERENCE AND COMPLEMENTARY MUTATION TYPES
    # Plotting the bar plot for the reference mutation type
    bars1 = ax[0,0].bar([f"{polynucleotide}_{str(Seq(polynucleotide).reverse_complement())}" for polynucleotide in trinucleotide_mutation_rates],
                    [trinucleotide_mutation_rates[polynucleotide][0] for polynucleotide in trinucleotide_mutation_rates],
                    color=color1)
    ax[0,0].set_ylabel(r'Average Trinucleotide Mutation Rate (Normalized) ($10^{-4}$) $\rightarrow$', fontsize=8)
    # Set the title as the reference mutation type
    ax[0,0].set_title(reference_mutation_type.replace("T","U"), fontweight='bold', fontsize=18) #Replace 'T' with 'U' for RNA
    # Set y-axis limit to 1.5 times the maximum trinucleotide mutation rate
    ax[0,0].set_ylim(0, 1.5*max(trinucleotide_mutation_rates[polynucleotide][0] for polynucleotide in trinucleotide_mutation_rates))  
    # Set x-axis labels
    ax[0,0].set_xticklabels([f"{polynucleotide} ({str(Seq(polynucleotide).reverse_complement())})".replace("T","U") for polynucleotide in trinucleotide_mutation_rates], rotation=270) #Replace 'T' with 'U' for RNA
    # Plotting the bar plot for the complementary mutation type
    bars2 = ax[1,0].bar([f"{polynucleotide}_{str(Seq(polynucleotide).reverse_complement())}" for polynucleotide in trinucleotide_mutation_rates],
                    [trinucleotide_mutation_rates[polynucleotide][1] for polynucleotide in trinucleotide_mutation_rates],
                    color=color2)
    ax[1,0].set_ylabel(r'Average Trinucleotide Mutation Rate (Normalized) ($10^{-4}$) $\rightarrow$', fontsize=8)
    # Set the title as the complementary mutation type
    ax[1,0].set_title(complementary_mutation_type.replace("T","U"), fontweight='bold', fontsize=18) #Replace 'T' with 'U' for RNA
    # Set x-axis label
    ax[1,0].set_xlabel('Trinucleotides ->')
    # Set y-axis limit to 1.5 times the maximum trinucleotide mutation rate
    ax[1,0].set_ylim(0, 1.5*max(trinucleotide_mutation_rates[polynucleotide][1] for polynucleotide in trinucleotide_mutation_rates))  
    # Set x-axis labels
    ax[1,0].set_xticklabels([f"{str(Seq(polynucleotide).reverse_complement())} ({polynucleotide})".replace("T","U") for polynucleotide in trinucleotide_mutation_rates], rotation=270) #Replace 'T' with 'U' for RNA
    
    # Add a horizontal line deicting the average trinucleotide mutation rate
    ax[0,0].axhline(np.mean([bar.get_height() for bar in bars1]), color='black', linestyle='--', linewidth=1)
    ax[1,0].axhline(np.mean([bar.get_height() for bar in bars2]), color='black', linestyle='--', linewidth=1)
    # Function to mention the associated Trinucleotide Mutation Rate on top of each bar
    def add_value_labels(ax, bars, counts):
        bar_count = -1
        for bar in bars:
            bar_count += 1
            height = bar.get_height()
            ax.annotate('{}'.format(counts[bar_count]),
                        xy=(bar.get_x() + bar.get_width() / 2, height),
                        xytext=(0, 3), 
                        textcoords="offset points",
                        ha='center', va='bottom', rotation=90)  
    # Add Trinucleotide Mutation rate value at the top of each bar in both plots
    add_value_labels(ax[0,0], bars1, [round(trinucleotide_mutation_rates[polynucleotide][0],2) for polynucleotide in trinucleotide_mutation_rates])
    add_value_labels(ax[1,0], bars2, [round(trinucleotide_mutation_rates[polynucleotide][1],2) for polynucleotide in trinucleotide_mutation_rates])



    #WE PLOT THE POSITION WEIGHT MATRIX FOR THE REFERENCE MUTATION TYPE
    combined_mutation_dict={}
    for trinucleotide in count_data:
        if(trinucleotide[1]==reference_mutation_type[0]):
            for alternate_base in count_data[trinucleotide]:
                if(alternate_base==reference_mutation_type[-1]):
                    combined_mutation_dict[trinucleotide]=count_data[trinucleotide][alternate_base]
    #Initialize the Position Weight Matrix(PWM) for the reference mutation type
    pwm = np.zeros((3, 4))
    #Compute the Position Weight Matrix(PWM) for the reference mutation type
    for trinucleotide in combined_mutation_dict:
        rate=combined_mutation_dict[trinucleotide]
        if(rate>0):
            for i in range(0, len(trinucleotide)):
                if trinucleotide[i] == "A":
                    pwm[i, 0] += rate
                elif trinucleotide[i] == "C":
                    pwm[i, 1] += rate
                elif trinucleotide[i] == "G":
                    pwm[i, 2] += rate
                elif trinucleotide[i] == "T":
                    pwm[i, 3] += rate
    # Normalize the PWM base frquencies
    normalized_pwm = pwm / pwm.sum(axis=1, keepdims=True)
    # Create a DataFrame from the Normalized PWM 
    pwm_df = pd.DataFrame(normalized_pwm, index=[i for i in range(1, 4)], columns=["A", "C", "G", "U"])


    # Create a Logo object to Visualize the Position Weight Matrix(PWM)
    logo_reference_strand = logomaker.Logo(pwm_df, ax=ax[0,2], shade_below=0.5, fade_below=0.5)
    # Style for better visualization
    logo_reference_strand.style_spines(visible=False)
    logo_reference_strand.style_spines(spines=['left', 'bottom'], visible=True)
    logo_reference_strand.style_xticks(rotation=90, fmt='%d', anchor=0)
    logo_reference_strand.ax.xaxis.set_ticks_position('none')
    logo_reference_strand.ax.xaxis.set_tick_params(pad=-1)
    # add it to the plot
    logo_reference_strand.draw()



    #PLOT THE SORTED TRINUCLEOTIDE MUTATION RATES FOR THE REFERENCE MUTATION TYPE
    # Sort the trinucleotide mutation rates in descending order
    sorted_trinucleotide_mutation_rate_data = dict(sorted(combined_mutation_dict.items(), key=lambda item: item[1], reverse=True))
    # Extract trinucelotide labels and corresponding mutation rates for plotting
    trinucelotides = [f"{item} ({str(Seq(item).reverse_complement())})".replace("T","U") for item in sorted_trinucleotide_mutation_rate_data]
    trinucelotide_mutation_rates = [sorted_trinucleotide_mutation_rate_data[item]*10000 for item in sorted_trinucleotide_mutation_rate_data]
    # Plotting
    ax[0, 1].barh(trinucelotides, trinucelotide_mutation_rates, color=color1)  
    #Styling
    ax[0, 1].set_title(reference_mutation_type.replace("T","U"), fontweight='bold', fontsize=18)
    ax[0, 1].invert_yaxis()  # Invert y-axis for readability
    ax[0, 1].tick_params(axis='x', rotation=90)  # Rotate x-ticks vertically
    
    
    
    #WE PLOT THE POSITION WEIGHT MATRIX FOR THE COMPLEMENTARY MUTATION TYPE
    #Compute the trinucleotide mutation rates for the complementary mutation type
    combined_mutation_dict={}
    for trinucleotide in count_data:
        if(trinucleotide[1]==complementary_mutation_type[0]):
            for alternate_base in count_data[trinucleotide]:
                if(alternate_base==complementary_mutation_type[-1]):
                    combined_mutation_dict[trinucleotide]=count_data[trinucleotide][alternate_base]
    # Initialize the Position Weight Matrix(PWM) for the complementary mutation type
    pwm = np.zeros((3, 4))
    # Compute the Position Weight Matrix(PWM) for the complementary mutation type based on trinucleotide mutation rates
    for trinucleotide in combined_mutation_dict:
        rate=combined_mutation_dict[trinucleotide]
        if(rate>0):
            for i in range(0, len(trinucleotide)):
                if trinucleotide[i] == "A":
                    pwm[i, 0] += rate
                elif trinucleotide[i] == "C":
                    pwm[i, 1] += rate
                elif trinucleotide[i] == "G":
                    pwm[i, 2] += rate
                elif trinucleotide[i] == "T":
                    pwm[i, 3] += rate
    # Normalize the PWM base frquencies
    normalized_pwm = pwm / pwm.sum(axis=1, keepdims=True)
    # Create a DataFrame from the Normalized PWM 
    pwm_df = pd.DataFrame(normalized_pwm, index=[i for i in range(1, 4)], columns=["A", "C", "G", "U"])
    # Create a Logo object to Visualize the Position Weight Matrix(PWM)
    logo_opposite_strand = logomaker.Logo(pwm_df, ax=ax[1,2], shade_below=0.5, fade_below=0.5)
    # Style for better visualization
    logo_opposite_strand.style_spines(visible=False)
    logo_opposite_strand.style_spines(spines=['left', 'bottom'], visible=True)
    logo_opposite_strand.style_xticks(rotation=90, fmt='%d', anchor=0)
    logo_opposite_strand.ax.xaxis.set_ticks_position('none')
    logo_opposite_strand.ax.xaxis.set_tick_params(pad=-1)



    #PLOT THE SORTED TRINUCLEOTIDE MUTATION RATES FOR THE COMPLEMENTARY MUTATION TYPE
    sorted_data = dict(sorted(combined_mutation_dict.items(), key=lambda item: item[1], reverse=True))
    
    # Extract trinucleotides and corresponding trinucleotide mutation rates for plotting
    trinucleotides = [f"{item} ({str(Seq(item).reverse_complement())})".replace("T","U") for item in sorted_data]
    trinucelotide_mutation_rates = [sorted_data[item]*10000 for item in sorted_data]

    # Plotting
    ax[1, 1].barh(trinucleotides, trinucelotide_mutation_rates, color=color2)  
    ax[1, 1].set_xlabel(r'Average Trinucleotide Mutation Rate (Normalized) ($10^{-4}$) $\rightarrow$')  
    ax[1, 1].invert_yaxis()  # Invert y-axis for readability
    ax[1, 1].tick_params(axis='x', rotation=90)  # Rotate x-ticks vertically
    ax[1, 1].set_title(complementary_mutation_type.replace("T", "U"), fontweight='bold', fontsize=18)


    # Adjust the space between subplots
    fig.subplots_adjust(hspace=0.45, wspace=0.3)  


plot_trinucleotide_mutations_per_strand_bar_plot(denovo_mutations_in_sampled_genomes,  r"C$\rightarrow$T", '#d33c33', r"G$\rightarrow$A", '#a5d4e6')
plot_trinucleotide_mutations_per_strand_bar_plot(denovo_mutations_in_sampled_genomes, r"A$\rightarrow$G", '#dff488', r"T$\rightarrow$C", '#accc6f')
plot_trinucleotide_mutations_per_strand_bar_plot(denovo_mutations_in_sampled_genomes, r"G$\rightarrow$T", '#cd7430', r"C$\rightarrow$A", '#5fbceb')
