In [None]:
import pandas as pd
import scipy as sp
from connectome_interpreter import *

### Test notebook with examples on how to use the external_map.hex_heatmap() function. 

While the column data is not currently saved with the meta data, load a csv with the individual bodyIds from 15 columnar cell types within the medulla assigned to individual columns. Assignment data from Nern et al. 2024.

In [None]:
cols = pd.read_csv('https://raw.githubusercontent.com/YijieYin/connectome_interpreter/refs/heads/main/connectome_interpreter/data/Nern2024/ME-columnar-cells-hex-location.csv')
cols["coords"] = cols["x"].astype(str) + "," + cols["y"].astype(str)
# make longer, remove columns hex1_id, hex2_id, x, y
cols = cols[
    [
        "L1",
        "L2",
        "L3",
        "L5",
        "Mi1",
        "Mi4",
        "Mi9",
        "C2",
        "C3",
        "Tm1",
        "Tm2",
        "Tm4",
        "Tm9",
        "Tm20",
        "T1",
        'coords'
    ]
].melt(id_vars=["coords"], var_name="cell_type", value_name="bodyId")
cols = cols[~cols.bodyId.isna()]
cols['bodyId'] = cols['bodyId'].astype(int)

In [None]:
# map from bodyid to coordinates
bodyid_to_coords = cols.set_index('bodyId')['coords'].to_dict()

Add the column information for individual bodyIds from these 15 cell types to the meta data associated with the optic lobe connectome data set.

In [None]:
meta = pd.read_csv('../../interpret_connectome/data/neuprint_meta_optic.csv')

meta['coords'] = meta.bodyId.map(bodyid_to_coords)

# make dictionaries to map from index to groups
idx_to_coords = meta.set_index('idx')['coords'].to_dict()
idx_to_type = meta.set_index('idx')['cell_type'].to_dict()
idx_to_root = meta.set_index('idx')['bodyId'].to_dict()

In [None]:
inprop = sp.sparse.load_npz('../../interpret_connectome/data/neuprint_inprop_optic.npz')

In [None]:
# Set the target cell type
target_cell_type = 'Tm3'
outidx = meta.idx[meta.cell_type == target_cell_type]

In [None]:
# Set the upstream cell type for which you would like to plot the spatial location of the cells of this type that innervate the target cell type.
upstream_cell_type = 'Mi1'
inidx = meta.idx[meta.cell_type == upstream_cell_type]

### Plot the columns that contain Mi1 neurons that directly innervate a single Tm3 cell (bodyId = 100143)

In [None]:
# Get the direct connectivity data frame from L1 neurons to single Tm3 cells.
# Data frame rows = medulla columns and data frame columns = individual Tm3 bodyIds.
df_per_cell = result_summary(inprop, inidx, outidx, idx_to_coords, idx_to_root,
                    display_threshold=1e-4)

In [None]:
# bid = 0 # individual cell of choice. Refers to column in 'df_per_cell'
# df_to_plot = df_per_cell.iloc[:, bid]

# bid = "100143"
bid = "101077"
df_to_plot = df_per_cell[bid]

fig = hex_heatmap(df_to_plot, sizing=None, style=None)
fig.show()

### Plot the columns that contain Mi1 neurons that directly innervate any Tm3 cell.

In [None]:
df_type = result_summary(inprop, inidx, outidx, idx_to_coords, idx_to_type,
                    display_threshold=1e-4)

In [None]:
print('Mi1 to Tm3 - direct')
df_to_plot = df_type[target_cell_type]
fig = hex_heatmap(df_to_plot, sizing=None, style=None)
fig.show()

In [None]:
# %%capture
# first download the precomputed connectivity
!gdown --folder 1dqICZerhL4cvBHknhu7SAonKJQOBdqd9

### Plot the columns that contain Mi1 neurons that indirectly innervate a single Tm3 cell (bodyId = 100143)

In [None]:
# then load it
# steps_cpu is a list of matrices. The first one is direct connectivity, the second is one hop etc.
steps_cpu = read_precomputed('maleCNS_neuprint_optic_neuron')
# this is the sum of all matrices in steps_cpu

stepsn = add_first_n_matrices(steps_cpu, len(steps_cpu))

In [None]:
df_indirect_per_cell = result_summary(stepsn, inidx, outidx, idx_to_coords, idx_to_root,
                    display_threshold=1e-3)

In [None]:
df_to_plot = df_indirect_per_cell[bid]
fig = hex_heatmap(df_to_plot, sizing=None, style=None)
fig.show()

In [None]:
df_type_indirect = result_summary(stepsn, inidx, outidx, idx_to_coords, idx_to_type,
                    display_threshold=1e-4)

In [None]:
print('Mi1 to Tm3 - indirect')
df_to_plot = df_type_indirect[target_cell_type]
fig = hex_heatmap(df_to_plot, sizing=None, style=None)
fig.show()