In [1]:
import numpy as np
import os
import glob
import re
from cogent3.maths.measure import jsd

In [3]:
base_dir = "/Users/gulugulu/Desktop/honours/data_local/whole_genome_mammal87/ensembl_download/genomes"

all_fasta_paths = {}

for species in os.listdir(base_dir):
    fasta_dir = os.path.join(base_dir, species, 'fasta')
    fasta_files = glob.glob(os.path.join(fasta_dir, '*.fa.gz'))
    all_fasta_paths[species] = fasta_files

In [8]:
del all_fasta_paths['.DS_Store']

In [13]:
species_key = list(all_fasta_paths.keys())

In [14]:
import json

with open('/Users/gulugulu/repos/PuningAnalysis/results/nuc_counts.json', 'r') as file:
    nuc_counts = json.load(file)


In [15]:
def extract_info(path):
    match = re.search(r'dna\.(.+?)\.fa\.gz', path)
    if match:
        return match.group(1)
    else:
        return "unknown"



In [16]:
assigned_nuc_counts = {}

for species, paths in all_fasta_paths.items():
    assigned_nuc_counts[species] = {}
    for path, nucleotide_count in zip(paths, nuc_counts[species]):
        info = extract_info(path)
        assigned_nuc_counts[species][info] = nucleotide_count

KeyError: 'ovis_aries_rambouillet'

In [8]:
nuc_freqs = {}
for species, counts in nuc_counts.items():
    nuc_freqs[species] = []
    for count in counts:
        sum_count = sum(count.values())
        freq = {
                'A': count['A'] / sum_count,
                'C': count['C'] / sum_count,
                'G': count['G'] / sum_count,
                'T': count['T'] / sum_count
            }
        nuc_freqs[species].append(freq)


In [9]:
assigned_nuc_freqs = {}

for species, paths in all_fasta_paths.items():
    assigned_nuc_freqs[species] = {}
    for path, nucleotide_freq in zip(paths, nuc_freqs[species]):
        info = extract_info(path)
        assigned_nuc_freqs[species][info] = nucleotide_freq

In [10]:
def calculate_nucleotide_frequency(nucleotide_data):
    # Initialize a dictionary to store the results
    nuc_frequencies = {}
    
    # Iterate over each species in the dictionary
    for species, counts_list in nucleotide_data.items():
        # Initialize the sum for each nucleotide type
        summed_nucleotides = {'A': 0, 'C': 0, 'G': 0, 'T': 0}
        
        # Sum the counts for each nucleotide type
        for counts in counts_list:
            summed_nucleotides['A'] += counts['A']
            summed_nucleotides['C'] += counts['C']
            summed_nucleotides['G'] += counts['G']
            summed_nucleotides['T'] += counts['T']
        
        total_nucleotides = sum(summed_nucleotides.values())

        nuc_frequency = {
                'A': summed_nucleotides['A'] / total_nucleotides,
                'C': summed_nucleotides['C'] / total_nucleotides,
                'G': summed_nucleotides['G'] / total_nucleotides,
                'T': summed_nucleotides['T'] / total_nucleotides
            }
        
        nuc_frequencies[species] = nuc_frequency
    
    return nuc_frequencies



In [11]:
# Calculate the summed counts
total_nucleotide_frequency = calculate_nucleotide_frequency(nuc_counts)

In [13]:
def convert_dict_to_list(species):
    nucleotide_frequencies = total_nucleotide_frequency[species]

    # Extract the frequencies in the order A, C, G, T
    frequency_list = [
        nucleotide_frequencies['A'],
        nucleotide_frequencies['C'],
        nucleotide_frequencies['G'],
        nucleotide_frequencies['T']
    ]

    return frequency_list

In [14]:
def calculate_pairwise_jsd(nucleotide_frequency_dict):
    pairwise_jsd = {}
    species_keys = list(nucleotide_frequency_dict.keys())
    for species_1, nuc_freq in nucleotide_frequency_dict.items():
        for species_2 in species_keys:
                dist1 = convert_dict_to_list(species_1)
                dist2 = convert_dict_to_list(species_2)
                jsd_value = jsd(dist1, dist2)
                key = (species_1, species_2)
                pairwise_jsd[key] = jsd_value
    return pairwise_jsd

In [15]:
all_species_pairwise_jsd = calculate_pairwise_jsd(total_nucleotide_frequency)

In [17]:
from cogent3 import load_tree
import pathlib
from ensembl_lite import Species


path = pathlib.Path("../data/dataset2_ensemble_trees/raw_data/tree_nh_file/91_eutherian_mammals_EPO-Extended_default.nh")
tree_87_mammals= load_tree(path, format=None, underscore_unmunge=False)

def make_db_prefixes():
    return [n.lower().replace(" ", "_") for n in Species.get_species_names()]

db_prefixes = make_db_prefixes()

def find_db_prefix(name):
    for db_prefix in db_prefixes:
        if name.startswith(db_prefix):
            return db_prefix
    return None

db_names = []
probs = []
for tip_name in tree_87_mammals.get_tip_names():
    db_name = find_db_prefix(tip_name.lower())
    if db_name is None:
        probs.append(tip_name)
    else:
        db_names.append(db_name)



In [18]:
pairwise_distance = tree_87_mammals.tip_to_tip_distances()

In [19]:
pairwise_distance_value = pairwise_distance[0]

In [20]:
indices_to_remove = sorted([9, 31, 42, 58], reverse=True)  # Reverse sort to prevent index shifting during removal
new_pairwise_distance_value = np.delete(pairwise_distance_value, indices_to_remove, axis=0)

new_pairwise_distance_value = np.delete(new_pairwise_distance_value, indices_to_remove, axis=1)

pairwise_name = np.array([[(db_names[i], db_names[j]) for j in range(len(db_names))] for i in range(len(db_names))], dtype=object)



In [21]:
pairwise_distance_dict = {(db_names[i], db_names[j]): new_pairwise_distance_value[i][j] for i in range(len(db_names)) for j in range(len(db_names))}


In [22]:
all_distance_jsd_dict = {}

for key in pairwise_distance_dict.keys():
    distance = pairwise_distance_dict[key]
    jsd = all_species_pairwise_jsd[key]
    all_distance_jsd_dict[key] = [distance, jsd]


In [23]:
# New dictionary to store non-redundant data
filtered_distance_jsd_dict = {}

# Iterate over the original dictionary
for (species1, species2), value in all_distance_jsd_dict.items():
    # Sort the tuple to ensure consistent ordering
    sorted_tuple = tuple(sorted((species1, species2)))

    # Check if the sorted tuple is already in the new dictionary
    if sorted_tuple not in filtered_distance_jsd_dict:
        # If not present, add it to the new dictionary
        filtered_distance_jsd_dict[sorted_tuple] = value

correct_order_jsd = []
correct_order_distance = []

for value in filtered_distance_jsd_dict.values():
    correct_order_jsd.append(value[1])
    correct_order_distance.append(value[0])


In [37]:
import plotly.express as px
import pandas as pd

data = {
    'JSD Value': [values[1] for values in filtered_distance_jsd_dict.values()],
    'Distance Value': [values[0] for values in filtered_distance_jsd_dict.values()]
}
df = pd.DataFrame(data)

fig = px.scatter(df, x='JSD Value', y='Distance Value',
                 title='JSD Vs Genetic Distance',
                 labels={'JSD Value': 'JSD Value', 'Distance Value': 'Genetic Distance Value'},
                 width=1000, height=600)




In [38]:
# Save the plot to a PDF file
fig.write_image("../results/all_species_scatter_plot.pdf")
fig.show()

In [39]:
# Update layout if needed
fig.update_layout(
    xaxis=dict(
        title='JSD Value',
        tickmode='array',
        title_font=dict(size=18, family='Arial, sans-serif', color='black')
    ),
    yaxis=dict(
        title='Genetic Distance Value',
        tickmode='array',
        title_font=dict(size=18, family='Arial, sans-serif', color='black')
    )
)

x_max = 0.0004
fig.update_xaxes(range=[-0.00003, x_max])

# Save the plot to a PDF file
fig.show()


In [40]:
probs = ['myotis_lucifugus', 'rhinolophus_ferrumequinum', 'equus_caballus']
species_tuple = [(species_1, species_2) for species_1 in probs for species_2 in probs]

In [41]:
sub_distance_jsd_dict = {k: all_distance_jsd_dict[k] for k in species_tuple}

In [42]:
# New dictionary to store non-redundant data
filtered_sub_distance_jsd_dict = {}

# Iterate over the original dictionary
for (species1, species2), value in sub_distance_jsd_dict.items():
    # Sort the tuple to ensure consistent ordering
    sorted_tuple = tuple(sorted((species1, species2)))

    # Check if the sorted tuple is already in the new dictionary
    if sorted_tuple not in filtered_sub_distance_jsd_dict:
        # If not present, add it to the new dictionary
        filtered_sub_distance_jsd_dict[sorted_tuple] = value

In [43]:
species_pair = ['Self vs Self' , 'Microbat vs Greater horseshoe bat', 'Horse vs Microbat', 'Self vs Self', 'Horse vs Greater horseshoe bat', 'Self vs Self']
data2 = {
    'Species Pair': species_pair,
    'JSD Value': [values[1] for values in filtered_sub_distance_jsd_dict.values()],
    'Distance Value': [values[0] for values in filtered_sub_distance_jsd_dict.values()]
}
df2 = pd.DataFrame(data2)

fig2 = px.scatter(df2, x='JSD Value', y='Distance Value',text='Species Pair',
                 title='JSD Vs Genetic Distance',
                 labels={'JSD Value': 'JSD Value ', 'Distance Value': 'Genetic Distance Value'})

# Update traces to show labels below points and make them bold
fig2.update_traces(textposition='bottom center', textfont={'color':'black', 'size':15, 'family':'Arial, sans-serif'})

# Update layout if needed
fig2.update_layout(
    xaxis=dict(
        title='JSD Value',
        tickmode='array',
        title_font=dict(size=18, family='Arial, sans-serif', color='black')
    ),
    yaxis=dict(
        title='Genetic Distance Value',
        tickmode='array',
        title_font=dict(size=18, family='Arial, sans-serif', color='black')
    )
)

x_max = df2['JSD Value'].max()
x_range_buffer = x_max *0.3
fig2.update_xaxes(range=[-0.00003, x_max + x_range_buffer])

fig2.write_image("../results/horse_microbat_horsesheobat.pdf")


fig2.show()

In [44]:
# from cogent3 import open_
# from cogent3.parse.fasta import MinimalFastaParser
# from collections import Counter


# def counts_nucs(path):
#     with open_(path) as infile:
#         data = infile.readlines()

#     parser = MinimalFastaParser(data)
#     nuc_counts = Counter()
#     for label, seq in parser:
#         nuc_counts.update(seq)
#         del seq # delete the sequence instance

#     # discard non-canonical nucleotides
#     return {b: nuc_counts[b] for b in "ACGT"}



In [45]:
# def counts_labels(path):
#     with open_(path) as infile:
#         data = infile.readlines()

#     parser = MinimalFastaParser(data)
#     label_counts = Counter()

#     for label, seq in parser:
#         first_word_of_label = label.split()[0]
#         label_counts.update({first_word_of_label: 1})
#         del seq 

#     # discard non-canonical nucleotides
#     return label_counts