## Workflow overview (from exploratory checks to final portrait figure)

### Task (what this notebook does)
This notebook builds **multi-hop synaptic flow summaries** from curated **sensory neuron sets** and visualizes them as **Sankey diagrams**. The earlier cells are **exploratory** (single-panel Sankeys for individual sets and minor variants), while the last section assembles the **final multi-panel portrait figure** (all Sankeys stacked and exported as one SVG for Illustrator).

---

### Abbreviations / naming
- **FlyWire**: the connectome dataset used here.
- **GRN**: gustatory receptor neuron (input lists loaded from CSVs; each has a `root_id` column).
- **PhN**: pharyngeal nerve (here: **PhN-SA_v2** sets; 6 CSVs). NOTE: the PhN contains the stomodeal nerve (StN) gut sensory afferents.
- **PSO**: presumed pharyngeal sensory organ group (here: **aPhN-SA** sets; 3 CSVs labeled `DCSO`, `aPhN1`, `aPhN2`).
- **MxLbN**: maxillary labellar nerve group (here: 4 GRN lists: `Sugar/Water`, `Bitter`, `Ir94e`, `Taste Peg`).
- **SA**: sensory axon (used in dataset naming).
- **`root_id`**: unique neuron identifier (node ID).
- **`pre_root_id` / `post_root_id`**: presynaptic / postsynaptic neuron IDs in `connections`.
- **`syn_count`**: number of synapses between a pre→post pair.
- **Superclass / `super_class`**: broad anatomical/functional category (used to label nodes and color them).
- **Hop**: one step of directed connectivity (pre→post).  
  - **Hop 1**: directly downstream of the input set  
  - **Hop 2**: downstream of Hop 1 targets  
  - **Hop 3**: downstream of Hop 2 targets

---

### Inputs (tables and lists)
- **Connectome edges**: `flywire_data/connections.csv.(gz|zip)` with columns like `pre_root_id`, `post_root_id`, `syn_count`.
- **Neuron metadata**: `flywire_data/classification.csv.gz` (subset used: `root_id`, `super_class`).
- **Seed neuron sets (CSV lists)**:
  - **PhN-SA_v2**: `input/PhN/set_1.csv` … `set_6.csv`
  - **PSO/aPhN-SA**: `input/aPhN-SA/set_1.csv` … `set_3.csv` (mapped to `DCSO`, `aPhN1`, `aPhN2`)
  - **MxLbN-SA**: `input/MxLbN-SA/*.csv` (4 stimulus/modality lists)

---

### Core computation (how a Sankey is built)
1. **Filter edges by seed set**  
   Select rows where `pre_root_id ∈ seed_root_ids`.

2. **Aggregate to pair-level synapses**  
   Group by `(pre_root_id, post_root_id)` and sum `syn_count` so each directed pair has a total weight.

3. **Threshold weak connections**  
   Keep only pairs with **total `syn_count ≥ min_syn`** (default used throughout: `min_syn = 5`).

4. **Annotate targets with superclass**  
   Merge `post_root_id` with `classification` to get `output_super_class`.

5. **Repeat for downstream hops (1→3)**  
   Hop 1 uses the seed IDs; Hop 2 uses Hop 1 target IDs; Hop 3 uses Hop 2 target IDs.

6. **Summarize flows at the superclass level**  
   Convert hop-level edges into flow tables:
   - **Seed → Hop1 superclass totals**
   - **Hop1 superclass → Hop2 superclass totals**
   - **Hop2 superclass → Hop3 superclass totals**

7. **Plot Sankey with consistent labeling and colors**  
   Nodes are labeled by hop, e.g. `1: central`, `2: motor`, `3: endocrine`.  
   Colors are assigned by `super_class` so the same superclass has the same color within a figure (and, where implemented, globally).

---

### Notebook structure (exploratory → final)
- **Exploratory section(s)**:
  - Generate **single Sankey diagrams** per set for:
    - the 6 **PhN-SA_v2** sets,
    - the 3 **PSO/aPhN-SA** sets,
    - the 4 **MxLbN-SA** sets.
  - These cells help validate thresholds, hop logic, labels, and layout.

- **Final figure section (for Supplementary Fig. S3)**:
  - Defines `make_sankey_trace(...)` to return a Plotly Sankey trace.
  - Loads **all 13 datasets** (6 PhN + 3 PSO + 4 MxLbN).
  - Uses `plotly.subplots.make_subplots` to stack them into a **13×1 portrait layout**.
  - Exports a **single SVG** (intended for Adobe Illustrator edits).

---

### Outputs
- Individual Sankey panels are shown interactively, and the final portrait figure is written to:
  - `./Giakoumas-et-al/output/figures/fig_S3/all_sankeys_portrait.svg`


In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and super_class classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    './Giakoumas-et-al/flywire_data/connections.csv.zip'
)
classification = pd.read_csv(
    './Giakoumas-et-al/flywire_data/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Build a GLOBAL color map for every superclass once, up-front
# ----------------------------------------------------------------------------
all_classes_global = sorted(classification['super_class'].unique())
palette = px.colors.qualitative.Safe
GLOBAL_COLOR_MAP = {
    cls: palette[i % len(palette)]
    for i, cls in enumerate(all_classes_global)
}

# ----------------------------------------------------------------------------
# 2. Load your six PhN-SA_v2 sets (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
sets = {
    f'Set{i}': pd.read_csv(
        f'./Giakoumas-et-al/input/PhN/set_{i}.csv'
    )
    for i in range(1, 7)
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(
            columns={'root_id':'post_root_id','super_class':'output_super_class'}
        ),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build and show a Sankey diagram with consistent colors
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # hops
    df1 = build_hop_df(grn_df['root_id'],            connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(df1, df2,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(df2, df3,
                   left_on='post_root_id', right_on='pre_root_id',
                   suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # nodes
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {node:i for i,node in enumerate(nodes)}

    # colors: reuse the GLOBAL_COLOR_MAP
    node_colors = [
        'lightgrey' if n == title else GLOBAL_COLOR_MAP[n.split(': ',1)[1]]
        for n in nodes
    ]

    # links
    source, target, value, link_colors = [], [], [], []
    def add_links(df, src_col, tgt_col):
        for _, r in df.iterrows():
            s = idx[r[src_col]]
            t = idx[r[tgt_col]]
            source.append(s)
            target.append(t)
            value.append(r['count'])
            link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    # flow1 → 1:class
    flow1 = flow1.rename(columns={'source':'src','output_super_class':'dst'})
    flow1['dst'] = flow1['dst'].map(lambda c: f"1: {c}")
    add_links(flow1, 'src', 'dst')

    # flow2 → 2:class
    flow2 = flow2.rename(columns={
        'output_super_class_1':'src','output_super_class_2':'dst'
    })
    flow2['src'] = flow2['src'].map(lambda c: f"1: {c}")
    flow2['dst'] = flow2['dst'].map(lambda c: f"2: {c}")
    add_links(flow2, 'src', 'dst')

    # flow3 → 3:class
    flow3 = flow3.rename(columns={
        'output_super_class_2':'src','output_super_class_3':'dst'
    })
    flow3['src'] = flow3['src'].map(lambda c: f"2: {c}")
    flow3['dst'] = flow3['dst'].map(lambda c: f"3: {c}")
    add_links(flow3, 'src', 'dst')

    # hover
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s,t,v in zip(source, target, value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)
    y = []
    for col in (col1, col2, col3, col4):
        n = len(col)
        y.extend([0.5] if n==1 else list(np.linspace(0,1,n)))

    # draw
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=f"{title}", font_size=14)
    fig.write_image(f"{title.replace('/','_')}.svg", width=1200, height=800, scale=2)
    fig.show()

# ----------------------------------------------------------------------------
# 3. Generate a Sankey for each of the six PhN-SA_v2 sets
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


### Generate a Sankey for each PSO set

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and super_class classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    './Giakoumas-et-al/flywire_data/connections.csv.gz'
)
classification = pd.read_csv(
    './Giakoumas-et-al/flywire_data/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your three PSO sets (each CSV has a 'root_id' column)
# ----------------------------------------------------------------------------
set_1 = pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_1.csv')
set_2 = pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_2.csv')
set_3 = pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_3.csv')

sets = {
    'DCSO':  set_1,
    'aPhN1': set_2,
    'aPhN2': set_3,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(
            columns={'root_id':'post_root_id','super_class':'output_super_class'}
        ),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build and show a Sankey diagram with maximal vertical spacing
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # 1) build hops
    df1 = build_hop_df(grn_df['root_id'],            connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # 2) summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(df1, df2, left_on='post_root_id', right_on='pre_root_id', suffixes=('_1','_2'))
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(df2, df3, left_on='post_root_id', right_on='pre_root_id', suffixes=('_2','_3'))
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # 3) build node labels
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {n:i for i,n in enumerate(nodes)}

    # 4) color mapping
    palette = px.colors.qualitative.Safe
    all_classes = sorted(
        set(df1['output_super_class']) |
        set(df2['output_super_class']) |
        set(df3['output_super_class'])
    )
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = ['lightgrey' if n==title else color_map[n.split(': ',1)[1]] for n in nodes]

    # 5) assemble links
    source, target, value, link_colors = [], [], [], []
    def add_links(flow_df, src_col, tgt_col):
        for _, r in flow_df.iterrows():
            s = idx[r[src_col]]
            t = idx[r[tgt_col]]
            source.append(s)
            target.append(t)
            value.append(r['count'])
            link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    flow1 = flow1.rename(columns={'source':'src','output_super_class':'dst'})
    flow1['dst'] = flow1['dst'].map(lambda c: f"1: {c}")
    add_links(flow1, 'src','dst')

    flow2 = flow2.rename(columns={'output_super_class_1':'src','output_super_class_2':'dst'})
    flow2['src'] = flow2['src'].map(lambda c: f"1: {c}")
    flow2['dst'] = flow2['dst'].map(lambda c: f"2: {c}")
    add_links(flow2, 'src','dst')

    flow3 = flow3.rename(columns={'output_super_class_2':'src','output_super_class_3':'dst'})
    flow3['src'] = flow3['src'].map(lambda c: f"2: {c}")
    flow3['dst'] = flow3['dst'].map(lambda c: f"3: {c}")
    add_links(flow3, 'src','dst')

    # 6) hover info
    incoming = dict.fromkeys(nodes,0)
    outgoing = dict.fromkeys(nodes,0)
    for s,t,v in zip(source,target,value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # 7) x positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)

    # 8) y positions: spread evenly down the column
    y = []
    for col in (col1, col2, col3, col4):
        n = len(col)
        if n == 1:
            y.append(0.5)
        else:
            y.extend(np.linspace(0, 1, n))

    # 9) draw Sankey with freeform arrangement
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x, y=y,
            color=node_colors,
            pad=15, thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=f"{title}", font_size=14)
    fig.write_image(f"{title.replace('/','_')}.svg", width=1200, height=800, scale=2)
    fig.show()

# ----------------------------------------------------------------------------
# 10. Generate a Sankey for each PSO set
# ----------------------------------------------------------------------------
for label, df in sets.items():
    plot_sankey_dynamic(df, label, connections, classification)


### Generate Sankey for each MxLbN set

In [None]:
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

# ----------------------------------------------------------------------------
# 0. Load FlyWire connectome and classification
# ----------------------------------------------------------------------------
connections = pd.read_csv(
    './Giakoumas-et-al/flywire_data/connections.csv.gz'
)
classification = pd.read_csv(
    './Giakoumas-et-al/flywire_data/classification.csv.gz'
)[['root_id','super_class']]

# ----------------------------------------------------------------------------
# 1. Load your four MxLbN GRN lists (each CSV has a 'root_id' column).
# ----------------------------------------------------------------------------
sugar_water = pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/sugar_water_GRNs.csv")
bitter      = pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/bitter_GRNs.csv")
ir94e       = pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/Ir94e_GRNs.csv")
taste_peg   = pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/taste_peg_GRNs.csv")

sets = {
    'Sugar/Water': sugar_water,
    'Bitter':      bitter,
    'Ir94e':       ir94e,
    'Taste Peg':   taste_peg,
}

# ----------------------------------------------------------------------------
# Helper: filter connections by source IDs, threshold syn_count, attach superclass
# ----------------------------------------------------------------------------
def build_hop_df(src_ids, connections, classification, min_syn=5):
    df = connections[connections['pre_root_id'].isin(src_ids)]
    summed = (
        df.groupby(['pre_root_id','post_root_id'], as_index=False)
          .agg({'syn_count':'sum'})
          .query('syn_count >= @min_syn')
    )
    merged = pd.merge(
        summed,
        classification.rename(columns={'root_id':'post_root_id','super_class':'output_super_class'}),
        on='post_root_id', how='left'
    )
    return merged[['pre_root_id','post_root_id','output_super_class','syn_count']]

# ----------------------------------------------------------------------------
# Core: build and show a Sankey diagram using Plotly's Safe palette
# ----------------------------------------------------------------------------
def plot_sankey_dynamic(grn_df, title, connections, classification, min_syn=5):
    # build hop-level dataframes
    df1 = build_hop_df(grn_df['root_id'], connections, classification, min_syn)
    df2 = build_hop_df(df1['post_root_id'].unique(), connections, classification, min_syn)
    df3 = build_hop_df(df2['post_root_id'].unique(), connections, classification, min_syn)

    # summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum()
           .reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(
        df1, df2,
        left_on='post_root_id', right_on='pre_root_id',
        suffixes=('_1','_2')
    )
    flow2 = (
        m12.groupby(['output_super_class_1','output_super_class_2'])['syn_count_2']
           .sum()
           .reset_index(name='count')
    )
    m23 = pd.merge(
        df2, df3,
        left_on='post_root_id', right_on='pre_root_id',
        suffixes=('_2','_3')
    )
    flow3 = (
        m23.groupby(['output_super_class_2','output_super_class_3'])['syn_count_3']
           .sum()
           .reset_index(name='count')
    )

    # assemble nodes
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx   = {node:i for i,node in enumerate(nodes)}

    # assign colors via Safe palette
    palette = px.colors.qualitative.Safe
    all_classes = sorted(
        set(df1['output_super_class'])
      | set(df2['output_super_class'])
      | set(df3['output_super_class'])
    )
    color_map = {cls: palette[i % len(palette)] for i,cls in enumerate(all_classes)}
    node_colors = [
        'lightgrey' if n == title else color_map[n.split(': ',1)[1]]
        for n in nodes
    ]

    # build link lists
    source, target, value, link_colors = [], [], [], []
    # flow1 links
    for _, r in flow1.iterrows():
        s = idx[title]
        t = idx[f"1: {r['output_super_class']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    # flow2 links
    for _, r in flow2.iterrows():
        s = idx[f"1: {r['output_super_class_1']}"]
        t = idx[f"2: {r['output_super_class_2']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))
    # flow3 links
    for _, r in flow3.iterrows():
        s = idx[f"2: {r['output_super_class_2']}"]
        t = idx[f"3: {r['output_super_class_3']}"]
        source.append(s); target.append(t); value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb','rgba').replace(')',',0.5)'))

    # compute hover info
    incoming, outgoing = dict.fromkeys(nodes,0), dict.fromkeys(nodes,0)
    for s,t,v in zip(source,target,value):
        outgoing[nodes[s]] += v
        incoming[nodes[t]]  += v
    customdata = [
        f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}"
        for n in nodes
    ]

    # set x-positions
    x = [0.0]*len(col1) + [0.33]*len(col2) + [0.66]*len(col3) + [1.0]*len(col4)

    # plot Sankey
    fig = go.Figure(go.Sankey(
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x,
            color=node_colors,
            pad=15,
            thickness=20,
            line=dict(color='black', width=0.5),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    ))
    fig.update_layout(title_text=title, font_size=14)
    fig.show()

# ----------------------------------------------------------------------------
# Generate Sankey for each MxLbN set
# ----------------------------------------------------------------------------
datasets = [
    ('Sugar/Water', sugar_water),
    ('Bitter',      bitter),
    ('Ir94e',       ir94e),
    ('Taste Peg',   taste_peg),
]
for label, df in datasets:
    plot_sankey_dynamic(df, label, connections, classification)


### Generate Sankey for each PSOS set

In [None]:
# Only psos:
psos = [
    ('DCSO',  pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_1.csv')),
    ('aPhN1', pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_2.csv')),
    ('aPhN2', pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_3.csv')),
]

fig = make_subplots(
    rows=3, cols=1,
    vertical_spacing=0.02,
    specs=[[{"type":"sankey"}]] * 3
)
for i, (label, df) in enumerate(all_sets, start=1):
    sankey_trace = make_sankey_trace(
        grn_df         = df,
        title          = label,
        connections    = connections,
        classification = classification,
        min_syn        = 5
    )
    fig.add_trace(sankey_trace, row=i, col=1)

fig.update_layout(
    height=800,    # 3 panels × 250 px each = 750 px + small margins
    width=600,
    margin=dict(l=20, r=20, t=30, b=20),
    font=dict(size=10),
    title="PSO Sets Sankeys (stacked)",
)
fig.show()


### Main figure for Supplementary Figure S3 to be edited in Adobe Illustrator: all 13 Sankey diagrams in one portrait figure

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

def make_sankey_trace(grn_df, title, connections, classification, min_syn=5):
    # Build hop‐level dataframes without using .query()
    def build_hop_df(src_ids):
        df = connections[connections['pre_root_id'].isin(src_ids)]
        summed = (
            df.groupby(['pre_root_id', 'post_root_id'], as_index=False)
              .agg({'syn_count': 'sum'})
        )
        summed = summed[summed['syn_count'] >= min_syn]
        merged = pd.merge(
            summed,
            classification.rename(
                columns={'root_id': 'post_root_id', 'super_class': 'output_super_class'}
            ),
            on='post_root_id', how='left'
        )
        return merged[['pre_root_id', 'post_root_id', 'output_super_class', 'syn_count']]

    # 1) Build the three hops
    df1 = build_hop_df(grn_df['root_id'])
    df2 = build_hop_df(df1['post_root_id'].unique())
    df3 = build_hop_df(df2['post_root_id'].unique())

    # 2) Summarize flows
    flow1 = (
        df1.groupby('output_super_class')['syn_count']
           .sum().reset_index(name='count')
           .assign(source=title)
    )
    m12 = pd.merge(
        df1, df2,
        left_on='post_root_id', right_on='pre_root_id',
        suffixes=('_1', '_2')
    )
    flow2 = (
        m12.groupby(['output_super_class_1', 'output_super_class_2'])['syn_count_2']
          .sum().reset_index(name='count')
    )
    m23 = pd.merge(
        df2, df3,
        left_on='post_root_id', right_on='pre_root_id',
        suffixes=('_2', '_3')
    )
    flow3 = (
        m23.groupby(['output_super_class_2', 'output_super_class_3'])['syn_count_3']
          .sum().reset_index(name='count')
    )

    # 3) Node labels
    col1 = [title]
    col2 = [f"1: {c}" for c in sorted(df1['output_super_class'].unique())]
    col3 = [f"2: {c}" for c in sorted(df2['output_super_class'].unique())]
    col4 = [f"3: {c}" for c in sorted(df3['output_super_class'].unique())]
    nodes = col1 + col2 + col3 + col4
    idx = {n: i for i, n in enumerate(nodes)}

    # 4) Node colors
    palette = px.colors.qualitative.Safe
    all_classes = sorted(
        set(df1['output_super_class']) |
        set(df2['output_super_class']) |
        set(df3['output_super_class'])
    )
    color_map = {cls: palette[i % len(palette)] for i, cls in enumerate(all_classes)}
    node_colors = [
        'lightgrey' if n == title else color_map[n.split(': ', 1)[1]]
        for n in nodes
    ]

    # 5) Build link lists
    source, target, value, link_colors = [], [], [], []

    # flow1 → “1:Class”
    flow1 = flow1.rename(columns={'source': 'src', 'output_super_class': 'dst'})
    flow1['dst'] = flow1['dst'].map(lambda c: f"1: {c}")
    for _, r in flow1.iterrows():
        s = idx[r['src']]
        t = idx[r['dst']]
        source.append(s)
        target.append(t)
        value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb', 'rgba').replace(')', ',0.5)'))

    # flow2 → “2:Class”
    flow2 = flow2.rename(columns={'output_super_class_1': 'src', 'output_super_class_2': 'dst'})
    flow2['src'] = flow2['src'].map(lambda c: f"1: {c}")
    flow2['dst'] = flow2['dst'].map(lambda c: f"2: {c}")
    for _, r in flow2.iterrows():
        s = idx[r['src']]
        t = idx[r['dst']]
        source.append(s)
        target.append(t)
        value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb', 'rgba').replace(')', ',0.5)'))

    # flow3 → “3:Class”
    flow3 = flow3.rename(columns={'output_super_class_2': 'src', 'output_super_class_3': 'dst'})
    flow3['src'] = flow3['src'].map(lambda c: f"2: {c}")
    flow3['dst'] = flow3['dst'].map(lambda c: f"3: {c}")
    for _, r in flow3.iterrows():
        s = idx[r['src']]
        t = idx[r['dst']]
        source.append(s)
        target.append(t)
        value.append(r['count'])
        link_colors.append(node_colors[s].replace('rgb', 'rgba').replace(')', ',0.5)'))

    # 6) Compute hover info
    incoming = dict.fromkeys(nodes, 0)
    outgoing = dict.fromkeys(nodes, 0)
    for s_, t_, v_ in zip(source, target, value):
        outgoing[nodes[s_]] += v_
        incoming[nodes[t_]] += v_
    customdata = [f"Incoming: {incoming[n]}<br>Outgoing: {outgoing[n]}" for n in nodes]

    # 7) X positions
    x0 = [0.0] * len(col1)
    x1 = [0.33] * len(col2)
    x2 = [0.66] * len(col3)
    x3 = [1.0] * len(col4)
    x = x0 + x1 + x2 + x3

    # 8) Y positions
    y = []
    for col in (col1, col2, col3, col4):
        n = len(col)
        if n == 1:
            y.append(0.5)
        else:
            y.extend(list(np.linspace(0, 1, n)))

    # 9) **Return the Sankey trace directly** (no go.Trace wrapping)
    sankey = go.Sankey(
        name=title,
        arrangement='snap',
        node=dict(
            label=nodes,
            x=x,
            y=y,
            color=node_colors,
            pad=8,
            thickness=12,
            line=dict(color='black', width=0.3),
            customdata=customdata,
            hovertemplate='%{customdata}<extra>%{label}</extra>'
        ),
        link=dict(source=source, target=target, value=value, color=link_colors)
    )

    return sankey


In [None]:
# Now stack all 13 Sankey traces in a single portrait figure:

# (A) Gather all dataframes:
sets_phn_sa = {
    f'PhN-SA_v2_{i}': pd.read_csv(
        f'./Giakoumas-et-al/input/PhN/set_{i}.csv'
    )
    for i in range(1, 7)
}
sets_pso = {
    'DCSO':  pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_1.csv'),
    'aPhN1': pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_2.csv'),
    'aPhN2': pd.read_csv('./Giakoumas-et-al/input/aPhN-SA/set_3.csv'),
}
sets_mxlbn = {
    'Sugar/Water': pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/sugar_water_GRNs.csv"),
    'Bitter':      pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/bitter_GRNs.csv"),
    'Ir94e':       pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/Ir94e_GRNs.csv"),
    'Taste Peg':   pd.read_csv("./Giakoumas-et-al/input/MxLbN-SA/taste_peg_GRNs.csv"),
}

all_sets = []
all_sets += list(sets_phn_sa.items())
all_sets += list(sets_pso.items())
all_sets += list(sets_mxlbn.items())

n_plots = len(all_sets)  # 13

# (B) Create subplots (13 rows × 1 column, each type="sankey")
fig = make_subplots(
    rows = n_plots,
    cols = 1,
    shared_xaxes = False,
    shared_yaxes = False,
    vertical_spacing = 0.02,
    specs = [[{"type": "sankey"}] for _ in range(n_plots)],
)

# (C) Add each Sankey trace into its own row
for i, (label, df) in enumerate(all_sets, start=1):
    sankey_trace = make_sankey_trace(
        grn_df         = df,
        title          = label,
        connections    = connections,
        classification = classification,
        min_syn        = 5
    )
    fig.add_trace(sankey_trace, row=i, col=1)

# (D) Layout tweaks for portrait
fig.update_layout(
    height = 300 * n_plots,    # Still 3900 px for 13 plots
    width  = 534,              # One third of original 1600 px
    margin = dict(l=20, r=20, t=40, b=20),
    font   = dict(size=10),
    title  = "All Sankey Panels in One Portrait Figure",
)

fig.show()
# To export as a single SVG:
fig.write_image("./Giakoumas-et-al/output/figures/fig_S3/all_sankeys_portrait.svg")
