In [None]:
# !pip install connectome-interpreter --no-deps

In [None]:
import pandas as pd
import scipy as sp
from connectome_interpreter import *
import plotly.graph_objects as go
import numpy as np

In [None]:
cols = pd.read_csv('https://raw.githubusercontent.com/YijieYin/connectome_interpreter/refs/heads/main/connectome_interpreter/data/Nern2024/ME-columnar-cells-hex-location.csv')
cols["coords"] = cols["x"].astype(str) + "," + cols["y"].astype(str)
# make longer, remove columns hex1_id, hex2_id, x, y
cols = cols[
    [
        "L1",
        "L2",
        "L3",
        "L5",
        "Mi1",
        "Mi4",
        "Mi9",
        "C2",
        "C3",
        "Tm1",
        "Tm2",
        "Tm4",
        "Tm9",
        "Tm20",
        "T1",
        'coords'
    ]
].melt(id_vars=["coords"], var_name="cell_type", value_name="bodyId")
cols = cols[~cols.bodyId.isna()]
cols['bodyId'] = cols['bodyId'].astype(int)
cols

In [None]:
# map from bodyid to coordinates
bodyid_to_coords = cols.set_index('bodyId')['coords'].to_dict()

In [None]:
meta = pd.read_csv('/Users/burnettl/Documents/GitHub/interpret_connectome/data/neuprint_meta_optic.csv')

meta['coords'] = meta.bodyId.map(bodyid_to_coords)

# make dictionaries to map from index to groups
idx_to_coords = meta.set_index('idx')['coords'].to_dict()
idx_to_type = meta.set_index('idx')['cell_type'].to_dict()
idx_to_root = meta.set_index('idx')['bodyId'].to_dict()

In [None]:
inprop = sp.sparse.load_npz('/Users/burnettl/Documents/GitHub/interpret_connectome/data/neuprint_inprop_optic.npz')

In [None]:
inidx = meta.idx[meta.cell_type == 'L1']
outidx = meta.idx

df = result_summary(inprop, inidx, outidx, idx_to_type, idx_to_type,
                    sort_within = 'row', threshold_axis = 'column',
                    # display_threshold=0
                    )

In [None]:
inidx = meta.idx[meta.cell_type == 'L1']
outidx = meta.idx[meta.cell_type == 'Tm3']

df = result_summary(inprop, inidx, outidx, idx_to_coords, idx_to_root,
                    display_threshold=1e-4)

In [None]:
df_to_plot = df.iloc[:, 0]
df_to_plot

In [None]:
def hex_plot_from_series(
    df:pd.Series
    , style:dict=None
    , sizing:dict=None
) -> go.Figure:
    """
    Generate a hexagonal heat map plot of a data frame column value per ROI column in the right medulla.

    Parameters
    ----------
    df : pd.Series
        A Series where the index is formatted as 'x,y' coordinates and values represent data to plot.
    style : dict, default=None
        Dict containing styling formatting variables.
    sizing : dict, default=None
        Dict containing size formatting variables.

    Returns
    -------
    fig : go.Figure
    """
    # Default styling and sizing parameters to use if not specified.
    if style is None:
        style = {
            "font_type": "arial",
            "markerlinecolor": "rgba(0,0,0,0)", #transparent
            "linecolor": "black",
            "papercolor": "rgba(255,255,255,255)"
        }

    if sizing is None:
        sizing = {
            "fig_width": 250,  # units = mm, max 180
            "fig_height": 210,  # units = mm, max 170
            "fig_margin": 0,
            "fsize_ticks_pt": 20,
            "fsize_title_pt": 20,
            "markersize": 18,
            "ticklen": 15,
            "tickwidth": 5,
            "axislinewidth": 3,
            "markerlinewidth": 0.9,
            "cbar_thickness": 20,
            "cbar_len": 0.75,
        }

    # sizing of the figure and font
    pixelsperinch = 72 # for svg and pdf
    pixelspermm = pixelsperinch / 25.4
    area_width = (sizing["fig_width"] - sizing["fig_margin"]) * pixelspermm
    area_height = (sizing["fig_height"] - sizing["fig_margin"]) * pixelspermm
    fsize_ticks_px = sizing["fsize_ticks_pt"] * (1 / 72) * pixelsperinch
    fsize_title_px = sizing["fsize_title_pt"] * (1 / 72) * pixelsperinch

    # Convert index values (formatted as '-12,34') into separate x and y coordinates
    coords = [tuple(map(int, idx.split(","))) for idx in df.index]
    x_vals, y_vals = zip(*coords)  # Separate into x and y lists

    # initiate plot
    fig = go.Figure()
    fig.update_layout(
        autosize=False
      , height=area_height
      , width=area_width
      , margin={"l": 0, "r": 0, "b": 0, "t": 0, "pad": 0}
      , paper_bgcolor=style["papercolor"]
      , plot_bgcolor=style["papercolor"]
    )
    fig.update_xaxes(
        showgrid=False, showticklabels=False, showline=False, visible=False
    )
    fig.update_yaxes(
        showgrid=False, showticklabels=False, showline=False, visible=False
    )
    # Symbol number to choose to plot hexagons
    symbol_number = 15

    # Get the coordinates of all columns in the medulla:
    col_coords = pd.read_csv("/Users/burnettl/Documents/GitHub/connectome_interpreter/connectome_interpreter/data/Nern2024/ME-column-coords.csv")
    
    # Add empty white 'background' hexagons - all neuropils
    fig.add_trace(
        go.Scatter(
            x=col_coords["x"]
          , y=col_coords["y"]
          , mode="markers"
          , marker_symbol=symbol_number
          , marker={
                "size": sizing["markersize"]
              , "color": "white"
              , "line": {"width": sizing["markerlinewidth"], "color": "lightgrey"}
            }
          , showlegend=False
        )
    )

    # Add data
    fig.add_trace(
        go.Scatter(
            x=x_vals
          , y=y_vals
          , mode="markers"
          , marker_symbol=symbol_number
          , marker={
              "cmin": 0.1
              , "cmax": df.values.max()
              , "size": sizing["markersize"]
              , "color": df.values
              , "line": {
                    "width": sizing["markerlinewidth"]
                  , "color": style["markerlinecolor"]
                }
              , "colorbar": {
                    "orientation": "v"
                  , "outlinecolor": style["linecolor"]
                  , "outlinewidth": sizing["axislinewidth"]
                  , "thickness": sizing["cbar_thickness"]
                  , "len": sizing["cbar_len"]
                  , "tickmode": "array"
                  , "ticklen": sizing["ticklen"]
                  , "tickwidth": sizing["tickwidth"]
                  , "tickcolor": style["linecolor"]
                  , "tickfont": {
                        "size": fsize_ticks_px
                      , "family": style["font_type"]
                      , "color": style["linecolor"]
                    }
                  , "title": {
                        "font": {
                            "family": style["font_type"]
                          , "size": fsize_title_px
                          , "color": style["linecolor"]
                        }
                      , "side": "right"
                    }
                }
              , "colorscale": 'blues'
            }
          , showlegend=False
        )
    )

    return fig

### Example use case 1 - columnar position of L1 cells that directly give input to a single Tm3 cell (bodyId = 100143)

In [None]:
fig = hex_plot_from_series(df_to_plot, sizing=None, style=None)
fig.show()

### Example use case 2 - columnar position of L1 cells that directly give input to all of the Tm3 cells.


In [None]:
df_to_plot_all_cells = df.mean(axis=1)
df_to_plot_all_cells

In [None]:
fig = hex_plot_from_series(df_to_plot_all_cells, sizing=None, style=None)
fig.show()
# mean 

In [None]:
fig = hex_plot_from_series(df_to_plot_all_cells, sizing=None, style=None)
fig.show()
# sum

In [None]:
# 3 - Would also want to plot for a single Tm3 cell the columns that it recevies input from from all of the columnar types.