# Connectomic analysis for flyConnectome project

By Charles Xu @ KIBM, UCSD

## Set up environment
### Import libraries

In [None]:
# The plotting examples below require holoviews, hvplot, and bokeh:
# conda install -c conda-forge bokeh holoviews hvplot
import numpy as np
import pandas as pd

import bokeh
import hvplot.pandas
import holoviews as hv

import bokeh.palettes
from bokeh.plotting import figure, show, output_notebook
output_notebook()

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as colors

import os
import datetime

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

### Helper functions

In [None]:
# Define function to compute connectivity score
def compute_connectivity_score(df, rois, max_size):
    """
    Compute connectivity score for a given pair of ROIs. The connectivity score is computed as
    the joint probability of observing an input (post) synapse in ROI 1 and an output (pre)
    synapse in ROI 2, normalized
    df: pandas DataFrame. Contains the per-ROI synapse counts data.
    roi: tuple. Pair of ROIs to compute connectivity score for.
    """

    connectivity_score_neuron = df[rois[0]+'_post'] / df['post'] * df[rois[1]+'_pre'] / df['pre'] * (df['post']+df['pre']) / max_size
    connectivity_score_neuron = connectivity_score_neuron.fillna(0)
    connectivity_score_roi = sum(connectivity_score_neuron)
    return connectivity_score_neuron, connectivity_score_roi

# Define function to get dataframe
def get_df(df_dict, roi_pair):
    """
    Get the dataframe for a given pair of ROIs.
    df_dict: dictionary. Dictionary of dataframes, keyed by ROI pair.
    roi_pair: tuple. Pair of ROIs (where order matters) to get dataframe for.
    """

    roi1, roi2 = roi_pair

    # Try to get the dataframe with the order (roi1, roi2)
    df = df_dict.get((roi1, roi2))
    if df is not None:
        return df

    # If that didn't work, try to get the dataframe with the order (roi2, roi1)
    df = df_dict.get((roi2, roi1))
    if df is not None:
        return df

    # If neither worked, raise an error
    raise ValueError(f"No dataframe found for the ROIs {roi1} and {roi2}")

### User inputs

In [None]:
TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6Imp4dTdAY2FsdGVjaC5lZHUiLCJsZXZlbCI6Im5vYXV0aCIsImltYWdlLXVybCI6Imh0dHBzOi8vbGgzLmdvb2dsZXVzZXJjb250ZW50LmNvbS9hL0FDZzhvY0lpQ3IwR0p6VFVZU1R3QmtMWVM0SUdBYk1TeHhsWXhQUWFxdnBXZkRoTj1zOTYtYz9zej01MD9zej01MCIsImV4cCI6MTg3OTY2OTcwNn0.ZawB-33UC1YGtZn6N1vxU1FzK2Ss2vfamlIP_01pfwU" # <--- Paste your token here
# (or define NEUPRINT_APPLICATION CREDENTIALS in your environment)

savefigs = False # Whether to save figures
figdir = 'figures' # Directory to save figures in

savedata = True # Whether to save data
datadir = 'data' # Directory to save data in

# rois = ['EB', 'FB', 'PB'] # Regions of interest
rois = ['EB', 'PB', 'NO', 'FB', # CX
        'LAL(R)', 'BU(R)', 'CRE(R)', 'WED(R)', # EB, NO, FB
        'IB', 'IPS(R)', 'SPS(R)', # PB
        # 'SIP(R)', 'SLP(R)', 'SMP(R)' # FB
        'SNP(R)'] # EB, FB

rois = ['EB', 'PB', 'NO', 'FB', # CX
        'LAL(R)', 'BU(R)', 'CRE(R)', 'SNP(R)', # Monosynaptic
        'IB', 'SPS(R)', 'IPS(R)',
        'AOTU(R)',
        'ICL(R)', 'PVLP(R)', 'WED(R)', 'SAD',
        'LH(R)', 'MB(R)', 'AL(R)'] # More distant regions

# rois = ['EB', 'PB', 'NO', 'FB', 'MB(R)']
        
neuron_properties = ['bodyId', 'instance', 'type', 'pre', 'post', 'status', 'cropped'] # Properties of interest

In [None]:
import os

if not os.path.exists(figdir):
    os.makedirs(figdir)

import itertools

roi_combs = list(itertools.combinations_with_replacement(rois, 2))
roi_perms = list(itertools.product(rois, repeat=2))

### Create a Client

Before you begin, you must create a [Client][client] object.  It will be stored globally and used for all communication with the neuprint server.

Initialize it with your personal authentication token.  See the [Quickstart][qs] guide for details.

[qs]: https://connectome-neuprint.github.io/neuprint-python/docs/quickstart.html
[client]: https://connectome-neuprint.github.io/neuprint-python/docs/client.html

In [None]:
from neuprint import Client

c = Client('neuprint.janelia.org', 'hemibrain:v1.2.1', TOKEN)

If ipywidgets installation is required, run these commands:

```shell
conda install -c conda-forge ipywidgets
jupyter nbextension enable --py widgetsnbextension
conda install -c conda-forge jupyterlab_widgets
```

## Fetch neurons

### ROIs

In neuprint, each neuron is annotated with the list of regions (ROIs) it intersects, along with the synapse counts in each.

The ROIs comprise a hierarchy, with smaller ROIs nested within larger ROIs.  Furthermore, **primary** ROIs are guaranteed not to overlap, and they roughly tile the entire brain (with some gaps).

For a quick overview of the ROI hierarchy, use [fetch_roi_hierarchy()][fetch_roi_hierarchy].

[fetch_roi_hierarchy]: https://connectome-neuprint.github.io/neuprint-python/docs/queries.html#neuprint.queries.fetch_roi_hierarchy

In [None]:
from neuprint import fetch_roi_hierarchy

# Show the ROI hierarchy, with primary ROIs marked with '*'
print(fetch_roi_hierarchy(False, mark_primary=True, format='text'))

### Neuron Search Criteria

Specify neurons of interest by `bodyId`, `type`/`instance`, or via a [NeuronCriteria][NeuronCriteria] object.
With `NeuronCriteria`, you can specify multiple search constraints, including the ROIs in which matched neurons must contain synapses.

[NeuronCriteria]: https://connectome-neuprint.github.io/neuprint-python/docs/neuroncriteria.html

In [None]:
from neuprint import NeuronCriteria as NC

# Select neurons which intersect PB and EB
criteria_combs = {}
for roi_pair in roi_combs:
    criteria_combs[roi_pair] = NC(rois=roi_pair)

### Fetch neuron properties

Neuron properties and per-ROI synapse distributions can be obtained with [fetch_neurons()][fetch_neurons].  Two dataframes are returned: one for neuron properties, and one for the counts of synapses in each ROI.

[fetch_neurons]: https://connectome-neuprint.github.io/neuprint-python/docs/queries.html#neuprint.queries.fetch_neurons

In [None]:
from neuprint import fetch_neurons

neuron_dfs = {}
roi_counts_dfs = {}
for roi_pair, criteria_pair in criteria_combs.items():
    neuron_dfs[roi_pair], roi_counts_dfs[roi_pair] = fetch_neurons(criteria_pair)
    print(f'Fetched {neuron_dfs[roi_pair].shape[0]} neurons for {roi_pair} with properties: {list(neuron_dfs[roi_pair].columns)}')

The total count of pre-synaptic and post-synaptic points within each neuron are given in the `pre` and `post` columns:

In [None]:
# Keep only the properties we care about
for roi_pair, neuron_df in neuron_dfs.items():
    neuron_dfs[roi_pair] = neuron_df[neuron_properties]
    print(f'The new columns are: {list(neuron_dfs[roi_pair].columns)}')

In [None]:
# Add placeholder columns for the per-ROI pre/post synapse counts
pd.options.mode.chained_assignment = None  # default='warn'
for roi_pair, neuron_df in neuron_dfs.items():
    neuron_df[roi_pair[0]+'_pre'] = 0
    neuron_df[roi_pair[0]+'_post'] = 0
    neuron_df[roi_pair[1]+'_pre'] = 0
    neuron_df[roi_pair[1]+'_post'] = 0
    print(f'The new columns are: {list(neuron_df.columns)}')

The per-ROI synapse counts are returned in the second DataFrame.

<div class="alert alert-info">
    
**Note:** Since ROIs overlap (see hierarchy above), the sum of the per-ROI counts for each body will be more than the `pre` and `post` columns above.

</div>


In [None]:
# Print the columns of the per-ROI synapse counts
for roi_pair, roi_counts_df in roi_counts_dfs.items():
    print(f'The columns are: {list(roi_counts_df.columns)}')

In [None]:
# Fill in the per-ROI synapse counts
for roi_pair, roi_counts_df in roi_counts_dfs.items():
    for roi in roi_pair:
        for i, neuron_id in enumerate(neuron_dfs[roi_pair]['bodyId']):
            index = roi_counts_df[(roi_counts_df['bodyId'] == neuron_id) & (roi_counts_df['roi'] == roi)].index
            if len(index) > 0:
                index = index[0]
                neuron_dfs[roi_pair].loc[i, roi+'_pre'] = roi_counts_df.loc[index, 'pre']
                neuron_dfs[roi_pair].loc[i, roi+'_post'] = roi_counts_df.loc[index, 'post']
    print(f'The size of the dataframe for {roi_pair} is: {neuron_dfs[roi_pair].shape}')

## Compute the connectivity scores for ROI pairs

In [None]:
# Calculate maximum 'neuron size' of all results
# 'Neuron size' is computed as the sum of pre- and post-synapse counts
max_size = 0
for neuron_df in neuron_dfs.values():
    ith_max_size = max(neuron_df['pre'] + neuron_df['post'])
    max_size = max(max_size, ith_max_size)
max_size

In [None]:
# Compute connectivity scores
connectivity_score_neuron, connectivity_score_roi = {}, {}
for roi_pair in roi_perms:
    connectivity_score_neuron[roi_pair], connectivity_score_roi[roi_pair] = compute_connectivity_score(get_df(neuron_dfs, roi_pair), roi_pair, max_size)

In [None]:
connectivity_score_roi

In [None]:
# Transform the connectivity scores into a connectivity matrix
connectivity_matrix = np.zeros((len(rois), len(rois)))

# Fill the connectivity matrix
for i, roi1 in enumerate(rois):
    for j, roi2 in enumerate(rois):
        connectivity_matrix[-(i+1), j] = connectivity_score_roi.get((roi1, roi2), 0)

print(connectivity_matrix)

# Save the connectivity matrix
if savedata:
    np.savetxt(os.path.join(datadir, 'connectivity_matrix.csv'), connectivity_matrix, delimiter=',')

In [None]:
# Load the connectivity matrix
if savedata:
    connectivity_matrix = np.loadtxt(os.path.join(datadir, 'connectivity_matrix.csv'), delimiter=',')

In [None]:
# # Plot the connectivity matrix
# fig, ax = plt.subplots(figsize=(10, 10))
# im = ax.imshow(connectivity_matrix, cmap='Reds', vmin=0.1, vmax=1)
# cbar = fig.colorbar(im, ax=ax, shrink=0.5)
# cbar.set_label('Connectivity', rotation=270, labelpad=20, fontsize=12)
# ax.set_xticks(np.arange(len(rois)))
# ax.set_yticks(np.arange(len(rois)))
# ax.set_xticklabels(rois, rotation=90)
# ax.set_yticklabels(rois[::-1])
# ax.set_xlabel('Output (presynaptic) ROI', fontsize=12)
# ax.set_ylabel('Input (postsynaptic) ROI', fontsize=12)

# rect1 = patches.Rectangle((-0.5, len(rois) - 5.5), 5, 5, linewidth=1, edgecolor='k', facecolor='none')
# rect2 = patches.Rectangle((-0.5, len(rois) - 9.5), 9, 9, linewidth=1, edgecolor='k', facecolor='none')
# ax.add_patch(rect1)
# ax.add_patch(rect2)

# if savefigs:
#     current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
#     plt.savefig(f'{figdir}/connectivity_matrix_{current_datetime}.pdf', bbox_inches='tight')

In [None]:
## Plot the connectivity matrix with a log scale
# Penalize weak connections
penalty_threshold = 1e-1
penalized_connectivity_matrix = connectivity_matrix.copy()
penalized_connectivity_matrix[penalized_connectivity_matrix < penalty_threshold] = np.nan

# Plot the connectivity matrix
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(penalized_connectivity_matrix, cmap='Reds', norm=colors.LogNorm(vmin=penalty_threshold, vmax=np.nanmax(penalized_connectivity_matrix)))
cbar = fig.colorbar(im, ax=ax, shrink=0.5)
cbar.set_label('Connectivity', rotation=270, labelpad=20, fontsize=20)
ax.set_xticks(np.arange(len(rois)))
ax.set_yticks(np.arange(len(rois)))
ax.set_xticklabels(rois, rotation=90, fontsize=16)
ax.set_yticklabels(rois[::-1], fontsize=16)
ax.set_xlabel('Output (presynaptic) ROI', fontsize=20)
ax.set_ylabel('Input (postsynaptic) ROI', fontsize=20)

rect1 = patches.Rectangle((-0.5, len(rois) - 4.5), 4, 4, linewidth=1, edgecolor='k', facecolor='none')
rect2 = patches.Rectangle((-0.5, len(rois) - 8.5), 8, 8, linewidth=1, edgecolor='k', facecolor='none')
ax.add_patch(rect1)
ax.add_patch(rect2)

savefigs = True
if savefigs:
    current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    plt.savefig(f'{figdir}/connectivity_matrix_{current_datetime}.pdf', bbox_inches='tight')
    plt.savefig(f'{figdir}/connectivity_matrix_{current_datetime}.svg', bbox_inches='tight')

In [None]:
## Plot the connectivity matrix for CX with a log scale
# Crop the connectivity matrix to only include CX ROIs
cx_rois = ['EB', 'PB', 'NO', 'FB']
cx_indices_col = np.array([rois.index(roi) for roi in cx_rois])
cx_indices_row = cx_indices_col + (len(rois) - len(cx_rois))
penalized_connectivity_matrix_cx = penalized_connectivity_matrix[cx_indices_row][:, cx_indices_col]

# Plot the connectivity matrix
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(penalized_connectivity_matrix_cx, cmap='Reds', norm=colors.LogNorm(vmin=penalty_threshold, vmax=np.nanmax(penalized_connectivity_matrix)))
cbar = fig.colorbar(im, ax=ax, shrink=0.5)
cbar.set_label('Connectivity', rotation=270, labelpad=20, fontsize=20)
ax.set_xticks(np.arange(len(cx_rois)))
ax.set_yticks(np.arange(len(cx_rois)))
ax.set_xticklabels(cx_rois, rotation=90, fontsize=16)
ax.set_yticklabels(cx_rois[::-1], fontsize=16)
ax.set_xlabel('Output (presynaptic) ROI', fontsize=20)
ax.set_ylabel('Input (postsynaptic) ROI', fontsize=20)

savefigs = True
if savefigs:
    current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    plt.savefig(f'{figdir}/connectivity_matrix_CX_{current_datetime}.pdf', bbox_inches='tight')
    plt.savefig(f'{figdir}/connectivity_matrix_CX_{current_datetime}.svg', bbox_inches='tight')

In [None]:
# normalized_connectivity_matrix = connectivity_matrix / connectivity_matrix.max()

# # Penalize weak connections
# penalty_threshold = 1e-2
# normalized_connectivity_matrix[normalized_connectivity_matrix < penalty_threshold] = np.nan

# # Log-transform the normalized connectivity matrix
# log_normalized_connectivity_matrix = np.log10(normalized_connectivity_matrix)

# # Plot the log-transformed normalized connectivity matrix
# fig, ax = plt.subplots(figsize=(10, 10))
# im = ax.imshow(log_normalized_connectivity_matrix, cmap='Reds', vmin=np.log10(penalty_threshold), vmax=0)
# cbar = fig.colorbar(im, ax=ax, shrink=0.5)
# cbar.set_label(r'Connectivity ($\log_{10}$)', rotation=270, labelpad=20, fontsize=12)
# ax.set_xticks(np.arange(len(rois)))
# ax.set_yticks(np.arange(len(rois)))
# ax.set_xticklabels(rois, rotation=90)
# ax.set_yticklabels(rois[::-1])
# ax.set_xlabel('Output (presynaptic) ROI', fontsize=12)
# ax.set_ylabel('Input (postsynaptic) ROI', fontsize=12)

# rect1 = patches.Rectangle((-0.5, len(rois) - 4.5), 4, 4, linewidth=1, edgecolor='k', facecolor='none')
# rect2 = patches.Rectangle((-0.5, len(rois) - 9.5), 9, 9, linewidth=1, edgecolor='k', facecolor='none')
# ax.add_patch(rect1)
# ax.add_patch(rect2)

In [None]:
# log_connectivity_matrix = np.log10(connectivity_matrix)
# log_connectivity_matrix = np.nan_to_num(log_connectivity_matrix, neginf=np.nan)
# log_connectivity_matrix[log_connectivity_matrix < -2] = np.nan

# # Calculate the minimum and maximum of the log_connectivity_matrix
# log_min = np.nanmin(log_connectivity_matrix)
# log_max = np.nanmax(log_connectivity_matrix)

# # Normalize log_connectivity_matrix from 0 to 1
# normalized_log_connectivity_matrix = 0 + ( (log_connectivity_matrix - log_min) * (1 - 0) ) / (log_max - log_min)

# # Plot the normalized log-transformed connectivity matrix
# fig, ax = plt.subplots(figsize=(10, 10))
# im = ax.imshow(normalized_log_connectivity_matrix, cmap='Reds', vmin=0.2)
# cbar = fig.colorbar(im, ax=ax, shrink=0.5)
# cbar.set_label(r'Connectivity ($\log_{10}$, normalized)', rotation=270, labelpad=20, fontsize=12)
# ax.set_xticks(np.arange(len(rois)))
# ax.set_yticks(np.arange(len(rois)))
# ax.set_xticklabels(rois, rotation=90)
# ax.set_yticklabels(rois[::-1])
# ax.set_xlabel('Output (presynaptic) ROI', fontsize=12)
# ax.set_ylabel('Input (postsynaptic) ROI', fontsize=12)

# rect1 = patches.Rectangle((-0.5, len(rois) - 4.5), 4, 4, linewidth=1, edgecolor='k', facecolor='none')
# rect2 = patches.Rectangle((-0.5, len(rois) - 9.5), 9, 9, linewidth=1, edgecolor='k', facecolor='none')
# ax.add_patch(rect1)
# ax.add_patch(rect2)

## Compute closeness centrality of the weighted graph

In [None]:
import networkx as nx

connectivity_matrix = np.loadtxt(os.path.join(datadir, 'connectivity_matrix.csv'), delimiter=',')

G = nx.from_numpy_array(connectivity_matrix, create_using=nx.Graph)

# Calculate closeness centrality for each node, considering weights
closeness_centrality_weighted = nx.closeness_centrality(G, distance='weight')

print("Closeness Centrality of Each Node (Weighted):", closeness_centrality_weighted)