In [None]:
#from aPhN-SA_Activation import set_1
#%pip install statsmodels

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from scipy.optimize import curve_fit
from venn import venn
import seaborn as sns
from matplotlib.colors import LogNorm
import statsmodels.api as sm

In [None]:
# Set seaborn theme to white
sns.set_theme(style='white')

# set up matplot lib theme
andy_theme = {'axes.grid': False,
              'grid.linestyle': '--',
              'legend.framealpha': 1,
              'legend.facecolor': 'white',
              'legend.shadow': False,
              'legend.fontsize': 14,
              'legend.title_fontsize': 14,
              'font.sans-serif':'Helvetica',
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'axes.labelsize': 12,
              'axes.titlesize': 16,
              'figure.dpi': 300}

plt.rcParams.update(andy_theme)

#Uncomment next 2 lines if matplotlib can not find Helvetica font
#plt.rcParams['font.family'] = 'DeJavu Serif'
#plt.rcParams['font.sans-serif'] = ['Arial']

## 1. FIRST ORDER ANALYSES

### Load the datasets with neurons and connections.

* This script assumes that the **CSV** files (`.csv.gz`) and **aPhN-SAs lists** (`_new.csv` files) are in the same folder as this notebook or script.
* These files include four CSVs containing manually curated  aPhN-SAs lists and four connectome datasets from FlyWire:
  1. **`classification.csv.gz`**
  2. **`connections.csv.gz`**
  3. **`neuropil_synapse_table.csv.gz`**
  4. **`neurons.csv.gz`**
* **Axon lists** were curated manually as described in the paper.
* **Connectome datasets** were downloaded from the FlyWire website using **snapshot 783** (previous snapshot 630).
* We focus on putative sensory axons from the Drosophila **pharyngeal nerve** in this analysis.


In [None]:
# Connections dataset and additional data sets

# Load the connections dataset
# columns: pre_root_id, post_root_id, neuropil, syn_count, nt_type
connections = pd.read_csv('/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz')

# Neuropil synapses
# columns: root_id, input synapses, input partners, output synapses, output partners, etc
# Keep only root_id, input syanapses, output synapses
neuropil_synapse = pd.read_csv('/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/YaoLabUF/YaoLab/Drosophila_brain_model/neuropil_synapse_table.csv.gz')[['root_id', 'input synapses', 'output synapses']]

# Rename with underscores
neuropil_synapse.rename(columns={'input synapses': 'input_synapses','output synapses': 'output_synapses'}, inplace=True)

# Load classification table
# columns: root_id, flow, super_class, side, etc
# Keep only root_id and side
classification = pd.read_csv('/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz')[['root_id', 'side']]
classification_other = pd.read_csv('/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz')[['root_id', 'super_class', 'class']]

# Load data about each neuron
# columns: root_id, group, nt_type, etc
# Keep only root_id, nt_type
neurons = pd.read_csv('/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/YaoLabUF/YaoLab/Drosophila_brain_model/neurons.csv.gz')[['root_id', 'nt_type']]

# Merging additional data in one data set
neurons_data = pd.merge(neurons, pd.merge(classification, neuropil_synapse, on='root_id',how= 'outer'), on='root_id',how='outer')

# Load putative PSO lists
set_1 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv')
set_2 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv')
set_3 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv')

### Find downstream connections of aPhN-SAs
- includes all neurons downstream of aPhN-SAs - we will filter out set-set connections later
- minimum of 5 synapses between the two neurons

In [None]:
# Define function to get outputs of aPhN-SAs
def neuronal_outputs(aph1_sa):
    # Merge the aph1_sa DataFrame with the 'connections' data, filtering out any connections
    # that have fewer than 5 synapses
    connectivity = pd.merge(
        aph1_sa['root_id'],
        connections[['pre_root_id','post_root_id','neuropil','syn_count','nt_type']],
        left_on='root_id',
        right_on='pre_root_id',
        how='inner'
    ).query("syn_count >= 5")

    # Remove the temporary 'root_id' column that came from the aph1_sa DataFrame
    connectivity = connectivity.drop(columns='root_id')

    # Define function to categorize connection location
    def projection(neuropil):
        if neuropil in ['GNG', 'PRW', 'SAD', 'FLA_L', 'FLA_R', 'CAN']:  # Example SEZ-related regions
            return 'local'
        else:
            return 'outside_SEZ'

    # Apply the projection categorization to each row in 'connectivity'
    connectivity['location_of_connection'] = connectivity['neuropil'].apply(projection)

    return connectivity

In [None]:
# Get the outputs for each set of aPhN-SAs
set_1_outputs = neuronal_outputs(set_1)
set_2_outputs = neuronal_outputs(set_2)

In [None]:
#fig.show(renderer="browser")
#fig.write_html("sankey_diagram.html")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import deque, defaultdict

# ───────────────────────────────────────────────────────────────────────────────
# 0. LOAD ALL YOUR DATASETS
# ───────────────────────────────────────────────────────────────────────────────
# connections: pre_root_id, post_root_id, neuropil, syn_count, nt_type
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)

# neuropil synapse table: root_id, input synapses, output synapses
neuropil_synapse = (
    pd.read_csv(
        '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
        'YaoLabUF/YaoLab/Drosophila_brain_model/neuropil_synapse_table.csv.gz'
    )
    [['root_id','input synapses','output synapses']]
    .rename(columns={
        'input synapses':'input_synapses',
        'output synapses':'output_synapses'
    })
)

# classification: root_id, side
classification_side = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','side']]

# classification_other: root_id, super_class, class
classification_other = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# build super_class map
super_map = classification_other.set_index('root_id')['super_class'].to_dict()

# neurons table: root_id, nt_type
neurons = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/neurons.csv.gz'
)[['root_id','nt_type']]

# merged neurons_data (if you need it downstream)
neurons_data = pd.merge(
    neurons,
    pd.merge(classification_side, neuropil_synapse, on='root_id', how='outer'),
    on='root_id',
    how='outer'
)

# your three PSO lists
set_1 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv')
set_2 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv')
set_3 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv')

# ───────────────────────────────────────────────────────────────────────────────
# 1. BUILD THRESHOLDED ADJACENCY LIST (≥5 synapses)
# ───────────────────────────────────────────────────────────────────────────────
edge_df = (
    connections
    .groupby(['pre_root_id','post_root_id'], as_index=False)
    .agg({'syn_count':'sum'})
    .query('syn_count >= 5')
    [['pre_root_id','post_root_id']]
)

adj = defaultdict(set)
for u, v in edge_df.values:
    adj[int(u)].add(int(v))

# ───────────────────────────────────────────────────────────────────────────────
# 2. DEFINE YOUR SETS (DCSO, aPhN1, aPhN2)
# ───────────────────────────────────────────────────────────────────────────────
sets = {
    'DCSO':  set_1['root_id'].astype(int),
    'aPhN1': set_2['root_id'].astype(int),
    'aPhN2': set_3['root_id'].astype(int),
}

# ───────────────────────────────────────────────────────────────────────────────
# 3. BFS HELPER TO COMPUTE MINIMAL HOPS TO A TARGET CLASS
# ───────────────────────────────────────────────────────────────────────────────
def compute_hops(src_ids, target_class):
    hop_of = {}
    for src in src_ids:
        if src not in adj:
            hop_of[src] = None
            continue

        visited = {src}
        queue = deque([(src, 0)])
        found = None

        while queue and found is None:
            node, dist = queue.popleft()
            if dist >= 3:
                continue
            for nei in adj[node]:
                if nei in visited:
                    continue
                visited.add(nei)
                nd = dist + 1
                if super_map.get(nei) == target_class:
                    found = nd
                    break
                queue.append((nei, nd))

        hop_of[src] = found

    counts = {'1':0,'2':0,'3':0,'>3':0}
    for h in hop_of.values():
        if h in (1,2,3):
            counts[str(h)] += 1
        else:
            counts['>3'] += 1

    return pd.DataFrame({
        'hop':  ['1','2','3','>3'],
        'count':[counts['1'],counts['2'],counts['3'],counts['>3']]
    })

# ───────────────────────────────────────────────────────────────────────────────
# 4. RUN FOR “motor” AND “endocrine” & PLOT STACKED BARS
# ───────────────────────────────────────────────────────────────────────────────
custom_colors = [
    '#EE7733',  # Vibrant Orange
    '#009E73',  # Vivid Blue
    '#33BBEE',  # Cyan
    '#CC3311'   # Red
]

for target in ['motor','endocrine']:
    all_df = []
    for label, ids in sets.items():
        df_h = compute_hops(ids, target)
        df_h['set'] = label
        all_df.append(df_h)

    df_stack = pd.concat(all_df, ignore_index=True)
    df_stack['hop'] = pd.Categorical(df_stack['hop'], ['1','2','3','>3'], ordered=True)
    df_stack['set'] = pd.Categorical(df_stack['set'], list(sets.keys()), ordered=True)

    pivot = df_stack.pivot(index='set', columns='hop', values='count').fillna(0)

    fig, ax = plt.subplots(figsize=(12,8))
    bottom = np.zeros(len(pivot), dtype=int)
    for i, hop in enumerate(['1','2','3','>3']):
        ax.bar(pivot.index, pivot[hop], bottom=bottom,
               color=custom_colors[i],
               label=f"{hop} hop{'s' if hop!='1' else ''}")
        bottom += pivot[hop].values

    ax.set_title(f"Hops from PSO-SA Sets to {target.capitalize()} Neurons", fontsize=20)
    ax.set_xlabel("PSO-SA Set", fontsize=14)
    ax.set_ylabel("Number of Cells", fontsize=14)
    ax.tick_params(labelsize=12)
    ax.legend(title="Hops", frameon=False, prop={'size':12})

        # ---- save as SVG ----
    fig.set_size_inches(12, 8)
    filename = f"hops_to_{target}.svg"
    fig.savefig(filename, format='svg', dpi=300, bbox_inches='tight')
    plt.tight_layout()
    plt.show()


In [None]:
for target in ['motor','endocrine']:
    # … all your plotting code …

    plt.tight_layout()
    # define a filename based on the target
    filename = f"hops_to_{target}.svg"
    # save it out at 1200×800px
    fig.set_size_inches(12, 8)
    fig.savefig(filename, format='svg', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and super_class classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO sets (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv')
set_2 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv')
set_3 = pd.read_csv('/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv')

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(
            columns={'root_id':'post_root_id','super_class':'output_super_class'}
        ),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build and show a Sankey diagram with maximal vertical spacing
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1) build hops
    df1 = build_hop_df(grn_df['root_id'],            connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # 2) summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(df1, df2, left_on='post_root_id', right_on='pre_root_id', suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(df2, df3, left_on='post_root_id', right_on='pre_root_id', suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # 3) build node labels
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {n:i for i,n in enumerate(nodes)}

    # 4) color mapping
    palette = px.colors.qualitative.Safe
    all_classes = sorted(
        set(df1['output_super_class']) |
        set(df2['output_super_class']) |
        set(df3['output_super_class'])
    )
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = ['lightgrey' if n==title else color_map[n.split(': ',1)[1]] for n in nodes]

    # 5) assemble links
    source, target, value, link_colors = [], [], [], []
    def add_links(flow_df, src_col, tgt_col):
        for _, r in flow_df.iterrows():
            s = idx[r[src_col]]
            t = idx[r[tgt_col]]
            source.append(s)
            target.append(t)
            value.append(r['count'])
            link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    flow1 = flow1.rename(columns={'source':'src','output_super_class':'dst'})
    flow1['dst'] = flow1['dst'].map(lambda c: f"1: {c}")
    add_links(flow1, 'src','dst')

    flow2 = flow2.rename(columns={'output_super_class_1':'src','output_super_class_2':'dst'})
    flow2['src'] = flow2['src'].map(lambda c: f"1: {c}")
    flow2['dst'] = flow2['dst'].map(lambda c: f"2: {c}")
    add_links(flow2, 'src','dst')

    flow3 = flow3.rename(columns={'output_super_class_2':'src','output_super_class_3':'dst'})
    flow3['src'] = flow3['src'].map(lambda c: f"2: {c}")
    flow3['dst'] = flow3['dst'].map(lambda c: f"3: {c}")
    add_links(flow3, 'src','dst')

    # 6) hover info
    incoming = dict.fromkeys(nodes,0)
    outgoing = dict.fromkeys(nodes,0)
    for s,t,v in zip(source,target,value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # 7) x positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)

    # 8) y positions: spread evenly down the column
    y = []
    for col in (col1, col2, col3, col4):
        n = len(col)
        if n == 1:
            y.append(0.5)
        else:
            y.extend(np.linspace(0, 1, n))

    # 9) draw Sankey with freeform arrangement
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=f"{title}", font_size=14)
    fig.write_image(f"{title.replace('/','_')}.svg", width=1200, height=800, scale=2)
    fig.show()

# ----------------------------------------------------------------------------
# 10. Generate a Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey, but filter df2 by motor criteria
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1st hop
    df1 = build_hop_df(grn_df['root_id'], connections, classification, min_syn)

    # full 2nd & 3rd hops
    df2_full = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # find 2nd-order roots that either are motor or feed into motor in 3rd hop
    motor_pre_ids = df3_full.loc[df3_full['output_super_class']=='motor','pre_root_id'].unique()
    df2_direct = df2_full[df2_full['output_super_class']=='motor']
    df2_feed   = df2_full[df2_full['post_root_id'].isin(motor_pre_ids)]
    df2 = pd.concat([df2_direct, df2_feed], ignore_index=True).drop_duplicates()

    # recompute 3rd hop from the filtered df2
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows for each hop ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble nodes & index mapping ---
    col1 = [title]
    col2 = sorted(flow1['target'].unique())
    col3 = sorted(flow2['target'].unique())
    col4 = sorted(flow3['target'].unique())
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- assign colors via Safe palette ---
    all_classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = ['lightgrey' if n==title else color_map[n.split(': ',1)[1]] for n in nodes]

    # --- build Sankey link lists ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- compute hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- x-axis positions for the 4 columns ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x,
            color=node_colors,
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate the filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey with vertical spacing
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # hop-1, hop-2, hop-3
    df1 = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # assemble nodes
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {n:i for i,n in enumerate(nodes)}

    # colors
    palette = px.colors.qualitative.Safe
    all_classes = sorted({
        *df1['output_super_class'],
        *df2['output_super_class'],
        *df3['output_super_class']
    })
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = [
        'lightgrey' if n==title else color_map[n.split(': ',1)[1]]
        for n in nodes
    ]

    # build links
    source, target, value, link_colors = [], [], [], []
    for _, r in flow1.iterrows():
        s = idx[title]
        t = idx[f"1: {r['output_super_class']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    for _, r in flow2.iterrows():
        s = idx[f"1: {r['output_super_class_1']}"]
        t = idx[f"2: {r['output_super_class_2']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    for _, r in flow3.iterrows():
        s = idx[f"2: {r['output_super_class_2']}"]
        t = idx[f"3: {r['output_super_class_3']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    # hover info
    incoming = dict.fromkeys(nodes,0)
    outgoing = dict.fromkeys(nodes,0)
    for s,t,v in zip(source,target,value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # x positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    # y positions for vertical spacing
    lens = [len(col1),len(col2),len(col3),len(col4)]
    y = []
    for n in lens:
        y.extend(list(np.linspace(0,1,n))) if n>1 else y.append(0.5)

    # plot
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black',width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey with df1 pruned to only those that
#       either directly go to motor (hop1→motor) or reach motor in hop2
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1) full hop-1, hop-2, hop-3
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # 2) find which hop-2 seeds (pre_root_id) reach motor in hop-3
    motor_in_hop3 = df3_full.loc[df3_full['output_super_class']=='motor','pre_root_id'].unique()

    # 3) prune df1: keep rows where
    #    a) output_super_class == 'motor' (direct)
    # OR b) the post_root_id is in motor_in_hop3 (indirect)
    mask_direct = df1_full['output_super_class']=='motor'
    mask_indirect = df1_full['post_root_id'].isin(motor_in_hop3)
    df1 = df1_full[mask_direct | mask_indirect].reset_index(drop=True)

    # 4) rebuild hop-2 & hop-3 from pruned df1
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey, prune df1 by direct/indirect/third-hop motor connection
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # 1. Direct: post_root_id in df1 directly to motor
    mask_direct = df1_full['output_super_class'] == 'motor'

    # 2. Indirect: post_root_id in df1 connects (in df2) to a motor neuron
    motor_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    mask_hop2 = df1_full['post_root_id'].isin(motor_in_hop2)

    # 3. Third-hop: post_root_id in df1 connects (via df2, df3) to a motor neuron
    # Find roots in df3 that are motor; backtrack to get the pre_root_id in df2 (which are post_root_id in df1)
    motor_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    # Now find which roots in df2 connect to those
    hop2_roots_to_motor3 = df2_full[df2_full['post_root_id'].isin(motor_in_hop3)]['pre_root_id'].unique()
    mask_hop3 = df1_full['post_root_id'].isin(hop2_roots_to_motor3)

    # Combined mask: keep if any are True
    df1 = df1_full[mask_direct | mask_hop2 | mask_hop3].reset_index(drop=True)

    # Rebuild hops from the pruned df1
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey with filtering for motor at each hop
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df1 (1st hop) as before (direct motor or motor downstream in hop2 or hop3) ----
    mask_direct1 = df1_full['output_super_class'] == 'motor'
    motor_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    mask_hop2_1 = df1_full['post_root_id'].isin(motor_in_hop2)
    motor_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    hop2_roots_to_motor3 = df2_full[df2_full['post_root_id'].isin(motor_in_hop3)]['pre_root_id'].unique()
    mask_hop3_1 = df1_full['post_root_id'].isin(hop2_roots_to_motor3)
    df1 = df1_full[mask_direct1 | mask_hop2_1 | mask_hop3_1].reset_index(drop=True)

    # ---- Rebuild df2 and df3 from pruned df1 ----
    df2_full2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full2 = build_hop_df(df2_full2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df2 (direct to motor OR downstream to motor in df3) ----
    mask_direct2 = df2_full2['output_super_class'] == 'motor'
    motor_in_df3 = df3_full2.loc[df3_full2['output_super_class'] == 'motor', 'pre_root_id'].unique()
    mask_downstream2 = df2_full2['post_root_id'].isin(motor_in_df3)
    df2 = df2_full2[mask_direct2 | mask_downstream2].reset_index(drop=True)

    # ---- Rebuild df3 from pruned df2 ----
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey, filtering at each hop for motor connectivity
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df1 (1st hop): direct to motor, or to df2/df3 roots that connect to motor
    mask_direct1 = df1_full['output_super_class'] == 'motor'
    motor_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    mask_hop2_1 = df1_full['post_root_id'].isin(motor_in_hop2)
    motor_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'motor', 'pre_root_id'].unique()
    hop2_roots_to_motor3 = df2_full[df2_full['post_root_id'].isin(motor_in_hop3)]['pre_root_id'].unique()
    mask_hop3_1 = df1_full['post_root_id'].isin(hop2_roots_to_motor3)
    df1 = df1_full[mask_direct1 | mask_hop2_1 | mask_hop3_1].reset_index(drop=True)

    # ---- Rebuild df2 and df3 from pruned df1 ----
    df2_full2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full2 = build_hop_df(df2_full2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df2: direct to motor, or to df3 roots that connect to motor
    mask_direct2 = df2_full2['output_super_class'] == 'motor'
    motor_in_df3 = df3_full2.loc[df3_full2['output_super_class'] == 'motor', 'pre_root_id'].unique()
    mask_downstream2 = df2_full2['post_root_id'].isin(motor_in_df3)
    df2 = df2_full2[mask_direct2 | mask_downstream2].reset_index(drop=True)

    # ---- Rebuild df3 from pruned df2 ----
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- FINAL: Only keep df3 rows that connect to motor
    df3 = df3[df3['output_super_class'] == 'motor'].reset_index(drop=True)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
        # ---- SAVE AS SVG ----
    svg_filename = f"{title.replace(' ', '_')}.svg"
    fig.write_image(svg_filename)          # <— saves the figure
    print(f"Saved Sankey for {title} as {svg_filename}")
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey, but filter df2 by endocrine criteria
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1st hop
    df1 = build_hop_df(grn_df['root_id'], connections, classification, min_syn)

    # full 2nd & 3rd hops
    df2_full = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # find 2nd-order roots that either are endocrine or feed into endocrine in 3rd hop
    endocrine_pre_ids = df3_full.loc[df3_full['output_super_class']=='endocrine','pre_root_id'].unique()
    df2_direct = df2_full[df2_full['output_super_class']=='endocrine']
    df2_feed   = df2_full[df2_full['post_root_id'].isin(endocrine_pre_ids)]
    df2 = pd.concat([df2_direct, df2_feed], ignore_index=True).drop_duplicates()

    # recompute 3rd hop from the filtered df2
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows for each hop ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble nodes & index mapping ---
    col1 = [title]
    col2 = sorted(flow1['target'].unique())
    col3 = sorted(flow2['target'].unique())
    col4 = sorted(flow3['target'].unique())
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- assign colors via Safe palette ---
    all_classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = ['lightgrey' if n==title else color_map[n.split(': ',1)[1]] for n in nodes]

    # --- build Sankey link lists ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- compute hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- x-axis positions for the 4 columns ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x,
            color=node_colors,
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate the filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey with vertical spacing
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # hop-1, hop-2, hop-3
    df1 = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # assemble nodes
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {n:i for i,n in enumerate(nodes)}

    # colors
    palette = px.colors.qualitative.Safe
    all_classes = sorted({
        *df1['output_super_class'],
        *df2['output_super_class'],
        *df3['output_super_class']
    })
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = [
        'lightgrey' if n==title else color_map[n.split(': ',1)[1]]
        for n in nodes
    ]

    # build links
    source, target, value, link_colors = [], [], [], []
    for _, r in flow1.iterrows():
        s = idx[title]
        t = idx[f"1: {r['output_super_class']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    for _, r in flow2.iterrows():
        s = idx[f"1: {r['output_super_class_1']}"]
        t = idx[f"2: {r['output_super_class_2']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    for _, r in flow3.iterrows():
        s = idx[f"2: {r['output_super_class_2']}"]
        t = idx[f"3: {r['output_super_class_3']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    # hover info
    incoming = dict.fromkeys(nodes,0)
    outgoing = dict.fromkeys(nodes,0)
    for s,t,v in zip(source,target,value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # x positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    # y positions for vertical spacing
    lens = [len(col1),len(col2),len(col3),len(col4)]
    y = []
    for n in lens:
        y.extend(list(np.linspace(0,1,n))) if n>1 else y.append(0.5)

    # plot
    fig = go.Figure(go.Sankey(
        arrangement='freeform',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black',width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from collections import defaultdict

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build & show a 3-hop Sankey with df1 pruned to only those that
#       either directly go to endocrine (hop1→endocrine) or reach endocrine in hop2
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1) full hop-1, hop-2, hop-3
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # 2) find which hop-2 seeds (pre_root_id) reach endocrine in hop-3
    endocrine_in_hop3 = df3_full.loc[df3_full['output_super_class']=='endocrine','pre_root_id'].unique()

    # 3) prune df1: keep rows where
    #    a) output_super_class == 'endocrine' (direct)
    # OR b) the post_root_id is in endocrine_in_hop3 (indirect)
    mask_direct = df1_full['output_super_class']=='endocrine'
    mask_indirect = df1_full['post_root_id'].isin(endocrine_in_hop3)
    df1 = df1_full[mask_direct | mask_indirect].reset_index(drop=True)

    # 4) rebuild hop-2 & hop-3 from pruned df1
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey, prune df1 by direct/indirect/third-hop endocrine connection
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # 1. Direct: post_root_id in df1 directly to endocrine
    mask_direct = df1_full['output_super_class'] == 'endocrine'

    # 2. Indirect: post_root_id in df1 connects (in df2) to a endocrine neuron
    endocrine_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    mask_hop2 = df1_full['post_root_id'].isin(endocrine_in_hop2)

    # 3. Third-hop: post_root_id in df1 connects (via df2, df3) to a endocrine neuron
    # Find roots in df3 that are endocrine; backtrack to get the pre_root_id in df2 (which are post_root_id in df1)
    endocrine_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    # Now find which roots in df2 connect to those
    hop2_roots_to_endocrine3 = df2_full[df2_full['post_root_id'].isin(endocrine_in_hop3)]['pre_root_id'].unique()
    mask_hop3 = df1_full['post_root_id'].isin(hop2_roots_to_endocrine3)

    # Combined mask: keep if any are True
    df1 = df1_full[mask_direct | mask_hop2 | mask_hop3].reset_index(drop=True)

    # Rebuild hops from the pruned df1
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey with filtering for endocrine at each hop
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df1 (1st hop) as before (direct endocrine or endocrine downstream in hop2 or hop3) ----
    mask_direct1 = df1_full['output_super_class'] == 'endocrine'
    endocrine_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    mask_hop2_1 = df1_full['post_root_id'].isin(endocrine_in_hop2)
    endocrine_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    hop2_roots_to_endocrine3 = df2_full[df2_full['post_root_id'].isin(endocrine_in_hop3)]['pre_root_id'].unique()
    mask_hop3_1 = df1_full['post_root_id'].isin(hop2_roots_to_endocrine3)
    df1 = df1_full[mask_direct1 | mask_hop2_1 | mask_hop3_1].reset_index(drop=True)

    # ---- Rebuild df2 and df3 from pruned df1 ----
    df2_full2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full2 = build_hop_df(df2_full2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df2 (direct to endocrine OR downstream to endocrine in df3) ----
    mask_direct2 = df2_full2['output_super_class'] == 'endocrine'
    endocrine_in_df3 = df3_full2.loc[df3_full2['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    mask_downstream2 = df2_full2['post_root_id'].isin(endocrine_in_df3)
    df2 = df2_full2[mask_direct2 | mask_downstream2].reset_index(drop=True)

    # ---- Rebuild df3 from pruned df2 ----
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/connections.csv.gz'
)
classification = pd.read_csv(
    '/Users/yaolab/Library/CloudStorage/OneDrive-UniversityofFlorida/'
    'YaoLabUF/YaoLab/Drosophila_brain_model/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO lists (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_1.csv'
)
set_2 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_2.csv'
)
set_3 = pd.read_csv(
    '/Users/yaolab/Downloads/taste-connectome-main/aPhN-SA_v3/set_3.csv'
)

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={
            'root_id':'post_root_id',
            'super_class':'output_super_class'
        }),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: 3-hop Sankey, filtering at each hop for endocrine connectivity
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # Compute all hops fully
    df1_full = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2_full = build_hop_df(df1_full['post_root_id'].unique(), connections, classification, min_syn)
    df3_full = build_hop_df(df2_full['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df1 (1st hop): direct to endocrine, or to df2/df3 roots that connect to endocrine
    mask_direct1 = df1_full['output_super_class'] == 'endocrine'
    endocrine_in_hop2 = df2_full.loc[df2_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    mask_hop2_1 = df1_full['post_root_id'].isin(endocrine_in_hop2)
    endocrine_in_hop3 = df3_full.loc[df3_full['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    hop2_roots_to_endocrine3 = df2_full[df2_full['post_root_id'].isin(endocrine_in_hop3)]['pre_root_id'].unique()
    mask_hop3_1 = df1_full['post_root_id'].isin(hop2_roots_to_endocrine3)
    df1 = df1_full[mask_direct1 | mask_hop2_1 | mask_hop3_1].reset_index(drop=True)

    # ---- Rebuild df2 and df3 from pruned df1 ----
    df2_full2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3_full2 = build_hop_df(df2_full2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- Filter df2: direct to endocrine, or to df3 roots that connect to endocrine
    mask_direct2 = df2_full2['output_super_class'] == 'endocrine'
    endocrine_in_df3 = df3_full2.loc[df3_full2['output_super_class'] == 'endocrine', 'pre_root_id'].unique()
    mask_downstream2 = df2_full2['post_root_id'].isin(endocrine_in_df3)
    df2 = df2_full2[mask_direct2 | mask_downstream2].reset_index(drop=True)

    # ---- Rebuild df3 from pruned df2 ----
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # ---- FINAL: Only keep df3 rows that connect to endocrine
    df3 = df3[df3['output_super_class'] == 'endocrine'].reset_index(drop=True)

    # --- summarize flows ---
    flow1 = (
        df1.groupby('output_super_class', as_index=False)['syn_count']
           .sum().rename(columns={'syn_count':'count'})
           .assign(source=title,
                   target=lambda d: '1: ' + d['output_super_class'])
    )

    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'], as_index=False)
           ['syn_count_2'].sum()
           .rename(columns={
               'output_super_class_1':'source',
               'output_super_class_2':'target',
               'syn_count_2':'count'
           })
    )
    flow2['source'] = flow2['source'].apply(lambda c: '1: ' + c)
    flow2['target'] = flow2['target'].apply(lambda c: '2: ' + c)

    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'], as_index=False)
           ['syn_count_3'].sum()
           .rename(columns={
               'output_super_class_2':'source',
               'output_super_class_3':'target',
               'syn_count_3':'count'
           })
    )
    flow3['source'] = flow3['source'].apply(lambda c: '2: ' + c)
    flow3['target'] = flow3['target'].apply(lambda c: '3: ' + c)

    # --- assemble node lists & indices ---
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n:i for i,n in enumerate(nodes)}

    # --- color mapping ---
    classes = sorted({n.split(': ',1)[1] for n in col2+col3+col4})
    palette = px.colors.qualitative.Safe
    cmap = {cls: palette[i % len(palette)] for i,cls in enumerate(classes)}
    node_colors = [
        'lightgrey' if n==title else cmap[n.split(': ',1)[1]]
        for n in nodes
    ]

    # --- build Sankey link arrays ---
    source, target, value, link_colors = [], [], [], []
    for df in (flow1, flow2, flow3):
        for _, r in df.iterrows():
            s, t, v = idx[r['source']], idx[r['target']], r['count']
            source.append(s)
            target.append(t)
            value.append(v)
            rgba = node_colors[s].replace('rgb(', 'rgba(').replace(')', ',0.5)')
            link_colors.append(rgba)

    # --- hover info ---
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # --- layout positions ---
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    lens = [len(col1), len(col2), len(col3), len(col4)]
    y = []
    for L in lens:
        y.extend(list(np.linspace(0,1,L))) if L>1 else y.append(0.5)

    # --- plot Sankey ---
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
        # ---- SAVE AS SVG ----
    svg_filename = f"{title.replace(' ', '_')}.svg"
    fig.write_image(svg_filename)          # <— saves the figure
    print(f"Saved Sankey for {title} as {svg_filename}")
    fig.show()

# ----------------------------------------------------------------------------
# Generate filtered Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)
