# HIERARCHICAL GATING: Two-Stage Sequential Gating Pipeline

This notebook implements a two-stage hierarchical gating strategy:

**Stage 1:** Gate entire population on `avg_speed_um_s` vs `self_intersections` (scaled)
- Creates 4 gates in population-relative space
- Adds columns: `gate_id`, `gate_name`

**Stage 2:** Hierarchically gate a subset on `anomalous_exponent` vs `cum_displacement_um` (unscaled)
- Only evaluates rows from Gate 1's `low_speed_low_intersections` population
- Adds columns: `gate2_id`, `gate2_name`

**Final:** Manual classifier combining gate1_name + gate2_name ‚Üí overall population names


## üé® Setup: Colors & Imports


In [None]:
# Raiders color palette
raiders = ['#140f19', '#ebe1ce', '#ed1e24', '#fbd10c', '#6cc176', '#95d6d7']
raiders_text = ['#140f19', '#ebe1ce']
raiders_colors = ['#ed1e24', '#fbd10c', '#6cc176', '#95d6d7']
raiders_colors2 = ['#ed1e24', '#DE8F05', '#6cc176', '#95d6d7']

colorblind_colors = ['#0173B2', '#DE8F05', '#029E73', '#CC78BC', '#CA9161', 
                     '#FBAFE4', '#949494', '#ECE133', '#56B4E9', '#D55E00']

gate_fill_colors = ['#949494', '#ECE133', '#56B4E9', '#D55E00']


In [None]:
from __future__ import division, unicode_literals, print_function
import sys
sys.path.append('../src')

import SPTnano as spt

master = spt.config.MASTER
saved_data = spt.config.SAVED_DATA

pixelsize_microns = spt.config.PIXELSIZE_MICRONS
time_between_frames = spt.config.TIME_BETWEEN_FRAMES
orderofconditions = spt.config.ORDEROFCONDITIONS
features = spt.config.FEATURES
min_track_length = spt.config.TIME_WINDOW

from IPython.display import Markdown, display
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
import pims
import trackpy as tp
import os
import glob
import nd2
import seaborn as sns

%matplotlib inline

mpl.rc('figure', figsize=(10, 5))
mpl.rc('image', cmap='gray')
sns.set_context("notebook", rc={"xtick.labelsize": 10, "ytick.labelsize": 10})


## Load Data


In [None]:
import polars as pl

driveletter = 'F:'

# Configure your paths
MASTER_DIR = driveletter + "/Analyzed/"
SAVE_DIR = driveletter + "/Analyzed/FIXED_WINDOW_UIDS_20251030_155925"

instant_df = pl.read_parquet(os.path.join(SAVE_DIR, 'master_instant_df_FIXED.parquet'))
windowed_df = pl.read_parquet(os.path.join(SAVE_DIR, 'master_windowed_df_FIXED.parquet'))

# Update drive letter if needed
instant_df = instant_df.with_columns(
    pl.col('foldername').str.replace('Z:', driveletter)
)
windowed_df = windowed_df.with_columns(
    pl.col('foldername').str.replace('Z:', driveletter)
)

print(f"Updated drive letter from Z: to {driveletter} in foldername columns")

# Remove bad fits
print("Before removing bad fits:")
print(f"windowed_df shape: {windowed_df.shape}")

windowed_df = windowed_df.filter(pl.col('bad_fit_flag') == False)

unique_window_uids = windowed_df.select(pl.col('window_uid')).unique()
print(f"Number of unique window_uids: {unique_window_uids.shape[0]}")

instant_df = instant_df.filter(pl.col('window_uid').is_in(unique_window_uids['window_uid']))

print(f"\n‚úÖ Data loaded:")
print(f"   Windowed: {windowed_df.shape}")
print(f"   Instant: {instant_df.shape}")


### (Optional) Gallery view of curated tracks for Gate 1

**Note:** This section is for visualizing a small curated subset of tracks to understand behaviors before gating.
Skip if you want to go straight to gating the full population.


In [None]:
# Optional: Create curated subset with behavior labels for visualization
# Skip this if you don't have pre-labeled tracks

# Example structure (customize with your window_UIDs):
chosen_window_UIDs = {
    "transport": ['eekrw_131_7338_R1_1_2849.0_2908.0', 
'eekrw_132_7896_R1_10_301.0_360.0', 'eekrw_132_7896_R1_11_331.0_390.0', 'eekrw_132_7896_R1_9_271.0_330.0',
 'eekrw_72_27796_R1_4_2092.0_2151.0', 'eekrw_72_27796_R1_5_2122.0_2181.0',
'eekrw_64_26280_R1_0_2401.0_2461.0',
'eekrw_119_3224_R1_5_151.0_210.0', 'eekrw_119_3224_R1_6_181.0_240.0',
'eekrw_122_4897_R1_0_3896.0_3955.0', 'eekrw_122_4897_R1_1_3926.0_3985.0', 'eekrw_122_4897_R1_2_3956.0_4015.0',
'eekrw_124_5697_R1_0_822.0_882.0', 'eekrw_124_5697_R1_1_853.0_912.0',
# 'eekrw_124_5718_R1_0_1196.0_1255.0',
 'eekrw_124_5718_R1_1_1226.0_1285.0',
'eekrw_64_26137_R1_0_3723.0_3784.0',
#  'eekrw_64_26137_R1_1_3755.0_3814.0', 
# 'eekrw_64_26137_R1_2_3785.0_3844.0',
'eekrw_124_5697_R1_0_822.0_882.0', 'eekrw_124_5697_R1_1_853.0_912.0',
'eekrw_63_26099_R1_0_2180.0_2239.0', 
'eekrw_123_5096_R1_0_1923.0_1983.0'],
    "bound": ['eemrw_33_3980_R1_10_1876.0_1936.0', 'eemrw_33_3980_R1_8_1816.0_1875.0', 'eemrw_33_3980_R1_9_1846.0_1905.0',
'eekrw_64_26165_R1_0_4231.0_4294.0', 'eekrw_64_26165_R1_1_4263.0_4327.0',
'eemrw_33_3513_R1_16_3418.0_3477.0', 'eemrw_33_3513_R1_17_3448.0_3507.0', 'eemrw_33_3513_R1_18_3478.0_3537.0',
'eemrw_42_5444_R1_23_996.0_1055.0', 'eemrw_42_5444_R1_24_1026.0_1085.0', 'eemrw_42_5444_R1_25_1056.0_1115.0',
'eemrw_34_4506_R1_0_2043.0_2106.0', 'eemrw_34_4506_R1_1_2077.0_2137.0',
'eekrw_77_28613_R1_35_3643.0_3702.0', 'eekrw_77_28613_R1_36_3673.0_3732.0', 'eekrw_77_28613_R1_37_3703.0_3762.0',
'eemrw_34_4188_R1_1_5333.0_5392.0',
'eemrw_33_3399_R1_3_2262.0_2322.0',
'eemrw_42_5293_R1_23_691.0_750.0', 'eemrw_42_5293_R1_24_721.0_780.0', 'eemrw_42_5293_R1_25_751.0_810.0',
'eekrw_66_26462_R1_10_3647.0_3706.0', 'eekrw_66_26462_R1_11_3677.0_3736.0', 'eekrw_66_26462_R1_12_3707.0_3766.0',
'eeh2h_eeh2x_22_3800_R2_R2_1_5583.0_5642.0', 'eeh2h_eeh2x_22_3800_R2_R2_2_5613.0_5672.0', 'eeh2h_eeh2x_22_3800_R2_R2_3_5643.0_5702.0'],
    "transient": ['eeh2h_38_12054_R1_0_2432.0_2493.0',
'eekrw_123_5265_R1_1_5141.0_5200.0',
'eekrw_73_28118_R1_0_1708.0_1768.0', 'eekrw_73_28118_R1_1_1739.0_1798.0', 'eekrw_73_28118_R1_2_1769.0_1828.0',
'eeh2h_30_10132_R1_0_2059.0_2118.0',
'eemrw_37_4775_R1_0_2880.0_2950.0', 'eemrw_37_4775_R1_1_2919.0_2983.0',
'eemrw_34_4032_R1_0_2811.0_2872.0',
'eemrw_33_3474_R1_0_2638.0_2699.0',
'eemrw_40_5120_R1_0_3095.0_3155.0',
'eeh2h_eeh2x_31_7315_R2_R2_0_5427.0_5492.0',
'eeh2h_eeh2x_4_10006_R1_R1_0_5714.0_5773.0']}

# Create mapping and filter
uid_to_category = {}
for category, uids in chosen_window_UIDs.items():
    for uid in uids:
        uid_to_category[uid] = category

all_chosen_uids = list(uid_to_category.keys())

filtered_windowed_df = windowed_df.filter(pl.col('window_uid').is_in(all_chosen_uids))
filtered_windowed_df = filtered_windowed_df.with_columns(
    pl.col('window_uid').replace(uid_to_category, default=None).alias('behavior_type')
)

filtered_instant_df = instant_df.filter(pl.col('window_uid').is_in(all_chosen_uids))
filtered_instant_df = filtered_instant_df.with_columns(
    pl.col('window_uid').replace(uid_to_category, default=None).alias('behavior_type')
)

print(f"Filtered to {len(filtered_windowed_df)} curated tracks")
print(filtered_windowed_df.group_by('behavior_type').agg(pl.len().alias('count')))


In [None]:
# Gallery visualization of curated tracks (if using filtered subset)
saved_dir = 'F:/plots/finals/'

# Uncomment and run if you have a curated subset:
gallery_result = spt.gallery_of_tracks_v4(
    filtered_instant_df,
    color_by="behavior_type",
    num_tracks=25,
    order=['transport', 'bound', 'transient'],
    custom_colors=raiders_colors2,
    track_length_frames=60,
    spacing_factor=1.0,
    line_width=1.0,
    figsize=(12, 8),
    text_size=5,
    show_annotations=True,
    annotation="{window_uid}",
    annotation_color="w",
    transparent_background=True,
    save_path=saved_dir + "gallery_curated_tracks_gate1.svg"
)

print(f"Gallery created with {gallery_result['total_tracks']} tracks")


In [None]:
# Cell 6: Markdown
## STAGE 1: First Gating (Speed vs Intersections)
### Fit scaler on full population

# Cell 7: Python
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.pipeline import Pipeline

feat_x_gate1 = 'avg_speed_um_s'
feat_y_gate1 = 'self_intersections'

scaler_gate1 = Pipeline([
    ('std', StandardScaler()),
    ('minmax', MinMaxScaler())
])

scaler_gate1.fit(windowed_df[[feat_x_gate1, feat_y_gate1]])

print("="*80)
print("‚úÖ SCALER FITTED ON FULL POPULATION")
print("="*80)
print(f"   Features: {feat_x_gate1}, {feat_y_gate1}")
print(f"   Total tracks: {len(windowed_df):,}")

# Cell 8: Python - Create ROI Manager
roi_manager_gate1 = spt.ROIManager(
    df=windowed_df,
    x_col=feat_x_gate1,
    y_col=feat_y_gate1,
    scaler=scaler_gate1
)

print(f"‚úÖ Gate 1 ROIManager created for {len(windowed_df):,} tracks")

# Cell 9: Python - Define Gates
roi_manager_gate1.clear_gates()

roi_manager_gate1.add_rectangle_gate(
    x_min=0.0, x_max=0.2, y_min=0.0, y_max=0.14,
    name="low_speed_low_intersections"
)

roi_manager_gate1.add_rectangle_gate(
    x_min=0.2, x_max=1.0, y_min=0.0, y_max=0.14,
    name="high_speed_low_intersections"
)

roi_manager_gate1.add_rectangle_gate(
    x_min=0.0, x_max=0.3, y_min=0.14, y_max=1.0,
    name="low_speed_high_intersections"
)

roi_manager_gate1.add_rectangle_gate(
    x_min=0.3, x_max=1.0, y_min=0.14, y_max=1.0,
    name="high_speed_high_intersections"
)

print("‚úÖ Gate 1 gates defined")

# Cell 10: Python - Apply Gate 1
windowed_df_classified = roi_manager_gate1.classify_data(
    windowed_df,
    x_col=feat_x_gate1,
    y_col=feat_y_gate1
)

summary_gate1 = roi_manager_gate1.get_gate_summary(windowed_df_classified)

print("\n" + "="*80)
print("‚úÖ GATE 1: FULL DATASET CLASSIFIED")
print("="*80)
print(f"Total tracks: {summary_gate1['total_points']:,}")
print(f"\nüìä Gate 1 breakdown:")
for gate in summary_gate1['gates']:
    print(f"   {gate['name']:35s}: {gate['count']:8,} tracks ({gate['percent']:5.1f}%)")

In [None]:
# Cell 11: Markdown
## üéØ STAGE 2: Hierarchical Gating (Anomalous Exponent vs Displacement)
### Only gate the `low_speed_low_intersections` population

# Cell 12: Python - Setup Gate 2
feat_x_gate2 = 'anomalous_exponent'
feat_y_gate2 = 'cum_displacement_um'

parent_gate_name = 'low_speed_low_intersections'
parent_gate_id = 0

parent_pop = windowed_df_classified.filter(pl.col('gate_id') == parent_gate_id)
print(f"Parent population: {len(parent_pop):,} tracks")
# print(f"Parent gate name(s): {parent_pop['gate_name'].unique()}")

# Cell 13: Python - Create ROI Manager for Gate 2
roi_manager_gate2 = spt.ROIManager(
    df=windowed_df_classified,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    scaler=None  # Unscaled!
)

print("‚úÖ Gate 2 ROIManager created (unscaled mode)")

# Cell 14: Python - Define Gate 2 polygons
coords_high_a_low_d = [
    # (1.1, 0.0), (2.25, 7.0), (3.0, 7.0), (3.0, 0.0),
    (1.1, 0.0), (2.25, 8.0), (3.0, 8.0), (3.0, 0.0),
]
roi_manager_gate2.add_polygon_gate(coords_high_a_low_d, name='high_a_low_displacement')

coords_low_a_high_d = [
    # (0.0, 0.0), (0.0, 7.0), (2.25, 7.0), (1.1, 0.0),
    (-0.5, 0.0), (-0.5, 8.0), (2.25, 8.0), (1.1, 0.0),
]
roi_manager_gate2.add_polygon_gate(coords_low_a_high_d, name='low_a_high_displacement')

print("‚úÖ Gate 2 gates defined")

# Cell 15: Python - Apply Gate 2 HIERARCHICALLY
windowed_df_classified = roi_manager_gate2.classify_data(
    windowed_df_classified,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    gate_col_name='gate2_id',
    gate_name_col='gate2_name',
    parent_gate_col='gate_id',          # ‚Üê Hierarchical!
    parent_gate_ids=parent_gate_id,     # ‚Üê Only gate parent
    not_applicable_label='not_in_parent_gate'
)

print("\n‚úÖ GATE 2: HIERARCHICAL GATING COMPLETE")
print(windowed_df_classified.group_by('gate2_name').agg(pl.len().alias('count')).sort('count', descending=True))

In [None]:
# Cell 21:  Map to instant_df
columns_to_map = ['gate_id', 'gate_name', 'gate2_id', 'gate2_name']

gate_mapping_df = windowed_df_classified[['window_uid'] + columns_to_map]

instant_df_classified = instant_df.join(
    gate_mapping_df,
    on='window_uid',
    how='left'
)

print("‚úÖ Mapped gates to instant_df")

# Cell 22: Markdown
## üè∑Ô∏è FINAL: Manual Population Classifier

# Cell 23: Python - Check combinations
crosstab = windowed_df_classified.group_by(['gate_name', 'gate2_name']).agg(
    pl.len().alias('count')
).sort('count', descending=True)
print("üìä Gate combinations:")
print(crosstab)

# Cell 24: Python - Manual classifier function
def classify_population(gate_name, gate2_name):
    """Combine both gates into final population names"""
    if gate_name == 'low_speed_low_intersections':
        if gate2_name == 'high_a_low_displacement':
            return 'superdiffusive_transport'
        elif gate2_name == 'low_a_high_displacement':
            return 'subdiffusive_motion'
        elif gate2_name == 'ungated': 
            return 'ungatedgate2'
    elif gate_name == 'high_speed_low_intersections':
        return 'fast_exploratory'
    elif gate_name == 'low_speed_high_intersections':
        return 'bound_stationary'
    elif gate_name == 'high_speed_high_intersections':
        return 'fast_exploratory'
    elif gate_name == 'ungated':
        return 'ungatedgate1'
    return 'other'

# Apply
windowed_df_final = windowed_df_classified.with_columns(
    pl.struct(['gate_name', 'gate2_name'])
    .map_elements(lambda x: classify_population(x['gate_name'], x['gate2_name']), return_dtype=pl.Utf8)
    .alias('final_population')
)

instant_df_final = instant_df_classified.with_columns(
    pl.struct(['gate_name', 'gate2_name'])
    .map_elements(lambda x: classify_population(x['gate_name'], x['gate2_name']), return_dtype=pl.Utf8)
    .alias('final_population')
)

print("‚úÖ Final populations:")
print(windowed_df_final.group_by('final_population').agg(pl.len().alias('count')).sort('count', descending=True))

# Cell 25: Python - Save
OUTPUT_DIR = driveletter + "/Analyzed/HIERARCHICAL_GATES_" + pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")
os.makedirs(OUTPUT_DIR, exist_ok=True)

windowed_df_final.write_parquet(os.path.join(OUTPUT_DIR, 'windowed_df_hierarchical_gates.parquet'))
instant_df_final.write_parquet(os.path.join(OUTPUT_DIR, 'instant_df_hierarchical_gates.parquet'))

print(f"‚úÖ SAVED to {OUTPUT_DIR}")

In [None]:
# to select cortical neurons, looking at HTT 150, in the CB.

types = ['ES']
instant_df_classified = instant_df_classified.filter(pl.col('type').is_in(types))
windowed_df_classified = windowed_df_classified.filter(pl.col('type').is_in(types))

locations = ['ES']
instant_df_classified = instant_df_classified.filter(pl.col('location').is_in(locations))
windowed_df_classified = windowed_df_classified.filter(pl.col('location').is_in(locations))

# groups = ['HTT150']
# p1_instant_df = p1_instant_df.filter(pl.col('group').is_in(groups))
# p1_windowed_df = p1_windowed_df.filter(pl.col('group').is_in(groups))

mols = [ 'HTT', 'kinesin','myosin']
instant_df_classified = instant_df_classified.filter(pl.col('mol').is_in(mols))
windowed_df_classified = windowed_df_classified.filter(pl.col('mol').is_in(mols))

genos = ['20H20S', 'RUES2']
instant_df_classified = instant_df_classified.filter(pl.col('geno').is_in(genos))
windowed_df_classified = windowed_df_classified.filter(pl.col('geno').is_in(genos))

# gates = ['low_speed_low_intersections']
# instant_df_classified = instant_df_classified.filter(pl.col('gate_name').is_in(gates))
# windowed_df_classified = windowed_df_classified.filter(pl.col('gate_name').is_in(gates))

In [None]:
#  Configure Polars display options for full output
pl.Config.set_tbl_rows(-1)  # Show all rows (no limit)
pl.Config.set_tbl_cols(-1)  # Show all columns (no limit)
pl.Config.set_tbl_width_chars(1000)  # Increase table width


# Count unique cells per location and condition
location_cell_counts = (instant_df_classified
    .group_by(['mol', 'cell', 'type', 'location','geno', 'group', 'replicate', 'gate_name'])
    .agg([
        pl.col('filename').n_unique().alias('n_cells'),
        pl.col('unique_id').n_unique().alias('n_tracks'),
        pl.len().alias('n_datapoints')
    ])
    .sort(['mol','cell', 'type', 'location','geno', 'group', 'replicate', 'gate_name'])
)

print("Number of cells per location and condition:")
print(location_cell_counts)

# # Export to CSV
# location_cell_counts.write_csv(driveletter + '/plots/' + "location_cell_counts.csv")
# print("Exported location_cell_counts to 'location_cell_counts.csv'")


---

## üìä GATE 1 VISUALIZATIONS




In [None]:
### Plot 1: Scatter plot of curated tracks (if using filtered subset) with Gate 1 overlaid

# Setup
saved_dir = 'F:/plots/finals/'
order = ['transport', 'bound', 'transient']  # Adjust to your behavior types
gateorder_gate1 = ['low_speed_low_intersections', 'high_speed_low_intersections', 
                   'low_speed_high_intersections', 'high_speed_high_intersections']

# Scatter plot with curated subset (if you have one)
fig, ax = spt.plot_xy_heatmap(
    filtered_windowed_df,  # Or use windowed_df_classified for full dataset
    x_col=feat_x_gate1,
    y_col=feat_y_gate1,
    plot_type='scatter',
    color_by='behavior_type',  # Or 'mol' for full dataset
    order=order,
    small_multiples=False,
    contour_levels=None,
    contour_cmap='viridis',
    cmap='colorblind',
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir,
    figsize=(8, 6),
    s=30,
    alpha=0.9,
    log_scale=False,
    scale_data=True,
    scaler=scaler_gate1,
    xlim=(0, 1),
    ylim=(0, 1),
    scale_method='standard_minmax',
    dpi=300,
)

print("‚úÖ Gate 1 scatter plot created")



In [None]:
### Plot 2: Scatter plot with Gate 1 gates overlaid

# Scatter plot with gates shown
gateorder_gate1 = ['low_speed_low_intersections', 'high_speed_low_intersections', 
                   'low_speed_high_intersections', 'high_speed_high_intersections']

fig, ax = spt.plot_xy_heatmap(
    filtered_windowed_df,  # Or windowed_df_classified
    x_col=feat_x_gate1,
    y_col=feat_y_gate1,
    plot_type='scatter',
    color_by='behavior_type',  # Or 'mol'
    small_multiples=False,
    contour_levels=None,
    contour_cmap='viridis',
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir,
    figsize=(8, 6.2),
    xlim=(0, 1),
    ylim=(0, 1),
    s=30,
    alpha=0.9,
    order=order,
    log_scale=False,
    scale_data=True,
    scaler=scaler_gate1,
    cmap='colorblind',
    scale_method='standard_minmax',
    dpi=300,
    # Gate styling
    gates=roi_manager_gate1,
    gate_order=gateorder_gate1,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.1,
    gate_edge_alpha=1.0,
    gate_linewidth=1.3,
    gate_label_position='auto',
    gate_text_size=10,
    gate_label_border=False
)

print("‚úÖ Gate 1 scatter plot with gates overlaid")

In [None]:
# INSERT HERE FILTERING CELL FOR JUST ES CELLS

FILTERING DATA FOR VISUALIZATION OR WHATEVER


In [None]:
# to select cortical neurons, looking at HTT 150, in the CB.

types = ['ES']
instant_df_classified = instant_df_classified.filter(pl.col('type').is_in(types))
windowed_df_classified = windowed_df_classified.filter(pl.col('type').is_in(types))

locations = ['ES']
instant_df_classified = instant_df_classified.filter(pl.col('location').is_in(locations))
windowed_df_classified = windowed_df_classified.filter(pl.col('location').is_in(locations))

# groups = ['HTT150']
# p1_instant_df = p1_instant_df.filter(pl.col('group').is_in(groups))
# p1_windowed_df = p1_windowed_df.filter(pl.col('group').is_in(groups))

mols = [ 'HTT', 'kinesin','myosin']
instant_df_classified = instant_df_classified.filter(pl.col('mol').is_in(mols))
windowed_df_classified = windowed_df_classified.filter(pl.col('mol').is_in(mols))

genos = ['20H20S', 'RUES2']
instant_df_classified = instant_df_classified.filter(pl.col('geno').is_in(genos))
windowed_df_classified = windowed_df_classified.filter(pl.col('geno').is_in(genos))

# gates = ['low_speed_low_intersections']
# instant_df_classified = instant_df_classified.filter(pl.col('gate_name').is_in(gates))
# windowed_df_classified = windowed_df_classified.filter(pl.col('gate_name').is_in(gates))

In [None]:
#  Configure Polars display options for full output
pl.Config.set_tbl_rows(-1)  # Show all rows (no limit)
pl.Config.set_tbl_cols(-1)  # Show all columns (no limit)
pl.Config.set_tbl_width_chars(1000)  # Increase table width


# Count unique cells per location and condition
location_cell_counts = (instant_df_classified
    .group_by(['mol', 'cell', 'type', 'location','geno', 'group', 'replicate', 'gate_name'])
    .agg([
        pl.col('filename').n_unique().alias('n_cells'),
        pl.col('unique_id').n_unique().alias('n_tracks'),
        pl.len().alias('n_datapoints')
    ])
    .sort(['mol','cell', 'type', 'location','geno', 'group', 'replicate', 'gate_name'])
)

print("Number of cells per location and condition:")
print(location_cell_counts)

# # Export to CSV
# location_cell_counts.write_csv(driveletter + '/plots/' + "location_cell_counts.csv")
# print("Exported location_cell_counts to 'location_cell_counts.csv'")


In [None]:
### Plot 3: Hexbin of FULL classified dataset with Gate 1 gates

# Hexbin showing full population density with gates
fig, ax = spt.plot_xy_heatmap(
    windowed_df_classified,
    x_col=feat_x_gate1,
    y_col=feat_y_gate1,
    plot_type='hexbin',
    color_by='mol',  # Or gate_name to color by gates
    small_multiples=False,
    contour_levels=None,
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir,
    figsize=(8, 6),
    s=0.5,
    alpha=0.4,
    log_scale=False,
    scale_data=True,
    scaler=scaler_gate1,
    xlim=(0, 1),
    ylim=(0, 1),
    scale_method='standard_minmax',
    dpi=300,
    # Gate styling
    gates=roi_manager_gate1,
    gate_order=gateorder_gate1,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.1,
    gate_edge_alpha=1.0,
    gate_linewidth=1.3,
    gate_label_position='auto',
    gate_text_size=10,
    gate_label_border=False
)

print(f"‚úÖ Hexbin plot created showing {len(windowed_df_classified):,} tracks")
print(f"   Colored by: mol")
print(f"   Gates: {len(roi_manager_gate1.rois)} overlaid")



---

## GATE 2 VISUALIZATIONS




In [None]:
### Plot 4: Interactive Gate 2 visualization (optional - for manual gate adjustment)

# Optional: Use interactive plotting to adjust Gate 2 gates
# This creates an interactive Plotly figure

# Filter to parent gate population
parent_pop_df = windowed_df_classified.filter(pl.col('gate_id') == parent_gate_id)

fig, roi_manager_gate2_interactive = spt.interactive_roi_gating_with_capture(
    parent_pop_df,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    color_by='mol',
    scale_data=False,  # Unscaled
    scaler=None,
    point_size=2,
    opacity=0.2,
    contour_density=False,
    n_contours=10,
    height=700,
    width=900
)

display(fig)

print("üí° Use this to visually check your Gate 2 polygons")
print("   You can draw new polygons and use roi_manager_gate2_interactive.capture_rois()")


In [None]:
### Plot 5: Scatter plot of parent population with Gate 2 gates overlaid


# Scatter plot showing ONLY the parent gate population with Gate 2 gates
parent_pop_df = windowed_df_classified.filter(pl.col('gate_id') == parent_gate_id)

fig, ax = spt.plot_xy_heatmap(
    parent_pop_df,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    plot_type='scatter',
    color_by='gate2_name',  # Color by Gate 2 assignments
    small_multiples=False,
    contour_levels=None,
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir,
    figsize=(8, 6),
    s=10,
    alpha=0.6,
    log_scale=False,
    scale_data=False,  # Unscaled for Gate 2
    dpi=300,
    # Gate styling
    gates=roi_manager_gate2,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.1,
    gate_edge_alpha=1.0,
    gate_linewidth=1.3,
    gate_label_position='auto',
    gate_text_size=10,
    gate_label_border=False
)

print(f"‚úÖ Gate 2 scatter plot created")
print(f"   Showing parent population: {len(parent_pop_df):,} tracks")
print(f"   Gates: {len(roi_manager_gate2.rois)} overlaid (unscaled)")



In [None]:
### Plot 6: Hexbin of parent population showing Gate 2 density

# Hexbin density plot of parent gate with Gate 2 gates overlaid
parent_pop_df = windowed_df_classified.filter(pl.col('gate_id') == parent_gate_id)

fig, ax = spt.plot_xy_heatmap(
    parent_pop_df,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    plot_type='hexbin',
    color_by=None,  # Density only
    small_multiples=False,
    contour_levels=None,
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir,
    figsize=(8, 6),
    alpha=0.8,
    log_scale=True,  # Log scale to see density variations
    scale_data=False,
    dpi=300,
    # Gate styling
    gates=roi_manager_gate2,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.05,
    gate_edge_alpha=1.0,
    gate_linewidth=1.5,
    gate_label_position='auto',
    gate_text_size=10,
    gate_label_border=False
)

print(f"‚úÖ Gate 2 hexbin density plot created")



---

## FINAL POPULATION VISUALIZATIONS

**Visualize the final classified populations after hierarchical gating:**


In [None]:
### Plot 7: Combined view - both gates on same plot

# Create figure with both Gate 1 and Gate 2 shown
# Use subplots to show Gate 1 space and Gate 2 space side by side

import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Left: Gate 1 in scaled space
fig1, ax1 = spt.plot_xy_heatmap(
    windowed_df_final,
    x_col=feat_x_gate1,
    y_col=feat_y_gate1,
    plot_type='scatter',
    color_by='final_population',  # Color by final classification
    small_multiples=False,
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir + "gate1_final_populations",
    figsize=(8, 6),
    s=5,
    alpha=0.5,
    log_scale=False,
    scale_data=True,
    scaler=scaler_gate1,
    xlim=(0, 1),
    ylim=(0, 1),
    dpi=300,
    gates=roi_manager_gate1,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.05,
    gate_edge_alpha=0.8,
    gate_linewidth=1.2
)

print("‚úÖ Gate 1 view with final populations")


In [None]:

# Right: Gate 2 in unscaled space (parent gate only)
parent_pop_final = windowed_df_final.filter(pl.col('gate_id') == parent_gate_id)

fig2, ax2 = spt.plot_xy_heatmap(
    parent_pop_final,
    x_col=feat_x_gate2,
    y_col=feat_y_gate2,
    plot_type='scatter',
    color_by='final_population',
    small_multiples=False,
    transparent_background=True,
    line_color='k',
    export_format='svg',
    save_path=saved_dir + "gate2_final_populations",
    figsize=(8, 6),
    s=10,
    alpha=0.6,
    log_scale=False,
    scale_data=False,
    dpi=300,
    gates=roi_manager_gate2,
    gate_colors=raiders_colors2,
    gate_linestyle='--',
    gate_alpha=0.1,
    gate_edge_alpha=1.0,
    gate_linewidth=1.3
)

print("‚úÖ Gate 2 view with final populations")



In [None]:
### Plot 8: Gallery of final populations

# Gallery showing representative tracks from each final population

# First, select specific populations you want to visualize
populations_to_show = ['confined_bound', 'constrained_diffusive', 'fast_directed']

# Filter instant_df for these populations
gallery_df = instant_df_final.filter(pl.col('final_population').is_in(populations_to_show))

# Create gallery
gallery_result = spt.gallery_of_tracks_v4(
    gallery_df,
    color_by="final_population",
    num_tracks=30,  # Tracks per population
    order=populations_to_show,
    custom_colors=colorblind_colors[:len(populations_to_show)],
    track_length_frames=60,
    spacing_factor=1.0,
    line_width=1.0,
    figsize=(15, 10),
    text_size=5,
    show_annotations=True,
    annotation="{window_uid}",
    annotation_color="w",
    transparent_background=True,
    save_path=saved_dir + "gallery_final_populations.svg"
)

print(f"‚úÖ Gallery created with {gallery_result['total_tracks']} tracks")
print(f"   Populations: {gallery_result['categories']}")
print(f"   Category counts: {gallery_result['category_counts']}")



In [None]:
### Plot 9: Summary statistics plots

# Bar plot showing population distributions
import matplotlib.pyplot as plt
import seaborn as sns

# Get distribution
dist = windowed_df_final.group_by('final_population').agg(
    pl.len().alias('count')
).sort('count', descending=True).to_pandas()

# Create bar plot
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.barh(dist['final_population'], dist['count'], color=colorblind_colors[:len(dist)])

# Add count labels
for i, (pop, count) in enumerate(zip(dist['final_population'], dist['count'])):
    ax.text(count + max(dist['count'])*0.01, i, f'{count:,}', 
            va='center', fontsize=10)

ax.set_xlabel('Number of Tracks', fontsize=12)
ax.set_title('Final Population Distribution', fontsize=14, fontweight='bold')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig(saved_dir + 'final_population_barplot.svg', dpi=300, bbox_inches='tight')
plt.show()

print("‚úÖ Population distribution bar plot created")

