### Testing "local" vs. "global" effect
It is possible that the nearest neighbor being connected merely indicates that the presynaptic neuron has a high out-degree, which then increases the connection probabilities with _all_ other neurons. We call this the "global" morphology effect.
Alternatively, there can be a spatially limited effect, where the connection probability is increased in a specific area around the nearest neighbor. We call this a "local" morphology effect. 

The total effect can be a mixture of both, and in this notebook we conduct further analyses to further separate them. 

To that end we conduct an analysis of connection probabilities not just against distance, but against horizontal and vertical offsets from the pre-/post-synaptic neuron. 
The global effect will have no specific spatial structure, i.e., the increase in connection probability if the nearest neighbor is connected should be the same at each offset.
Additionally, we analyze also the connectivity of a **configuration model** fitted against the data. That is a model that shuffles the locations of connections, but preserves the in- and out-degrees of all neurons. As the global effect is simply based on some neurons having a higher degree than others, the configuration model will capture it completely.

## Setup
We begin by importing relevant packages and setting up paths.

### Customization options
We have set up this notebook and the paths below for the analysis of the excitatory subgraph of the MICrONS connectome, version 1412.

But you can use it to analyze other connectomes as well, as long as they are formatted in the .hdf5-based format of Connectome-Utilities.

One such example is full_microns_PA.h5, which is created by the notebook "Fit other models to microns" and contains a preferential-attachment model fit to the microns data.

### Caching the results
This analysis is somewhat computationally expensive (~20 minutes for 50k neuron connectomes). Hence, we broke it up into two parts: First, we calculate a pandas.DataFrame that holds pre-digested data. Then we plot the data according to customizable specifications. 

The output of the first, expensive step is saved into a file, such that it does not have to be re-calculated: In future executions it is instead read from the file, unless you set the "force_recalculation" flag to True. 


In [None]:
import numpy
import pandas
import h5py
import os

import conntility

from scipy.spatial.distance import cdist
import tqdm

from matplotlib import pyplot as plt
from ipywidgets import widgets, interactive

# Location of the file holding the connectome information
# Obtain the file at https://doi.org/10.5281/zenodo.16744240 (MICrONS data)
# or at https://doi.org/10.5281/zenodo.16744766 (Potential connectivity of an SSCX model)
fn_connectivity = "microns_mm3_connectome_v1412.h5"
fn_connectivity = "/Users/mwr/Documents/artefacts/connectomes/rat_struc_l23_with_controls.h5"

# Location of the file for storing/caching results of the compuations. If it does not exists, it will be created.
# The file below already has the relevant results pre-calculted.
fn_digest_dataframe = "/Users/mwr/Documents/artefacts/connectomes/rat_struc_l23_digest.h5"

# If you set this flag to True, expensive calculations will be repeated even if the result already exists in the cache file above.
# Note: If you set this to True you will have to update "fn_digest_dataframe" to a different, local path. This is because the 
# default location of "fn_digest_dataframe" is on a read-only file system.
force_recalculation = False


## Selecting the connection matrix
We read a connection matrix from an .h5 file. However, a file can contain more than one connection matrix. 

**Select the one to analyze from the dropdown.**

If you are using the recommended file, then you can choose between "condensed" and "full". The "full" matrix represents the connectome as a multigraph, i.e., multiple synapses from one neuron to another are represented by multiple edges. In "condensed" only one edge is placed, but it will have a property "count" that lists the number of edges in the "full" connectome. 

Which of the two you select will make no difference, as we convert the multigraph version to a condensed view anyways.

If you are using one of the files containing control connectomes, then you can select the instance of the control.

In [11]:
with h5py.File(fn_connectivity, "r") as h5:
    contents = []
    prefixes = list(h5.keys())
    for prefix in prefixes:
        [contents.append(prefix + "/" + _k) for _k in h5[prefix].keys()]

sel_matrix = widgets.Dropdown(options=contents, index=0, description="Select matrix")
display(sel_matrix)

Dropdown(description='Select matrix', options=('connectivity/configuration_model', 'connectivity/data', 'conne…

We load the selected connectivity matrix.
If it is a multigraph, we convert it not to be one.

In [12]:
selected = str(sel_matrix.value)
M = conntility.ConnectivityMatrix.from_h5(fn_connectivity, 
                                          prefix=sel_matrix.value.split("/")[0],
                                          group_name=sel_matrix.value.split("/")[1])
if M.is_multigraph:
    M = M.compress()
print(f"Loaded {len(M)} nodes with {len(M.edges)} edges from {selected}")

col_y = "y"
col_xz = ["x", "z"]
# This flag indicates whether y indicates a "depth", i.e. is counted from the top of L1 (True), or a reverse depth, counted from the bottom of L6 (False).
# This is relevatn purely for plotting purposes at the end.
y_is_depth = False


Loaded 20037 nodes with 67124018 edges from connectivity/data


### Generate spatial bins
We generate spatial bins for the offset of neurons pairs along the y-axis, and their distance in the x/z-plane.
For the y-bins we ensure that they are centered around 0.

We found 50 um to be a good bin size, but you can update it. In that case, you will have to set the "force_recalculate" flag to True.

In [13]:
# It is possible to adjust the bin size of the final plots here.
bin_sz = 50.0 # um


def make_spatial_bins(M_h, cols, bin_sz):
    _data = M_h.vertices[cols]
    delta = _data.max() - _data.min()

    sz = numpy.sqrt((delta.values ** 2).sum())
    if len(delta) == 1: # case 1d: negative and positive bins
        bins = numpy.arange(0, (bin_sz * numpy.ceil(sz / bin_sz)) + bin_sz, bin_sz)
        bins = numpy.hstack([-bins[:0:-1], bins])
    else: # case 2d: Only positive bins, but exclude 0 dist
        bins = numpy.arange(0, (bin_sz * numpy.ceil(sz / bin_sz)) + bin_sz, bin_sz)
        bins = numpy.hstack([[0, 1E-12], bins[1:]])
    return bins

dbins_xz = make_spatial_bins(M, col_xz, bin_sz)
binid_xz = numpy.arange(0, len(dbins_xz) + 1)

dbins_y = make_spatial_bins(M, [col_y], bin_sz)
binid_y = numpy.arange(0, len(dbins_y) + 1)

bin_centers_y = 0.5 * (dbins_y[:-1] + dbins_y[1:])
bin_centers_xz = 0.5 * (dbins_xz[1:-1] + dbins_xz[2:])


## Find nearest neighbors and calculate connectivity with nearest neighbors

We use KDTrees to quickly find nearest neigbors of all neurons.

Then we calculate three sparse matrices:
  - The first simply holds the number of edges between neurons i and j
  - The second holds the number of edges between i and the nearest neigbor of j
  - The third holds the number of edges between the nearest neigbbor of i and j

In [14]:
from scipy.spatial import KDTree

_coords = col_xz + [col_y]
tree = KDTree(M.vertices[_coords].values)

_, nn_id = tree.query(M.vertices[_coords], k=2)
nn_id = nn_id[:, 1]  # nn_id[:, 0] is the original node, which has distance 0. nn_id[:, 1] is neighbor

# Lookup from pre / post ids to number of edges
pair_to_edge_count = M.edges.set_index(pandas.MultiIndex.from_frame(M._edge_indices))["count"]

# Edge counts i -> j
edge_count_mat = M.matrix.tocsr()
# Edge counts nn(i) -> j
edge_count_nnpre_mat = edge_count_mat[nn_id]
# Edge counts i -> nn(j)
edge_count_nnpost_mat = edge_count_mat[:, nn_id]

## Main calculation

Here we calculate a DataFrame with the following structure:

Each row represents a possible combination of:
  - spatial bin of the offset in the xz plane of a pair of neurons (i - j)
  - spatial bin of their offset along the y-axis
  - number of edges (synapses) from i to j
  - number of edges (synapses) from i to the nearest neighbor of j

The first four columns list the values of these four properties. **a fifth column then counts the number of pairs of neurons for that combination**.

Additionally, we calculate a second DataFrame that is very similar, except the fourth column represents instead the number of edges (synapses) from the nearest neighbor of i to j.


For this calculation we have to consider the distances between all pairs of neurons. For large connectomes this can result in a very large matrix. To avoid running out of memory on weaker machines, we conduct the analysis in chunks of 2500 neurons. In the next cell, we define the function we run for each chunk. 

In [15]:
# Execute the analysis for a given set of _rows_ of the connectivity matrix
def for_pre_chunk(chunk_pre):
    # Which offset bin the pairs fall into
    Dxz = cdist(M.vertices.iloc[chunk_pre][col_xz], M.vertices[col_xz]) # PRE X POST
    Dxz = numpy.digitize(Dxz, dbins_xz) - 2  # -2 means distance = 0 will be bin id -1. That is the one to exclude.

    Dy = -M.vertices.iloc[chunk_pre][[col_y]].values + M.vertices[[col_y]].values.transpose() # PRE X POST
    Dy = numpy.digitize(Dy, dbins_y) - 1  # NOTE: Values are post - pre, i.e. the delta y along the direction of connection
    # Numer of edges i -> j
    edge_count = edge_count_mat[chunk_pre].toarray().flatten()

    # Number of edges nn(i) -> j
    edge_count_nnpre = edge_count_nnpre_mat[chunk_pre].toarray().flatten()
    # is nn(i) == j?
    collision_pre = (nn_id[chunk_pre].reshape((-1, 1)) - numpy.arange(edge_count_mat.shape[1]).reshape((1, -1))) != 0
    collision_pre = collision_pre.flatten()

    # Number of edges i -> nn(j)
    edge_count_nnpost = edge_count_nnpost_mat[chunk_pre].toarray().flatten()
    # is i == nn(j)?
    collision_post = (chunk_pre.reshape((-1, 1)) - nn_id.reshape((1, -1))) != 0
    collision_post = collision_post.flatten()

    # Count instances of each
    ret_incoming = pandas.DataFrame({
        "xz": Dxz.flatten()[collision_pre],
        "y": Dy.flatten()[collision_pre],
        "edges_pair": edge_count[collision_pre],
        "edges_nn": edge_count_nnpre[collision_pre],
    }).value_counts()
    
    ret_outgoing = pandas.DataFrame({
        "xz": Dxz.flatten()[collision_post],
        "y": Dy.flatten()[collision_post],
        "edges_pair": edge_count[collision_post],
        "edges_nn": edge_count_nnpost[collision_post],
    }).value_counts()
    return ret_incoming, ret_outgoing

### Try the cache first

Here, we test whether the result can already be found in the cache ("digest") file. If so and "force_recalculation" is False, we load it.

Otherwise, we iterate over chunks of neurons performing the costly calculation. This may take 20-30 minutes.

In [16]:
digest_exists = False
if not force_recalculation:
    if os.path.isfile(fn_digest_dataframe):
        with h5py.File(fn_digest_dataframe, "r") as h5:
            if selected + "/incoming" in h5 and selected + "/outgoing" in h5:
                digest_exists = True

if digest_exists:
    full_df_incoming = pandas.read_hdf(fn_digest_dataframe, selected + "/incoming")
    full_df_outgoing = pandas.read_hdf(fn_digest_dataframe, selected + "/outgoing")
else:
    chunk_sz = 2500
    chunking = numpy.arange(0, len(M) + chunk_sz, chunk_sz)

    chunk = numpy.arange(chunking[0], numpy.minimum(chunking[1], len(M)))
    full_df_incoming, full_df_outgoing = for_pre_chunk(chunk)

    for a, b in tqdm.tqdm(list(zip(chunking[1:-1], chunking[2:]))):
        chunk = numpy.arange(a, numpy.minimum(b, len(M)))
        new_df_in, new_df_out = for_pre_chunk(chunk)
        full_df_incoming = full_df_incoming.add(new_df_in, fill_value=0)
        full_df_outgoing = full_df_outgoing.add(new_df_out, fill_value=0)
    
    full_df_incoming = full_df_incoming.drop(-1, axis=0).reset_index()
    full_df_outgoing = full_df_outgoing.drop(-1, axis=0).reset_index()


    assert (full_df_incoming[["xz", "y"]] >= 0).all().all()
    assert (full_df_outgoing[["xz", "y"]] >= 0).all().all()

    assert (full_df_incoming["xz"] < len(binid_xz)).all()
    assert (full_df_incoming["y"] < len(binid_y)).all()
    assert (full_df_outgoing["xz"] < len(binid_xz)).all()
    assert (full_df_outgoing["y"] < len(binid_y)).all()

    full_df_incoming["xz"] = bin_centers_xz[full_df_incoming["xz"]]
    full_df_incoming["y"] = bin_centers_y[full_df_incoming["y"]]

    full_df_outgoing["xz"] = bin_centers_xz[full_df_outgoing["xz"]]
    full_df_outgoing["y"] = bin_centers_y[full_df_outgoing["y"]]

    full_df_incoming.to_hdf(fn_digest_dataframe, key=(selected + "/incoming"))
    full_df_outgoing.to_hdf(fn_digest_dataframe, key=(selected + "/outgoing"))


## Plotting the results

Here, we plot the results. As in the notebook "microns morphology effect", we plot the overall (prior) connection probabilies in spatial bins, and the (posterior) connection probability, conditional on the nearest neighbor of a neuron being connected.

From the pre-digested representation of the data we created above, a plot can be rapidly created.

We make the plot customizable:
  - thresh_pair: By default, a pair of neurons is considered connected if there is at least 1 synapse between them. But you can increase this to require 2, 3, or more synapses to investigate the spatial structure of stronger connections
  - thresh_nn: Similar to the above. This is the threshold of synapse count for considering the nearest neighbor connected.
  - required_count: Minimum number of neuron pairs for a valid connection probability estimate in a spatial bin. It can be argued that a connection probability calculated from a single pair of neurons is meaningless. Increase this value to avoid that.
  - clim_max: Adjust how tight the limits of the color bar are set.
  - show_relative: If unchecked, then the raw difference of posterior and prior connection probability is plotted. Otherwise, their relative difference is plotted (Michaelson contrast, i.e., bounded between -1 and 1)

In [None]:
def p_prior_post_fun(df_in, thresh_pair=1, thresh_nn=1, required_count=1):
    v_pair = df_in["edges_pair"] >= thresh_pair
    v_nn = df_in["edges_nn"] >= thresh_nn

    if df_in.loc[v_nn, "count"].sum() < required_count:
        return pandas.Series({
            "prior": numpy.nan,
            "posterior": numpy.nan
        })

    prior = df_in.loc[v_pair, "count"].sum() / df_in["count"].sum()
    
    df_in = df_in.loc[v_nn]
    v_pair = v_pair[v_nn]
    posterior = 0.0
    if numpy.any(v_pair):
        posterior = df_in.loc[v_pair]["count"].sum() / df_in["count"].sum()
    return pandas.Series({
        "prior": prior,
        "posterior": posterior
    })

def make_extent(df):
    delta_xz = df.columns[-1] - df.columns[-2]
    delta_y = df.index[-1] - df.index[-2]
    
    extent = [df.columns[0] - delta_xz/2, df.columns[-1] + delta_xz/2,
             df.index[-1] + delta_y/2, df.index[0] - delta_y/2]
    return extent

def show_results(res_df_incoming, res_df_outgoing, clim=[0, 0.1],
                 show_relative=False):
    fig = plt.figure(figsize=(4, 6))

    i = 1
    for prob_type in ["prior", "posterior"]:
        for df, df_str in zip([res_df_incoming, res_df_outgoing],
                              ["Incoming", "Outgoing"]):
            ax = fig.add_subplot(3, 2, i)
            img = df[prob_type].sort_index().unstack("xz")
            pltimg = ax.imshow(img, extent=make_extent(img), clim=clim)
            ax.set_frame_on(False)
            ax.set_title(f"{df_str} connections", fontsize=10)
            # dy = y(post) - y(pre). If y is depth then a values > 0 indicates a _downwards_ connection.
            # Hence, for "Incoming" we want values > 0 towards the top of the plot. if y is not depth,
            # then the other way around.
            if y_is_depth == (df_str == "Incoming"):
                ax.set_ylim(sorted(ax.get_ylim()))
            ax.set_xticks([])
            if numpy.mod(i, 2) == 0:
                plt.colorbar(pltimg, label=f"{prob_type} prob.")
                ax.set_yticks([])
            i += 1

    clim_diff = [-clim[1], clim[1]]
    for df, df_str in zip([res_df_incoming, res_df_outgoing],
                              ["Incoming", "Outgoing"]):
        ax = fig.add_subplot(3, 2, i)
        img = df["posterior"].subtract(df["prior"], fill_value=0)
        if show_relative:
            img = img.divide(df["prior"].add(df["posterior"], fill_value=0), fill_value=0)
            clim_diff = [-1.0, 1.0]
        img = img.sort_index().unstack("xz")
            
        pltimg = ax.imshow(img, extent=make_extent(img), clim=clim_diff, cmap="coolwarm")
        ax.set_frame_on(False)
        ax.set_title(f"Difference", fontsize=10)
        if y_is_depth == (df_str == "Incoming"):
            ax.set_ylim(sorted(ax.get_ylim()))
        ax.set_xticks(ax.get_xticks()); ax.set_xticklabels(ax.get_xticks(), rotation="vertical")
        if numpy.mod(i, 2) == 0:
            plt.colorbar(pltimg)
        i += 1


def interact_fun(thresh_pair, thresh_nn, required_count, clim_max, show_relative):
    res_in = full_df_incoming.groupby(["xz", "y"]).apply(p_prior_post_fun,  include_groups=False,
                                                    thresh_pair=thresh_pair,
                                                    thresh_nn=thresh_nn,
                                                    required_count=required_count)
    res_out = full_df_outgoing.groupby(["xz", "y"]).apply(p_prior_post_fun,  include_groups=False,
                                                    thresh_pair=thresh_pair,
                                                    thresh_nn=thresh_nn,
                                                    required_count=required_count)
    show_results(res_in, res_out, clim=[0, clim_max], show_relative=show_relative)

sel_thresh_pair = widgets.IntSlider(min=1, max=10, value=1)
sel_thresh_nn = widgets.IntSlider(min=1, max=10, value=1)
sel_required_count = widgets.IntSlider(min=1, max=100, value=1)
sel_clim_max = widgets.FloatSlider(min=0.01, max=1.0, value=0.1, step=0.01)
sel_relative = widgets.Checkbox(value=False)

interactive(interact_fun, thresh_pair=sel_thresh_pair, thresh_nn=sel_thresh_nn,
            clim_max=sel_clim_max, required_count=sel_required_count, show_relative=sel_relative)

interactive(children=(IntSlider(value=1, description='thresh_pair', max=10, min=1), IntSlider(value=1, descrip…