In [1]:
import pandas as pd

import plotly.express as px
from ipywidgets import interact, interactive, fixed, interact_manual, VBox
import ipywidgets as widgets

import plotly.graph_objects as go
import plotly.figure_factory as ff

import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage
import random
import string
from itables import init_notebook_mode

In [2]:
locus = "all"
# original table
afnd_a = pd.read_csv("AFND_data_locus_{}.csv".format(locus))
afnd_a = afnd_a[afnd_a.columns[1:]]

## all alleles existing in the DB
alleles = afnd_a["Allele"].unique()
#print("Total of {} distinct alleles".format(len(alleles)))
alleles_df = afnd_a[["Allele", "Allele - original"]].drop_duplicates().reset_index()

## all populations extracted
populations = afnd_a["Population"].unique()
#print("Total of {} distinct populations".format(len(populations)))
populations_df = afnd_a[["Population", "country", "continent", "Sample Size", "Location"]].drop_duplicates().reset_index()
#populations_df
populations = populations_df["Population"].unique()

## all countries extracted
countries = populations_df["country"].unique()
#print("Total of {} distinct countries".format(len(countries)))
#print(countries)

# all continents extracted
continents = populations_df["continent"].unique()
#print("Total of {} distinct continents".format(len(continents)))
#print(continents)

In [3]:
pop_freq_info = afnd_a[["Population", "Allele", "Allele Frequency", "% of individuals that have the allele"]].drop_duplicates().reset_index()
#pop_freq_info
# merge all
full_df = populations_df.merge(alleles_df, how='cross')
# fill in frequency information
full_df = full_df.merge(pop_freq_info, how='outer', on=["Population", "Allele"]).reset_index()
full_df["Allele Frequency"] = full_df["Allele Frequency"].fillna(0)
full_df = full_df[["Population", "country", "continent", "Sample Size", "Allele", "Allele Frequency"]]
full_df = full_df.drop_duplicates().reset_index()
#full_df.to_csv("AFND_population_frequencies.csv")

In [4]:
# all countries calculate the probability distribution

selected_data = full_df
distributions = {}
for p in populations:
    #print(c)
    selected_data_p = selected_data[selected_data["Population"]==p]
    tmp_df = selected_data_p.sort_values(by=["Population", "Allele"])
    #tmp_df.columns = ["country", "{} - alleles".format(c), "{} - allele freq".format(c)]
    #tmp_df = tmp_df.groupby("country").head(thr).set_index("country").T
    #display(tmp_df)
    distributions[p] = list(tmp_df["Allele Frequency"])
    #print(list(tmp_df["avg_frequency".format(c)]))
X_freq =np.array(list(distributions.values()))
full_df["Locus"] = full_df.Allele.apply(lambda x: x[0])
full_df = full_df.sort_values(["Population", "Locus", "Allele"])

observations = X_freq
dist_function = "jensenshannon"
linkage_method = "ward"
color_threshold = 1.5

In [16]:


fig = ff.create_dendrogram(observations, 
                           orientation='bottom', 
                           labels = populations,
                           linkagefun=lambda x: linkage(x, linkage_method), 
                           distfun=lambda x: x,
                           color_threshold = color_threshold,
                           )
fig.update_layout(width=6000)
fig.update_layout(height=500)
fig.update_xaxes(tickangle=45)
fig.show()

In [15]:
# Initialize figure by creating upper dendrogram
fig = ff.create_dendrogram(observations, 
                           orientation='bottom', 
                           labels = populations,
                           linkagefun=lambda x: linkage(x, linkage_method), 
                           distfun=lambda x: x,
                           color_threshold = color_threshold
                           )

fig.for_each_trace(lambda trace: trace.update(visible=False))
#x_axis = dendro_side['layout']['xaxis']['ticktext']
for i in range(len(fig['data'])):
    fig['data'][i]['yaxis'] = 'y2'
    
# Create Side Dendrogram
# dendro_side = ff.create_dendrogram(X, orientation='right', labels = labels)
dendro_side = ff.create_dendrogram(observations, 
                           orientation='right', 
                           #labels = countries,
                           linkagefun=lambda x: linkage(x, linkage_method), 
                           distfun=lambda x: x,
                           color_threshold = color_threshold
                           )

for i in range(len(dendro_side['data'])):
    dendro_side['data'][i]['xaxis'] = 'x2'
# Add Side Dendrogram Data to Figure
for data in dendro_side['data']:
    fig.add_trace(data)

    
# Create Heatmap
dendro_leaves = dendro_side['layout']['yaxis']['ticktext']
dendro_leaves = list(map(int, dendro_leaves))
#data_dist = pdist(df.values)
#heat_data = squareform(data_dist)
data_dist = pdist(observations, dist_function)
heat_data = squareform(data_dist)

heat_data = heat_data[dendro_leaves,:]
heat_data = heat_data[:,dendro_leaves]

heatmap = [
    go.Heatmap(
        x = dendro_leaves,
        y = dendro_leaves,
        z = heat_data,
        colorscale = 'Blues'
    )
]

heatmap[0]['x'] = fig['layout']['xaxis']['tickvals']
heatmap[0]['y'] = dendro_side['layout']['yaxis']['tickvals']

# Add Heatmap Data to Figure
for data in heatmap:
    fig.add_trace(data)

# Edit Layout
fig.update_layout({'width':800, 'height':800,
                         'showlegend':False, 'hovermode': 'closest',
                         })
   
# Edit xaxis
fig.update_layout(xaxis={'domain': [.15, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'ticks':""})
# Edit xaxis2
fig.update_layout(xaxis2={'domain': [0, .15],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

# Edit yaxis
fig.update_layout(yaxis={'domain': [0, 1],
                                  'mirror': False,
                                  'showgrid': False,
                                  'showline': False,
                                  'zeroline': False,
                                  'showticklabels': False,
                                  'ticks': ""
                        })
# # Edit yaxis2
fig.update_layout(yaxis2={'domain':[.825, .975],
                                   'mirror': False,
                                   'showgrid': False,
                                   'showline': False,
                                   'zeroline': False,
                                   'showticklabels': False,
                                   'ticks':""})

fig.update_layout(paper_bgcolor="rgba(0,0,0,0)",
                  plot_bgcolor="rgba(0,0,0,0)",
                  xaxis_tickfont = dict(color = 'rgba(0,0,0,0)'))
fig.update_layout(height=1000)
fig.update_layout(width=1000)

fig.show()