In [23]:
import os
import time
import pandas as pd
import numpy as np
import networkx as nx
import collections
from scipy import sparse as sp
from scipy.stats import rankdata

import itertools
from itertools import combinations, combinations_with_replacement, cycle
from functools import reduce

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from util import *

import colorcet as cc

import bokeh
from bokeh.io import output_notebook, output_file, show, save
from bokeh.plotting import figure
from bokeh.models import (Rect, MultiLine, Circle, Span, Label,
                          GraphRenderer, StaticLayoutProvider,
                          NodesAndLinkedEdges,
                          HoverTool, TapTool, ColumnDataSource,
                          LinearColorMapper, LogColorMapper, CategoricalColorMapper,
                          CategoricalMarkerMapper,
                          BoxSelectTool,
                          ColorBar, BasicTicker, BoxZoomTool, FactorRange,
                          Range1d)
from bokeh.models import CategoricalTicker, FixedTicker, BoxAnnotation
from bokeh.models import Arrow, NormalHead, OpenHead, VeeHead, LabelSet

from bokeh.transform import transform, factor_cmap, linear_cmap, log_cmap
from bokeh.layouts import row, column, gridplot
output_notebook()

In [24]:
# Pull in dataframes and filter WB
df_in = pd.read_csv('oviIN/preprocessed_inputs-v1.2.1/preprocessed_nodes.csv', index_col= 1)
wb = pd.read_csv('hemibrain/preprocessed-v1.2/preprocessed_nodes.csv')
# bodyIds from the inputs
ids = df_in['id']

# Filter the wb data to only include the bodyIds from the inputs
wb = wb[wb['id'].isin(ids)]
wb

Unnamed: 0,id,0.75,0.05,0.1,0.25,0.5,1.0,0.0,instance,celltype,...,size,status,cropped,statusLabel,cellBodyFiber,somaRadius,somaLocation,inputRois,outputRois,roiInfo
32,263674097,19,3,8,14,16,22,3,LHPD2a5_a_R,LHPD2a5_a,...,408560985,Traced,False,Roughly traced,PDL06,268.5,"[5386, 20096, 4080]","['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'INP', 'LH(R)'...","['INP', 'LH(R)', 'SCL(R)', 'SIP(R)', 'SMP(R)',...","{'SNP(R)': {'pre': 121, 'post': 330, 'downstre..."
53,266187480,30,3,8,4,23,35,3,SMP349_R,SMP349,...,563941715,Traced,False,Traced,PDM07,238.5,"[18808, 27714, 4256]","['SIP(R)', 'SLP(R)', 'SMP(R)', 'SNP(R)']","['SIP(R)', 'SLP(R)', 'SMP(R)', 'SNP(R)']","{'SNP(R)': {'pre': 190, 'post': 854, 'downstre..."
55,266187559,32,3,8,4,25,37,3,SLP399_R,SLP399,...,539797068,Traced,False,Roughly traced,PDM07,290.5,"[17838, 26568, 3924]","['LH(R)', 'SLP(R)', 'SMP(R)', 'SNP(R)']","['SLP(R)', 'SMP(R)', 'SNP(R)']","{'SNP(R)': {'pre': 216, 'post': 744, 'downstre..."
70,267214250,42,3,8,22,33,47,3,pC1b_R,pC1b,...,3805489752,Traced,False,Traced,PDM09,446.5,"[18931, 10896, 14728]","['AOTU(R)', 'AVLP(R)', 'ICL(R)', 'INP', 'SCL(R...","['AVLP(R)', 'INP', 'SCL(R)', 'SIP(R)', 'SLP(R)...","{'SNP(R)': {'pre': 547, 'post': 2474, 'downstr..."
72,267223104,44,6,7,19,34,49,4,SMP025_R,SMP025,...,383479545,Traced,False,Roughly traced,ADL09,286.0,"[3281, 26379, 16668]","['SIP(R)', 'SLP(R)', 'SMP(R)', 'SNP(R)']","['SIP(R)', 'SLP(R)', 'SMP(R)', 'SNP(R)']","{'SNP(R)': {'pre': 97, 'post': 410, 'downstrea..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21721,5901225755,2285,1,1,98,1221,3714,1,,,...,809456817,Traced,False,Roughly traced,,,,"['ATL(R)', 'IB', 'ICL(R)', 'INP', 'IPS(R)', 'L...","['ATL(R)', 'IB', 'ICL(R)', 'INP', 'IPS(R)', 'L...","{'INP': {'pre': 123, 'post': 475, 'downstream'..."
21725,5901227238,2672,7,33,107,1170,3208,1,,,...,773291310,Traced,False,Roughly traced,,,,"['AL(R)', 'AL-VP2(R)', 'EPA(R)', 'IPS(R)', 'LA...","['EPA(R)', 'IPS(R)', 'LAL(-GA)(R)', 'LAL(R)', ...","{'LX(R)': {'pre': 42, 'post': 557, 'downstream..."
21731,5901232053,155,3,8,4,65,178,3,SMP272(PDL21)_L,SMP272,...,1421485085,Traced,False,Roughly traced,,,,"['CRE(-RUB)(L)', 'CRE(L)', 'INP', 'LAL(L)', 'L...","['CRE(-RUB)(L)', 'CRE(L)', 'INP', 'SCL(L)', 'S...","{'SNP(L)': {'pre': 464, 'post': 967, 'downstre..."
21732,6400000773,328,3,8,4,55,401,3,SMP411_R,SMP411,...,503262274,Traced,False,Roughly traced,PDM09,321.5,"[22433, 11755, 18464]","['INP', 'LH(R)', 'MB(+ACA)(R)', 'PLP(R)', 'SCL...","['MB(+ACA)(R)', 'PLP(R)', 'SIP(R)', 'SLP(R)', ...","{'SNP(R)': {'pre': 180, 'post': 467, 'downstre..."


In [25]:
def joint_marginal(df, c1, c2, include_fraction=False):
    """Given a dataframe and two columns, return a dataframe with the joint and marginal counts."""
    j = df.value_counts([c1, c2])
    j.name = "joint_count"
    j = j.reset_index()

    m1 = df.value_counts(c1)
    m1.name = f"{c1}_count"
    j = j.merge(m1, left_on=c1, right_index=True)

    m2 = df.value_counts(c2)
    m2.name = f"{c2}_count"
    j = j.merge(m2, left_on=c2, right_index=True)

    if include_fraction:
        j["joint_fraction"] = j["joint_count"] / j["joint_count"].sum()
        j[f"{c1}_fraction"] = j["joint_count"] / j[f"{c1}_count"]
        j[f"{c2}_fraction"] = j["joint_count"] / j[f"{c2}_count"]
    return j


In [26]:
# Pulled from Prof G's code on github (https://github.com/Gutierrez-lab/oviIN-analyses-gabrielle/blob/main/modular_sandbox.ipynb)
def modularity_merge(df1,df2,suf1,suf2):
    """Given two modularity dataframes, merge them along shared body IDs. Pass in suffixes for the columns as strings."""
    merged_mod_df = df1.merge(df2, left_on='id', right_on='id', suffixes=[suf1, suf2])
    #merged_mod_df = df1.merge(df2, left_on='id', right_on='id', suffixes=['_oviHB', '_wholeHB'])
    return merged_mod_df

In [27]:
df_in

Unnamed: 0_level_0,id,0.0,0.05,0.1,0.5,0.75,1.0,instance,celltype,pre,...,status,cropped,statusLabel,cellBodyFiber,somaRadius,somaLocation,roiInfo,notes,inputRois,outputRois
key,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,1003215282,1,1,1,1,1,1,CL229_R,CL229,100,...,Traced,False,Roughly traced,PDM19,301.0,"[23044, 14981, 11600]","{'INP': {'pre': 87, 'post': 351, 'downstream':...",,"['EPA(R)', 'GOR(R)', 'IB', 'ICL(R)', 'INP', 'S...","['GOR(R)', 'IB', 'ICL(R)', 'INP', 'SCL(R)', 'S..."
2,1005952640,2,1,1,2,2,2,IB058_R,IB058,664,...,Traced,False,Roughly traced,PVL20,,,"{'INP': {'pre': 464, 'post': 1327, 'downstream...",,"['ATL(R)', 'IB', 'ICL(R)', 'INP', 'PLP(R)', 'S...","['ATL(R)', 'IB', 'ICL(R)', 'INP', 'PLP(R)', 'S..."
3,1006928515,1,1,1,3,3,3,CL300_R,CL300,86,...,Traced,False,Roughly traced,PVL13,236.0,"[12083, 10523, 16816]","{'INP': {'pre': 79, 'post': 126, 'downstream':...",,"['ATL(R)', 'IB', 'ICL(R)', 'INP', 'SCL(R)', 'S...","['ATL(R)', 'IB', 'ICL(R)', 'INP', 'SCL(R)', 'S..."
4,1007260806,2,1,1,4,4,4,CL301_R,CL301,119,...,Traced,False,Roughly traced,PVL13,236.0,"[13524, 10108, 16480]","{'INP': {'pre': 40, 'post': 128, 'downstream':...",,"['GOR(R)', 'IB', 'ICL(R)', 'INP', 'PLP(R)', 'S...","['IB', 'ICL(R)', 'INP', 'PLP(R)', 'SCL(R)', 'S..."
5,1008024276,3,2,2,5,5,5,FB5N_R,FB5N,499,...,Traced,False,Roughly traced,AVM08,472.5,"[19178, 29711, 37312]","{'SNP(L)': {'post': 5, 'upstream': 5, 'mito': ...",SMPCREFB5_4,"['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX', 'FB', 'F...","['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX', 'FB', 'F..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2507,987273073,3,8,8,409,604,629,(PVL05)_L,,283,...,Traced,False,Roughly traced,,,,"{'SNP(R)': {'pre': 65, 'post': 52, 'downstream...",,"['CRE(-ROB,-RUB)(R)', 'CRE(-RUB)(L)', 'CRE(L)'...","['CRE(-ROB,-RUB)(R)', 'CRE(-RUB)(L)', 'CRE(L)'..."
2508,987842109,3,9,23,533,780,815,,,2,...,Orphan,,Orphan hotknife,,,,"{'SNP(R)': {'pre': 2, 'post': 13, 'downstream'...",,"['SMP(R)', 'SNP(R)']","['SMP(R)', 'SNP(R)']"
2509,988567837,2,3,4,16,58,63,FB4G_R,FB4G,785,...,Traced,False,Roughly traced,AVM08,,,"{'SNP(R)': {'pre': 6, 'post': 73, 'downstream'...",CRELALFB4_3,"['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX', 'FB', 'F...","['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX', 'FB', 'F..."
2510,988909130,2,3,4,389,559,572,FB5V_R,FB5V,269,...,Traced,False,Roughly traced,AVM10,296.5,"[13226, 32024, 18600]","{'SNP(R)': {'pre': 1, 'post': 28, 'downstream'...",CRELALFB5,"['AB(R)', 'CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX',...","['CRE(-ROB,-RUB)(R)', 'CRE(R)', 'CX', 'FB', 'F..."


In [28]:
# Defining base res and suffixes
res = '0.0'
df1_suf = '_ovi'
df2_suf = '_whole'

# set id columns as index for both dataframes
df_in.set_index('id', inplace=True)
wb.set_index('id', inplace=True)

# Using merge function from code cell above
mod_merge_df = modularity_merge(df_in[[res]],wb[[res]],df1_suf,df2_suf)
mod_merge_df

Unnamed: 0_level_0,0.0_ovi,0.0_whole
id,Unnamed: 1_level_1,Unnamed: 2_level_1
1003215282,1,6
1005952640,2,1
1006928515,1,1
1007260806,2,1
1008024276,3,2
...,...,...
987117151,2,1
987273073,3,3
988567837,2,2
988909130,2,2


In [29]:
chi1 = res + df1_suf
chi2 = res + df2_suf

jm = joint_marginal(mod_merge_df, chi1, chi2, include_fraction=True)

# sort the clusters on the y axis to get a more "diagonal" plot
yrange = jm.sort_values([f"{chi2}_fraction"], ascending=False).groupby(chi2).agg({chi1: "first", f"{chi2}_fraction": "first", "joint_count": "first"}).sort_values([chi1, "joint_count"], ascending=[True, False]).index

# make a bokeh figure
f = figure(title=f"Clusters at chi2 = {chi2} vs. clusters at chi1 = {chi1}",
x_range=FactorRange(factors=[str(i + 1) for i in range(jm[chi1].max())]),
y_range=FactorRange(factors=[str(y) for y in yrange]),
width=600, height=700)
jm["x"] = jm[chi1].apply(str)  # bokeh factor range has to have strings, so we have to convert these
jm["y"] = jm[chi2].apply(str)

f.rect(x="x", y="y", width=f"{chi1}_fraction", height=f"{chi2}_fraction", source=jm)
f.add_tools(HoverTool(tooltips={"Neurons": "@joint_count (@joint_fraction{%%} of Hemibrain)",
                                f"Fraction of {chi2}": f"@{{{chi2}_fraction}}{{2.%%}}",
                                f"Fraction of {chi1}": f"@{{{chi1}_fraction}}{{2.%%}}"}))
f.xaxis.axis_label = 'Cluster in ' +chi1
f.yaxis.axis_label = 'Cluster in ' +chi2

show(f)