# Sunburst Plot for ROI Breakdown of Connectivity Locations
   - This script has been modified from an original version provided by Emily Joyce
   - Modifications include:
        - different queries for a direct vs one-hop relationship
        - link to ROI_hierarchy.csv for use with script
        - dedicated variables for body ID lists to ease updating or changing the values
        - dedicated token variable to ease updates or changes to another auth token
   - Links to more information regarding sunburst plots and how to setup your system to run the script
        - Sunburst plot basics: https://plotly.com/python/sunburst-charts/#basic-sunburst-plot-with-plotlyexpress
        - Getting Started/setup: https://plotly.com/python/getting-started/#jupyterlab-support-python-35

In [None]:
from neuprint import Client, fetch_custom
import pandas as pd
import plotly.express as px
import plotly as pio
import json
import numpy as np

In [None]:
# add your auth token here: token = 'abc123def456'
token = ""

c = Client('neuprint-test.janelia.org', dataset='hemibrain', token = token, verify=True)


#### If you do not have a ROI_hierarchy.csv saved in the same folder as this file you can specify your path or you can run obtain the csv at https://github.com/alvaradocx/CAT-work/blob/main/ROI_hierarchy.csv

## Direct Relationship
Use the cell below for a direct realtionship (a->b)
- this can also be used to look at the relationship between (a->b) or (b->c) in an (a->b->c) relationship

In [None]:
# add lists for all bodies of interest 
# body list for bodies in the starting position/most upstream
sbody_list = []

#  body list for the last/ending body
ebody_list = []
# if you want to apply a filter on the minimum number of synapses
minw =[]

# add you query here 
direct = f" WITH {sbody_list} AS START, {ebody_list} AS END \
MATCH (a:Neuron)-[w:ConnectsTo]->(b:Neuron) \
WHERE a.bodyId IN START AND b.bodyId IN END AND w.weight >= {minw} \
RETURN a.bodyId, a.type, w.weight, w.roiInfo, b.bodyId, b.type"

# specify the roiInfo column whose data you want to plot. 
# in the form: column = 'column_name'
column = 'w.roiInfo'

## One Hop Relationship
Use the cell below for a one hop relationship (a->b->c)

In [None]:
#add lists for all bodies of interest along the way a->b->c
#body list for bodies in the starting position/most upstream
sbody_list = []
#body list for all bodies in the interneuron/mid position
ibody_list = []
# body list for the last/ending body
ebody_list = []

# if you want to apply a filter on the minimum number of synapses for the relationship a->b
minw =[]
# if you want to apply a filter on the minimum number of synapses for the relationship b->c
minw2 =[]
# add you query here 
one_hop = f" WITH {sbody_list} AS START,{ibody_list} AS MID, {ebody_list} AS END \
MATCH (a:Neuron)-[w:ConnectsTo]->(b:Neuron)-[ww:ConnectsTo]->(c:Neuron) \
WHERE a.bodyId IN START AND b.bodyId IN MID AND c.bodyId IN END AND w.weight >= {minw} and ww.weight >= {minw2} \
RETURN a.bodyId, a.type, w.weight, w.roiInfo, b.bodyId, b.type, ww.weight, ww.roiInfo, c.bodyId, c.type"

# specify the roiInfo column whose data you want to plot. 
# in the form: column = 'column_name'
column = 'w.roiInfo'


In [None]:
# run query but specify if one_hop or direct query
df=fetch_custom(direct)

# add a column where you read the column of interest as a json(dict)
df['json1']=df[column].apply(json.loads)

In [None]:
def create__connectivity_plot(df_col):
    
    '''
    
    adds a synapse count column to 'hierarchy' that acts as a running 
    tally of the number of synapses in that ROI for a specific column
    of a data frame (containing roiInfo)
    
    '''
    
    # read hierarchy or ROIs spreadsheet
    hierarchy = pd.read_csv('ROI_hierarchy.csv')

    # add a new column "synapse count" where each cell in that column = 0
    hierarchy['synapse_count'] = [0]*len(hierarchy)
    
    
    # count how many synapses are in each ROI 
    for roi_dict in df_col:
        
        for roi, syn in roi_dict.items():
            
            '''
            
            We need to count the number of synapses in this ROI on 
            this synapse between 2 neurons. We need pre OR post. 
            sometimes pre=1 and post = null if the synapse is on the 
            edge of an ROI. Account for that here:
            
            '''
            try: 
                syn = syn['pre']
            except:
                syn = syn['post']
                
            '''
            
            add that syn number to that roi's synapse count in 
            the hierarchy graph
            
            '''
                
            hierarchy.loc[hierarchy[hierarchy['ROI_all'] == roi].index,'synapse_count'] += syn
     
    # remove rows with a synapse count of 0
    hierarchy = hierarchy[hierarchy['synapse_count'] != 0]
    
    # remove rows where the global ROI is the same as the roi_all,
    # unless there is only one row with that global ROI.
    global_counts = (hierarchy['Global'].value_counts())
    for idx, row in hierarchy.iterrows():
        
        if global_counts[row['Global']] == 1:
            continue
        elif row['ROI_all'] == row['Global']:
            hierarchy.drop(idx, inplace=True)
        else:
            continue 
                
    # this is only ploting Global ROIs and their direct sub ROIs. 
    sbplot = px.sunburst(hierarchy, path=['Global','ROI'], values='synapse_count')
   

    return sbplot
    
    

In [None]:
#plot creation
sbplot1 = create__connectivity_plot(df['json1'])
sbplot1

## Image Export
- you can export and interactive version of the plot or a static version
- static export can be done in JupyterLab. 
    - Hover over the plot and click the camera icon in the top right to download your plot as a png <font color=red>(recommended over advanced)
- more advanced static image export formats and instructions can be located at: https://plotly.com/python/static-image-export/

In [None]:
# interactive export, make sure to name your file!
sbplot1.write_html(".html")