In [574]:
import numpy as np
import scipy
from clock_project.simulation.magnitude_quantification import calculate_non_stationarity, calculate_ENS, random_nucleotide_distribution, calculate_information, entropy_calculation
from cogent3.maths.measure import jsd
from clock_project.simulation.wts import generate_rate_matrix
import scipy
import json
from clock_project.genome_analysis.yapeng_check_BV import get_bounds_violation, load_param_values
from cogent3.util.deserialise import deserialise_object
import numpy as np
import os
import glob
from cogent3 import get_app

import plotly.express as px



In [575]:
base_dir = '/Users/gulugulu/Desktop/honours/data_local/whole_genome_mammal87/triads_model_fitting_350_threshold'
gene_paths = glob.glob(os.path.join(base_dir, '*/'))
alignment_dir = '/Users/gulugulu/repos/PuningAnalysis/data/ensembl_ortholog_sequences/homologies_alignment_common_name_350_threshold'

In [576]:
def get_jsd_diff(triads_info):
    jsd_diff_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        jsd_dict = triads_info_value['jsd_dict']

        jsd_diff = abs(jsd_dict['JSD_difference'][triads_names['ingroup1']] - jsd_dict['JSD_difference'][triads_names['ingroup2']])
        jsd_diff_dict[identifier] = jsd_diff
    return jsd_diff_dict
        

In [577]:
from scipy.stats import wasserstein_distance

def get_ingroup_jsd(triads_info):
    ingroup_jsd_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        jsd_dict = triads_info_value['jsd_dict']

        ingroup_jsd = jsd_dict['Ingroup_JSD']
        ingroup_jsd_dict[identifier] = ingroup_jsd
    return ingroup_jsd_dict

def get_ingroup_wst(triads_info):
    ingroup_wst_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        nuc_freqs_dict = triads_info_value['nuc_freqs_dict']
        ingroup1_freq = nuc_freqs_dict['ingroup1']
        ingroup2_freq = nuc_freqs_dict['ingroup2']
        internal_node_freq = nuc_freqs_dict['internal_node']
        wst1 = wasserstein_distance(ingroup1_freq, internal_node_freq)
        wst2 = wasserstein_distance(ingroup2_freq, internal_node_freq)
        ingroup_wst = wasserstein_distance(ingroup1_freq, ingroup2_freq)
        ingroup_wst_dict[identifier] = ingroup_wst

    return ingroup_wst_dict

def get_wst_diff(triads_info):
    wst_diff_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        nuc_freqs_dict = triads_info_value['nuc_freqs_dict']
        ingroup1_freq = nuc_freqs_dict['ingroup1']
        ingroup2_freq = nuc_freqs_dict['ingroup2']
        internal_node_freq = nuc_freqs_dict['internal_node']
        wst1 = wasserstein_distance(ingroup1_freq, internal_node_freq)
        wst2 = wasserstein_distance(ingroup2_freq, internal_node_freq)
        wst_diff = abs(wst1-wst2)
        wst_diff_dict[identifier] = wst_diff

    return wst_diff_dict

In [578]:
def get_ingroup_ens_diff(triads_info):
    ens_ingroup_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        ens_dict = triads_info_value['ens']
        ens_ingroup = np.log(ens_dict[triads_names['ingroup1']]/ ens_dict[triads_names['ingroup2']])
        ens_ingroup_dict[identifier] = ens_ingroup
    return ens_ingroup_dict

In [579]:
def get_ingroup_ens_absdiff(triads_info):
    ens_ingroup_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        ens_dict = triads_info_value['ens']
        ens_ingroup = abs(ens_dict[triads_names['ingroup1']]- ens_dict[triads_names['ingroup2']])
        ens_ingroup_dict[identifier] = ens_ingroup
    return ens_ingroup_dict


In [580]:
def get_ingroup_ens_Hellinger(triads_info):
    ens_ingroup_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        ens_dict = triads_info_value['ens']
        ens_ingroup = np.sqrt(2*(np.sqrt(ens_dict[triads_names['ingroup1']])- np.sqrt(ens_dict[triads_names['ingroup2']]))**2)
        ens_ingroup_dict[identifier] = ens_ingroup
    return ens_ingroup_dict

In [581]:
def get_nabla_absdiff(triads_info):
    nabla_diff_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        nabla_dict = triads_info_value['nabla_values']
        nbala_diff = abs(nabla_dict[triads_names['ingroup1']] - nabla_dict[triads_names['ingroup2']])
        nabla_diff_dict[identifier] = nbala_diff
    return nabla_diff_dict

In [582]:
def get_ingroup_nabla(triads_info):
    nabla_ingroup_dict = {}
    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        nabla_dict = triads_info_value['nabla_values']
        nabla_difference = np.log(nabla_dict[triads_names['ingroup1']]/nabla_dict[triads_names['ingroup2']])
        nabla_ingroup_dict[identifier] = nabla_difference
    return nabla_ingroup_dict

In [583]:
from clock_project.maths.evolutionary_rate import calculate_stationary_distribution

def get_STI(triads_info, N):
    matrix_ingroup_dict = {}
    nuc_freq_ingroup_dict = {}
    STI_dict = {}
    chi_squared_dict = {}

    for identifier, info in triads_info.items():
        triads_info_value = info['triads_info_small_tree']
        triads_names = info['triads_species_names']
        matrix_dict = triads_info_value['matrices']
        nuc_freqs_dict = triads_info_value['nuc_freqs_dict']
        nuc_freqs_ingroup_pair = {
            triads_names['ingroup1']: nuc_freqs_dict['ingroup1'],
            triads_names['ingroup2']: nuc_freqs_dict['ingroup2']
        }
        nuc_freq_ingroup_dict[identifier] = nuc_freqs_ingroup_pair
        matrix_ingroup_pair = {
            triads_names['ingroup1']: matrix_dict[triads_names['ingroup1']],
            triads_names['ingroup2']: matrix_dict[triads_names['ingroup2']]
        }
        matrix_ingroup_dict[identifier] = matrix_ingroup_pair
        stationary_distribution_dict = {
            triads_names['ingroup1']: calculate_stationary_distribution(matrix_dict[triads_names['ingroup1']]),
            triads_names['ingroup2']: calculate_stationary_distribution(matrix_dict[triads_names['ingroup2']])
        }
        
        # Calculate STI values for each ingroup species
        for species, freqs in nuc_freqs_ingroup_pair.items():
            π = stationary_distribution_dict[species]
            p = freqs
            ΔC = p[1] - π[1]  # C
            ΔG = p[3] - π[3]  # G
            ΔA = p[2] - π[2]  # A
            ΔT = p[0] - π[0]  # T
            
            STI1 = ΔC + ΔG
            STI2 = ΔA - ΔT
            STI3 = ΔC - ΔG
            
            chi_squared = N * sum(((p[i] - π[i]) ** 2) / π[i] if π[i] > 0 else 0 for i in range(4))
            
            if identifier not in STI_dict:
                STI_dict[identifier] = {}
            if identifier not in chi_squared_dict:
                chi_squared_dict[identifier] = {}
                
            STI_dict[identifier][species] = (STI1, STI2, STI3)
            chi_squared_dict[identifier][species] = chi_squared

    return STI_dict, chi_squared_dict


In [584]:
alignment_length_dict = {}
alignment_dir_paths = glob.glob(os.path.join(alignment_dir, '*.json'))
for path in alignment_dir_paths:
    file_name = os.path.basename(path).rsplit('.', 1)[0]
    alignment = deserialise_object(json.load(open(path, 'r')))
    alignment_length_dict[file_name] = alignment.get_lengths()[0]


In [585]:
def process_path_STI(path):
    file_name = os.path.basename(path).rsplit('.', 1)[0]
    triads_data_path = os.path.join(path, 'triads_info_dict.json')
    triads_info = load_json_data(triads_data_path)
    alignment_length = alignment_length_dict[file_name]
    STI_dict, chi_squared_dict = get_STI(triads_info, alignment_length)
    return STI_dict, chi_squared_dict
    

In [586]:
def remove_outliers_iqr(data1, data2):
    def compute_iqr_bounds(data):
        Q1 = np.percentile(data, 25)
        Q3 = np.percentile(data, 75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        return lower_bound, upper_bound

    # Calculate IQR bounds for both lists
    lower_bound1, upper_bound1 = compute_iqr_bounds(data1)
    lower_bound2, upper_bound2 = compute_iqr_bounds(data2)

    # Filter out pairs where either value is an outlier
    filtered_data1 = []
    filtered_data2 = []

    for val1, val2 in zip(data1, data2):
        if (lower_bound1 <= val1 <= upper_bound1) and (lower_bound2 <= val2 <= upper_bound2):
            filtered_data1.append(val1)
            filtered_data2.append(val2)

    return filtered_data1, filtered_data2

In [587]:
import os
import json
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Function to load JSON data
def load_json_data(path):
    with open(path, 'r') as file:
        return json.load(file)

# Function to compute required values
def compute_values(path):
    nabla_log_ratio_list = []
    ens_log_ratio_list = []

    triads_data_path = os.path.join(path, 'triads_info_dict.json')
    triads_info = load_json_data(triads_data_path)
    nabla_log_ratio_dict = get_ingroup_nabla(triads_info)
    ens_log_ratio_dict = get_ingroup_ens_diff(triads_info)
    ens_abs_diff_dict = get_ingroup_ens_absdiff(triads_info)
    jsd_diff_dict = get_jsd_diff(triads_info)
    ingroup_jsd_dict = get_ingroup_jsd(triads_info)
    ingroup_wst_dict  = get_ingroup_wst(triads_info)
    wst_diff_dict = get_wst_diff(triads_info)
    ens_hellinger_dict = get_ingroup_ens_Hellinger(triads_info)
    nabla_absdiff_dict = get_nabla_absdiff(triads_info)
    nabla_log_ratio_list = list(nabla_log_ratio_dict.values())
    ens_abs_diff_list = list(ens_abs_diff_dict.values())
    ens_log_ratio_list = list(ens_log_ratio_dict.values())
    jsd_diff_list = list(jsd_diff_dict.values())
    ingroup_jsd_list = list(ingroup_jsd_dict.values())
    ingroup_wst_list = list(ingroup_wst_dict.values())
    nabla_absdiff_list = list(nabla_absdiff_dict.values())
    ens_hellinger_list = list(ens_hellinger_dict.values())
    wst_diff_list = list(wst_diff_dict.values())

    return nabla_log_ratio_list, ens_log_ratio_list, ens_abs_diff_list, jsd_diff_list, ingroup_jsd_list, ingroup_wst_list, wst_diff_list, nabla_absdiff_list, ens_hellinger_list

# Initialize the dictionary to store the data
gene_data_dict = {}

# Populate the dictionary with data for each gene
for path in gene_paths:
    gene_name = os.path.basename(path.rstrip('/'))
    nabla_log_ratio_list, ens_log_ratio_list, ens_abs_diff_list, jsd_diff_list, ingroup_jsd_list, ingroup_wst_list, wst_diff_list, nabla_absdiff_list, ens_hellinger_list = compute_values(path)
    gene_data_dict[gene_name] = {
        'nabla_log_ratio': nabla_log_ratio_list,
        'ens_log_ratio': ens_log_ratio_list,
        'ens_abs_diff': ens_abs_diff_list,
        'jsd_diff': jsd_diff_list, 
        'ingroup_jsd': np.sqrt(ingroup_jsd_list),
        'ingroup_wst': ingroup_wst_list,
        'wst_diff': wst_diff_list,
        'nabla_absdiff': nabla_absdiff_list,
        'ens_hellinger': ens_hellinger_list
    }






In [588]:
for path in gene_paths:
    gene_name = os.path.basename(path.rstrip('/'))
    triads_data_path = os.path.join(path, 'triads_info_dict.json')
    triads_info = load_json_data(triads_data_path)
    for identifier, info in triads_info.items():
        ens_dict = info['triads_info_small_tree']['ens']
        for value in ens_dict.values():
            if value > 1:
                print(gene_name, identifier)

    
    

ENSG00000117322 59
ENSG00000214655 240


In [589]:

correlation_list1 = []
p_value_list1 = []
for gene, lists in gene_data_dict.items():
    nabla_log_ratio_list, ens_log_ratio_list = remove_outliers_iqr(lists['nabla_log_ratio'], lists['ens_log_ratio'])
    #Add the correlation factor in the list
    cor, p_value = scipy.stats.pearsonr(np.abs(list(ens_log_ratio_list)), np.abs(list(nabla_log_ratio_list)))
    correlation_list1.append(cor)
    p_value_list1.append(p_value)

fig1 = px.box(correlation_list1, title= 'Pearson Correlation Test Nabla Log Ratio Vs. ENS Log Ratio')
# Update layout with labels and title
fig1.update_layout(
    template='plotly_white',
    margin=dict(l=20, r=20, t=50, b=20),
    autosize=True,
    yaxis_title_font={'size': 20},  
    xaxis_title_font={'size': 20}, 
    xaxis_title='Correlation Coefficient',
    yaxis_title='Value',
    width=None
)
fig1.show()
# fig1.write_image('nabla_ens_cc.pdf')

In [590]:
gene_names = list(gene_data_dict.keys())
# Create the bar chart
fig = go.Figure(go.Bar(
    x=correlation_list1,
    y=gene_names,
    orientation='h',  # Horizontal bar chart
    marker=dict(color=p_value_list1, coloraxis="coloraxis")  # Color based on the correlation factor
))
fig.add_shape(
    type="line",
    x0=0.2, y0=0, x1=0.2, y1=len(gene_names),
    line=dict(color="yellow", width=3, dash="dashdot"),
)

fig.add_shape(
    type="line",
    x0=-0.2, y0=0, x1=-0.2, y1=len(gene_names),
    line=dict(color="yellow", width=3, dash="dashdot"),
)

# Update layout for better visualization
fig.update_layout(
    xaxis_title="Correlation Coefficient",
    yaxis_title="Gene ID",
    coloraxis=dict(colorscale='Viridis'),  # Red for higher, blue for lower values
    height=1200  # Adjust height based on the number of genes to avoid squeezing
)

fig.show()
# fig.write_image('nabla_ens_cc_bar.pdf')

In [591]:
count = len([x for x in correlation_list1 if abs(x) > 0.2])
count

21

In [592]:
count = len([i for i in range(len(p_value_list1)) if p_value_list1[i] < 0.0003937 and correlation_list1[i] > 0.2])
count

8

In [593]:
correlation_list4 = {}
p_value_list4 = {}
for gene, lists in gene_data_dict.items():
    wst_diff_list, ens_abs_diff_list = remove_outliers_iqr(lists['wst_diff'], lists['ens_abs_diff'])

    #Add the correlation factor in the list
    cor, p_value = scipy.stats.pearsonr(wst_diff_list, ens_abs_diff_list)
    correlation_list4[gene] = cor
    p_value_list4[gene] = p_value 

fig4 = px.box(correlation_list4.values(), title= 'Pearson Correlation Test WST Difference Vs. ENS difference')
# Update layout with labels and title
fig4.update_layout(
    template='plotly_white',
    margin=dict(l=20, r=20, t=50, b=20),
    autosize=True,
    yaxis_title_font={'size': 20},  
    xaxis_title_font={'size': 20}, 
    xaxis_title='Correlation Coefficient',
    yaxis_title='Value',
    width=None
)
fig4.show()
# fig2.write_image('jsd_ens_diff_cc.pdf')

In [594]:
gene_names = list(gene_data_dict.keys())
# Create the bar chart
fig = go.Figure(go.Bar(
    x=list(correlation_list4.values()),
    y=gene_names,
    orientation='h',  # Horizontal bar chart
    marker=dict(color=list(p_value_list4.values()), coloraxis="coloraxis")  # Color based on the correlation factor
))
fig.add_shape(
    type="line",
    x0=0.2, y0=0, x1=0.2, y1=len(gene_names),
    line=dict(color="yellow", width=3, dash="dashdot"),
)


# Update layout for better visualization
fig.update_layout(
    xaxis_title="Correlation Coefficient",
    yaxis_title="Gene ID",
    coloraxis=dict(colorscale='Viridis'),  # Red for higher, blue for lower values
    height=1200  # Adjust height based on the number of genes to avoid squeezing
)

fig.show()
# fig.write_image('jsd_ens_diff_cc_bar.pdf')

In [595]:
count = len([x for x in p_value_list4.values() if x < 0.05])
count


103

In [596]:
count = len([x for x in correlation_list4.values() if x > 0.3])
count

76

In [597]:
correlation_list2 = {}
p_value_list2 = {}
for gene, lists in gene_data_dict.items():
    jsd_diff_list, ens_abs_diff_list = remove_outliers_iqr(lists['jsd_diff'], lists['ens_abs_diff'])

    #Add the correlation factor in the list
    cor, p_value = scipy.stats.pearsonr(jsd_diff_list, ens_abs_diff_list)
    correlation_list2[gene] = cor
    p_value_list2[gene] = p_value 

fig2 = px.box(correlation_list2.values(), title= 'Pearson Correlation Test JSD Difference Vs. ENS difference')
# Update layout with labels and title
fig2.update_layout(
    template='plotly_white',
    margin=dict(l=20, r=20, t=50, b=20),
    autosize=True,
    yaxis_title_font={'size': 20},  
    xaxis_title_font={'size': 20}, 
    xaxis_title='Correlation Coefficient',
    yaxis_title='Value',
    width=None
)
fig2.show()
# fig2.write_image('jsd_ens_diff_cc.pdf')

In [598]:
gene_names = list(gene_data_dict.keys())
# Create the bar chart
fig = go.Figure(go.Bar(
    x=list(correlation_list2.values()),
    y=gene_names,
    orientation='h',  # Horizontal bar chart
    marker=dict(color=list(p_value_list2.values()), coloraxis="coloraxis")  # Color based on the correlation factor
))
fig.add_shape(
    type="line",
    x0=0.2, y0=0, x1=0.2, y1=len(gene_names),
    line=dict(color="yellow", width=3, dash="dashdot"),
)


# Update layout for better visualization
fig.update_layout(
    xaxis_title="Correlation Coefficient",
    yaxis_title="Gene ID",
    coloraxis=dict(colorscale='Viridis'),  # Red for higher, blue for lower values
    height=1200  # Adjust height based on the number of genes to avoid squeezing
)

fig.show()
# fig.write_image('jsd_ens_diff_cc_bar.pdf')

In [599]:
count = len([x for x in p_value_list2.values() if x < 0.05])
count

93

In [600]:
count = len([x for x in correlation_list2.values() if x > 0.3])
count

65

In [601]:
gene_list = ['ENSG00000048707',
 'ENSG00000143479',
 'ENSG00000087365',
 'ENSG00000116183',
 'ENSG00000116991',
 'ENSG00000110075',
 'ENSG00000265203',
 'ENSG00000107560',
 'ENSG00000150347',
 'ENSG00000143278',
 'ENSG00000166189',
 'ENSG00000186603',
 'ENSG00000119285',
 'ENSG00000154309',
 'ENSG00000117322',
 'ENSG00000197147',
 'ENSG00000138161',
 'ENSG00000166349',
 'ENSG00000109927',
 'ENSG00000149782',
 'ENSG00000187554',
 'ENSG00000135775',
 'ENSG00000064309',
 'ENSG00000107862',
 'ENSG00000116128',
 'ENSG00000042781',
 'ENSG00000116539',
 'ENSG00000117000',
 'ENSG00000185875',
 'ENSG00000065613',
 'ENSG00000126705',
 'ENSG00000170322',
 'ENSG00000127124',
 'ENSG00000091664',
 'ENSG00000135372',
 'ENSG00000162711',
 'ENSG00000143669',
 'ENSG00000198730',
 'ENSG00000129159',
 'ENSG00000171492',
 'ENSG00000283703',
 'ENSG00000198198',
 'ENSG00000160703',
 'ENSG00000187486',
 'ENSG00000181333',
 'ENSG00000166341']

In [602]:
correlation_list3 = {}
p_value_list3 = {}
for gene, lists in gene_data_dict.items():
    ingroup_jsm_list, ens_abs_diff_list = remove_outliers_iqr(lists['ingroup_jsd'], lists['ens_abs_diff'])

    #Add the correlation factor in the list
    cor, p_value = scipy.stats.pearsonr(ingroup_jsm_list, ens_abs_diff_list)
    correlation_list3[gene] = cor
    p_value_list3[gene] = p_value 

fig3 = px.box(correlation_list3.values(), title= 'Pearson Correlation Test In-group JSD Vs. ENS difference')
# Update layout with labels and title
fig3.update_layout(
    template='plotly_white',
    margin=dict(l=20, r=20, t=50, b=20),
    autosize=True,
    yaxis_title_font={'size': 20},  
    xaxis_title_font={'size': 20}, 
    xaxis_title='Correlation Coefficient',
    yaxis_title='Value',
    width=None
)
fig3.show()
# fig2.write_image('jsd_ens_diff_cc.pdf')



In [603]:
# correlation_list3_subset = {gene: correlation_list3[gene] for gene in gene_list}
# p_value_list3_subset = {gene: p_value_list3[gene] for gene in gene_list}

In [604]:
count = len([x for x in p_value_list3.values() if x < 0.05])
count

120

In [605]:
count = len([x for x in correlation_list3.values() if x > 0.3])
count

87

In [606]:
gene_names = list(gene_data_dict.keys())
# Create the bar chart
fig = go.Figure(go.Bar(
    x=list(correlation_list3.values()),
    y=gene_names,
    orientation='h',  # Horizontal bar chart
    marker=dict(color=list(p_value_list3.values()), coloraxis="coloraxis")  # Color based on the correlation factor
))
fig.add_shape(
    type="line",
    x0=0.3, y0=0, x1=0.3, y1=len(gene_names),
    line=dict(color="yellow", width=3, dash="dashdot"),
)


# Update layout for better visualization
fig.update_layout(
    xaxis_title="Correlation Coefficient",
    yaxis_title="Gene ID",
    coloraxis=dict(colorscale='Viridis'),  # Red for higher, blue for lower values
    height=1200  # Adjust height based on the number of genes to avoid squeezing
)

fig.show()
# fig.write_image('ingroup_jsd_ens_cc_bar.pdf')

In [607]:
ancester_distirbution_dict = json.load(open('/Users/gulugulu/Desktop/honours/data_local/whole_genome_mammal87/triads_350_threshold/internal_root_distributions.json', 'r'))

information_dict = {}
for gene_id, distirbutions_info in ancester_distirbution_dict.items():
    information_dict[gene_id] = {}
    for identifier, distirbution in distirbutions_info.items():
        information_dict[gene_id][identifier] = calculate_information(distirbution)

information_list = [value for gene_id in information_dict for value in information_dict[gene_id].values()]
average_information_dict = {}
for gene, informations in information_dict.items():
    average_information_dict[gene] = np.average(list(informations.values()))



In [608]:
# import plotly.express as px
# corr_dict = {gene: correlation_list3[gene] for gene in average_information_dict.keys()}
# information_correlation_fig = px.scatter(x = list(average_information_dict.values()), y = list(corr_dict.values()), labels={'x':'Information (Average)', 'y':'Correlation Coefficient'}, trendline="ols", title= None)
# # Update layout with labels and title
# information_correlation_fig.update_layout(
#     template='plotly_white',
#     margin=dict(l=20, r=20, t=50, b=20),
#     autosize=True,
#     yaxis_title_font={'size': 20},  
#     xaxis_title_font={'size': 20}, 
#     width=None 
# )
# information_correlation_fig.show()

In [609]:


def plot_gene_data(gene_data_dict, xcol, ycol):
    keys = list(gene_data_dict.keys())
    rows = int(len(keys) ** 0.5) + 1  # Calculate the number of rows for subplots
    cols = (len(keys) + rows - 1) // rows  # Calculate the number of columns

    fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'{key}' for key in keys])
    
    # Populate subplots
    for index, key in enumerate(keys, start=1):
        gene_data = gene_data_dict[key]
        x_value, y_value = remove_outliers_iqr(gene_data[xcol], gene_data[ycol])

        row = (index - 1) // cols + 1
        col = (index - 1) % cols + 1
        
        fig.add_trace(
            go.Scatter(
                x=x_value,
                y=y_value,
                mode='markers',
                name=f'{key}'
            ),
            row=row,
            col=col
        )
        
        # Adding a trend line
        fig.add_trace(
            go.Scatter(
                x=x_value,
                y=np.poly1d(np.polyfit(x_value, y_value, 1))(x_value),
                mode='lines',
                name=f'Trend {key}',
                line=dict(color='red')
            ),
            row=row,
            col=col
        )
        
        # Update axis properties
        fig.update_xaxes(title_text=xcol if row == rows else "", row=row, col=col)
        fig.update_yaxes(title_text=ycol if col == 1 else "", row=row, col=col)
    
    fig.update_layout(
        height=300 * rows,  # Set a reasonable height based on the number of rows
        width=300 * cols,   # Set a reasonable width based on the number of columns
        showlegend=False
    )
    
    return fig

In [619]:
# Usage example
fig = plot_gene_data(gene_data_dict, 'ingroup_jsd', 'ens_abs_diff')

fig.update_layout(
title_text="Scatter Plots of Ingroup JSD vs. ENS Difference",)
fig.show()

In [611]:
path = '/Users/gulugulu/Desktop/honours/data_local/whole_genome_mammal87/triads_model_fitting_350_threshold/ENSG00000065613'


In [612]:
triads_data_path = os.path.join(path, 'triads_info_dict.json')
triads_info = load_json_data(triads_data_path)
ens_ingroup_list = []
ingroup_jsd_list = []
for identifier, info in triads_info.items():
    triads_info_value = info['triads_info_small_tree']
    triads_names = info['triads_species_names']
    ingroup_jsd = triads_info_value['jsd_dict']['Ingroup_JSD']
    ens_dict = triads_info_value['ens']
    ens_ingroup = abs(ens_dict[triads_names['ingroup1']] - ens_dict[triads_names['ingroup2']])
    ens_ingroup_list.append(ens_ingroup)
    ingroup_jsd_list.append(ingroup_jsd)


ens_ingroup_list2, ingroup_jsd_list2 = remove_outliers_iqr(ens_ingroup_list, ingroup_jsd_list)


In [613]:
indices_ens_diff2 = list(range(1, len(ens_ingroup_list2) + 1))
indices_ingroup_jsd2 = list(range(1, len(ingroup_jsd_list2) + 1))

list_pair = []
for i in range(len(indices_ens_diff2)):
    jsd, ens = ingroup_jsd_list2[i], ens_ingroup_list2[i]
    list_pair.append((jsd, ens))

# Create the scatter plot
fig1 = go.Figure()

# Add ENS Differences scatter plot
fig1.add_trace(go.Scatter(
    x=indices_ingroup_jsd2, y=sorted(np.sqrt(ens_ingroup_list2)), mode='markers',
    marker=dict(size=4), name='ENS Differences'))

fig1.add_trace(go.Scatter(
    x=indices_ingroup_jsd2, y=sorted(np.sqrt(ingroup_jsd_list2)), mode='markers',
    marker=dict(size=4), name='Ingroup JSM'))


# Update layout for clear visualization
fig1.update_layout(
    title='Uniform distirbution of each property',
    xaxis_title='Index',
    yaxis_title='Values',
    showlegend=True,
    
)

fig1.show()

In [614]:
sorted_data = sorted(list_pair, key=lambda x: x[1])

In [615]:
x_values = [x[0] for x in sorted_data]
y_values = [x[0] for x in sorted_data]

In [616]:
# Create the scatter plot
fig1 = go.Figure()

# Add ENS Differences scatter plot
fig1.add_trace(go.Scatter(
    x=np.sqrt(x_values), y=sorted(np.sqrt(y_values)), mode='markers',
    marker=dict(size=4), name='ENS Differences'))



# Update layout for clear visualization
fig1.update_layout(
    title='Uniform distirbution of each property',
    xaxis_title='Index',
    yaxis_title='Values',
    showlegend=True,
    
)

fig1.show()

In [617]:
import numpy as np
import scipy.stats as stats

def permute_test_correlation(x, y, n_permutations=10000):
    # Calculate the actual correlation
    actual_corr, _ = stats.pearsonr(x, y)
    
    # To hold the permuted correlations
    permuted_corrs = []

    original_y = np.copy(y)
    
    # Permutation test
    for _ in range(n_permutations):
        # Shuffle one of the lists
        np.random.shuffle(y)
        # Compute the correlation of the permuted data
        perm_corr, _ = stats.pearsonr(x, y)
        permuted_corrs.append(perm_corr)

        y[:] = original_y
    
    # Compute p-value: proportion of permuted correlations as extreme as the actual one
    p_value = np.sum(np.abs(permuted_corrs) >= np.abs(actual_corr)) / n_permutations
    
    return actual_corr, p_value, permuted_corrs





In [618]:
# Run the permutation test
actual_corr, p_value = permute_test_correlation(distance_dict1, distance_dict2)
print("Actual Correlation:", actual_corr)
print("P-value:", p_value)


ValueError: too many values to unpack (expected 2)