## Setup

In [None]:
import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from networkx.algorithms import community
import matplotlib.patches as mpatches
import leidenalg
import igraph as ig
from collections import Counter
import scipy.linalg as la
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigs


# Add the src directory to path for importing our package
sys.path.insert(0, os.path.join(os.getcwd(), '..', 'src'))

# Import our DC-SBM implementation
from dcsbm import DCSBM, heldout_split, degrees

# Import bi-LRG implementation
from bilrg import BiLRG, HierarchicalBiLRG

In [None]:
data_path = '/Users/charlesxu/local/data/mcp/flywire/ring_extend/'
results_path = os.path.join(data_path, 'results')
save_data = True
if save_data:
    os.makedirs(results_path, exist_ok=True)

## Preprocessing

In [None]:
# Read the CSV file
connectivity_matrix_signed_df = pd.read_csv(os.path.join(data_path, 'connectivity_matrix_signed.csv'))

# Convert to square matrix (assuming first column might be row labels)
if connectivity_matrix_signed_df.shape[1] == connectivity_matrix_signed_df.shape[0] + 1:
    # If there's an extra column (likely row labels), drop it
    C_signed = connectivity_matrix_signed_df.iloc[:, 1:].values
else:
    # Otherwise, use all columns
    C_signed = connectivity_matrix_signed_df.values

print(f"Signed matrix shape: {C_signed.shape}")

# Visualize the signed connectivity matrix
plt.figure(figsize=(10, 8))
plt.imshow(np.log10(np.abs(C_signed) + 1) * np.sign(C_signed), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(C_signed) + 1)), 
           vmax=np.max(np.log10(np.abs(C_signed) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Signed Connectivity Matrix')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')
plt.show()

In [None]:
# Read the ctoi CSV file
ctoi_df = pd.read_csv(os.path.join(data_path, 'ctoi_list.csv'))
ctoi_list = ctoi_df['Cell Type'].tolist()
print(f"Cell type list length: {len(ctoi_list)}")
print(f"First 10 cell types: {ctoi_list[:10]}")

# Read the noi CSV file
noi_df = pd.read_csv(os.path.join(data_path, 'noi_list.csv'))
noi_list = noi_df['Neuron ID'].tolist()
print(f"Neuron ID list length: {len(noi_list)}")
print(f"First 10 neuron IDs: {noi_list[:10]}")

### Group by cell type

In [None]:
## Group by cell type and then for noi and connectivity matrix
# Create a mapping from cell type to indices
cell_type_indices = {}
for i, cell_type in enumerate(ctoi_list):
    if cell_type not in cell_type_indices:
        cell_type_indices[cell_type] = []
    cell_type_indices[cell_type].append(i)

# Get unique cell types in order of first appearance
unique_cell_types = []
seen_types = set()
for cell_type in ctoi_list:
    if cell_type not in seen_types:
        unique_cell_types.append(cell_type)
        seen_types.add(cell_type)

# Sort unique cell types alphabetically
unique_cell_types_sorted = sorted(unique_cell_types)

print(f"Sorted unique cell types: {unique_cell_types_sorted}")
print(f"Number of unique cell types: {len(unique_cell_types_sorted)}")

# Create new ordering with alphabetically sorted cell types
new_order_sorted = []
ctoi_grouped_sorted = []
noi_grouped_sorted = []

for cell_type in unique_cell_types_sorted:
    indices = cell_type_indices[cell_type]
    new_order_sorted.extend(indices)
    # Add cell types for this group
    ctoi_grouped_sorted.extend([cell_type] * len(indices))
    # Add corresponding neuron IDs
    for idx in indices:
        noi_grouped_sorted.append(noi_list[idx])

print(f"Sorted new order length: {len(new_order_sorted)}")
print(f"Sorted grouped ctoi length: {len(ctoi_grouped_sorted)}")
print(f"Sorted grouped noi length: {len(noi_grouped_sorted)}")

# Reorder the connectivity matrix with alphabetically sorted groups
C_signed_grouped_sorted = C_signed[np.ix_(new_order_sorted, new_order_sorted)]

print(f"Sorted grouped matrix shape: {C_signed_grouped_sorted.shape}")

# Update the variables to use the sorted versions
ctoi_grouped = ctoi_grouped_sorted
noi_grouped = noi_grouped_sorted
C_signed_grouped = C_signed_grouped_sorted
new_order = new_order_sorted
unique_cell_types = unique_cell_types_sorted

# Get the number of cells in each cell type
cell_type_counts = {}
for cell_type in unique_cell_types:
    cell_type_counts[cell_type] = len(cell_type_indices[cell_type])

print("Number of cells in each cell type:")
for cell_type in unique_cell_types:
    print(f"{cell_type}: {cell_type_counts[cell_type]}")

print(f"\nTotal number of cell types: {len(unique_cell_types)}")
print(f"Total number of cells: {sum(cell_type_counts.values())}")

### Remove unconnected nodes

In [None]:
C_unsigned_grouped = np.abs(C_signed_grouped)

# Detect unconnected neurons: both in-degree and out-degree are zero
out_degree = np.sum(C_unsigned_grouped, axis=1)
in_degree = np.sum(C_unsigned_grouped, axis=0)

# Find neurons with zero degree (both in and out)
zero_degree_mask = (out_degree == 0) & (in_degree == 0)
unconnected_indices = np.where(zero_degree_mask)[0]
connected_indices = np.where(~zero_degree_mask)[0]

print(f"Total neurons: {len(C_unsigned_grouped)}")
print(f"Unconnected neurons (zero in-degree and out-degree): {len(unconnected_indices)}")
print(f"Connected neurons: {len(connected_indices)}")

if len(unconnected_indices) > 0:
    print(f"\nUnconnected neuron indices: {unconnected_indices[:20]}...")  # Show first 20
    
    # Display cell types of unconnected neurons
    unconnected_cell_types = [ctoi_grouped[i] for i in unconnected_indices]
    cell_type_distribution = Counter(unconnected_cell_types)
    print(f"\nCell type distribution of unconnected neurons:")
    for cell_type, count in sorted(cell_type_distribution.items()):
        print(f"  {cell_type}: {count}")
    
    # Remove unconnected neurons from all relevant data structures
    C_signed_grouped_connected = C_signed_grouped[np.ix_(connected_indices, connected_indices)]
    C_unsigned_grouped_connected = C_unsigned_grouped[np.ix_(connected_indices, connected_indices)]
    ctoi_grouped_connected = [ctoi_grouped[i] for i in connected_indices]
    noi_grouped_connected = [noi_grouped[i] for i in connected_indices]
    new_order_connected = [new_order[i] for i in connected_indices]
    
    print(f"\nConnected matrix shape: {C_unsigned_grouped_connected.shape}")
    print(f"Connected ctoi length: {len(ctoi_grouped_connected)}")
    print(f"Connected noi length: {len(noi_grouped_connected)}")
    
    # Update variables to use connected versions
    C_signed_grouped = C_signed_grouped_connected
    C_unsigned_grouped = C_unsigned_grouped_connected
    ctoi_grouped = ctoi_grouped_connected
    noi_grouped = noi_grouped_connected
    new_order = new_order_connected
    unique_cell_types = np.unique(ctoi_grouped).tolist()
    
    # Recalculate cell type counts
    cell_type_counts = {}
    for cell_type in ctoi_grouped:
        cell_type_counts[cell_type] = cell_type_counts.get(cell_type, 0) + 1
    
    print(f"\nUpdated cell type counts (connected neurons only):")
    print(f"Number of unique cell types: {len(unique_cell_types)}")
else:
    print("\nAll neurons are connected.")

In [None]:
# Visualize the grouped connectivity matrix with cell type labels
plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(C_signed_grouped) + 1) * np.sign(C_signed_grouped), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(C_signed_grouped) + 1)), 
           vmax=np.max(np.log10(np.abs(C_signed_grouped) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Signed Connectivity Matrix (Grouped by Cell Type)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add cell type group boundaries and labels
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = len(cell_type_indices[cell_type])
    group_boundaries.append(current_pos + group_size)

    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)

    # Set x and y ticks with cell type labels
    plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
    plt.yticks(tick_positions, tick_labels, fontsize=8)
    current_pos += group_size

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:  # Exclude the last boundary
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

### Weight-only (unsigned) matrix

In [None]:
# Visualize the unsigned grouped connectivity matrix
plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(C_unsigned_grouped) + 1), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(C_unsigned_grouped) + 1)), 
           vmax=np.max(np.log10(np.abs(C_unsigned_grouped) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Unsigned Connectivity Matrix (Grouped by Cell Type)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add cell type group boundaries and labels
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = len(cell_type_indices[cell_type])
    group_boundaries.append(current_pos + group_size)

    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)

    # Set x and y ticks with cell type labels
    plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
    plt.yticks(tick_positions, tick_labels, fontsize=8)
    current_pos += group_size

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:  # Exclude the last boundary
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
# Workspace checkpoint
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

### Select cell types

In [None]:
select_cell_type = ['Delta7', 'EPG']
# select_cell_type = unique_cell_types # Use all cell types
select_indices = [i for i, ct in enumerate(ctoi_grouped) if ct in select_cell_type]
ctoi_grouped = [ctoi_grouped[i] for i in select_indices]
noi_grouped = [noi_grouped[i] for i in select_indices]
C_signed_grouped = C_signed_grouped[np.ix_(select_indices, select_indices)]
C_unsigned_grouped = C_unsigned_grouped[np.ix_(select_indices, select_indices)]
new_order = [new_order[i] for i in select_indices]
unique_cell_types = select_cell_type
cell_type_counts = {ct: ctoi_grouped.count(ct) for ct in unique_cell_types}

In [None]:
# Visualize the grouped connectivity matrix with cell type labels
plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(C_signed_grouped) + 1) * np.sign(C_signed_grouped), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(C_signed_grouped) + 1)), 
           vmax=np.max(np.log10(np.abs(C_signed_grouped) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Signed Connectivity Matrix (Grouped by Cell Type)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add cell type group boundaries and labels
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = len(cell_type_indices[cell_type])
    group_boundaries.append(current_pos + group_size)

    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)

    # Set x and y ticks with cell type labels
    plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
    plt.yticks(tick_positions, tick_labels, fontsize=8)
    current_pos += group_size

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:  # Exclude the last boundary
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
# Visualize the grouped connectivity matrix with cell type labels
plt.figure(figsize=(12, 10))
plt.imshow(np.sign(C_signed_grouped), 
           cmap='bwr', aspect='auto', 
           vmin=-1, 
           vmax=1,
           interpolation='none')
plt.colorbar(label='Connection Sign')
plt.title('Signed Connectivity Matrix (Grouped by Cell Type)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add cell type group boundaries and labels
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = len(cell_type_indices[cell_type])
    group_boundaries.append(current_pos + group_size)

    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)

    # Set x and y ticks with cell type labels
    plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
    plt.yticks(tick_positions, tick_labels, fontsize=8)
    current_pos += group_size

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:  # Exclude the last boundary
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
# Visualize the unsigned grouped connectivity matrix
plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(C_unsigned_grouped) + 1), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(C_unsigned_grouped) + 1)), 
           vmax=np.max(np.log10(np.abs(C_unsigned_grouped) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Unsigned Connectivity Matrix (Grouped by Cell Type)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add cell type group boundaries and labels
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = len(cell_type_indices[cell_type])
    group_boundaries.append(current_pos + group_size)

    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)

    # Set x and y ticks with cell type labels
    plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
    plt.yticks(tick_positions, tick_labels, fontsize=8)
    current_pos += group_size

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:  # Exclude the last boundary
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

### Normalization

#### Uniform scaling

#### Information-based normalization

### Checkpoint

In [None]:
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

## System reduction

In [None]:
A = C_unsigned_grouped

### Graph Laplacians

#### Directed, combinatorial Laplacians

In [None]:
# Compute out-degree and in-degree for A
# Out-degree: sum of outgoing connections (row sums)
D_out = np.sum(A, axis=1)

# In-degree: sum of incoming connections (column sums)
D_in = np.sum(A, axis=0)

print(f"Out-degree D_out shape: {D_out.shape}")
print(f"In-degree D_in shape: {D_in.shape}")
print(f"Out-degree range: {D_out.min():.2f} - {D_out.max():.2f}")
print(f"In-degree range: {D_in.min():.2f} - {D_in.max():.2f}")

# Create diagonal degree matrices
D_out = np.diag(D_out)
D_in = np.diag(D_in)

# Compute out-degree Laplacian: L_out = D_out - A
L_out = D_out - A

# Compute in-degree Laplacian: L_in = D_in - A^T
L_in = D_in - A.T

print(f"\nOut-degree Laplacian L_out shape: {L_out.shape}")
print(f"In-degree Laplacian L_in shape: {L_in.shape}")

# Verify Laplacian properties
# Row sums of L_out should be zero
row_sums_L_out = np.sum(L_out, axis=1)
print(f"\nL_out row sums (should be ~0): max abs = {np.abs(row_sums_L_out).max():.2e}")

# Column sums of L_in should be zero
col_sums_L_in = np.sum(L_in, axis=0)
print(f"L_in column sums (should be ~0): max abs = {np.abs(col_sums_L_in).max():.2e}")

# Show some statistics
print(f"\nDegree statistics:")
print(f"Mean out-degree: {D_out.mean():.2f}")
print(f"Mean in-degree: {D_in.mean():.2f}")
print(f"Std out-degree: {D_out.std():.2f}")
print(f"Std in-degree: {D_in.std():.2f}")

In [None]:
# Calculate out-degree and in-degree
out_degree = np.sum(A, axis=1)
in_degree = np.sum(A, axis=0)

# Create a figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot out-degree distribution
axes[0].hist(out_degree, bins=30, color='blue', alpha=0.7, edgecolor='black')
axes[0].set_xlabel('Out-Degree')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Out-Degree Distribution')
axes[0].grid(True, alpha=0.3)

# Plot in-degree distribution
axes[1].hist(in_degree, bins=30, color='green', alpha=0.7, edgecolor='black')
axes[1].set_xlabel('In-Degree')
axes[1].set_ylabel('Frequency')
axes[1].set_title('In-Degree Distribution')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("Out-Degree Statistics:")
print(f"  Mean: {out_degree.mean():.2f}")
print(f"  Std: {out_degree.std():.2f}")
print(f"  Min: {out_degree.min():.2f}")
print(f"  Max: {out_degree.max():.2f}")

print("\nIn-Degree Statistics:")
print(f"  Mean: {in_degree.mean():.2f}")
print(f"  Std: {in_degree.std():.2f}")
print(f"  Min: {in_degree.min():.2f}")
print(f"  Max: {in_degree.max():.2f}")

In [None]:
# Visualize the out-degree and in-degree Laplacians
# Plot L_out
plt.figure(figsize=(12, 10))
plt.imshow(L_out, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(L_out)), 
           vmax=np.max(np.abs(L_out)), 
           interpolation='none')
plt.colorbar(label='Value')
plt.title('Out-degree Laplacian (L_out = D_out - A)')
plt.xlabel('Column Index')
plt.ylabel('Row Index')
plt.tight_layout()
plt.show()

# Plot L_in
plt.figure(figsize=(12, 10))
plt.imshow(L_in, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(L_in)), 
           vmax=np.max(np.abs(L_in)), 
           interpolation='none')
plt.colorbar(label='Value')
plt.title('In-degree Laplacian (L_in = D_in - A^T)')
plt.xlabel('Column Index')
plt.ylabel('Row Index')
plt.tight_layout()
plt.show()

# Show some statistics about the Laplacians
print("Out-degree Laplacian (L_out) statistics:")
print(f"  Min: {L_out.min():.2f}")
print(f"  Max: {L_out.max():.2f}")
print(f"  Mean: {L_out.mean():.2f}")
print(f"  Std: {L_out.std():.2f}")

print("\nIn-degree Laplacian (L_in) statistics:")
print(f"  Min: {L_in.min():.2f}")
print(f"  Max: {L_in.max():.2f}")
print(f"  Mean: {L_in.mean():.2f}")
print(f"  Std: {L_in.std():.2f}")

#### Random walk Laplacian

In [None]:
# Compute random walk Laplacian L_rw = I - D_out^(-1) * A
# First, compute D_out^(-1)
try:
    D_out_inv = np.linalg.inv(D_out)
    
    # Compute the random walk Laplacian
    I = np.eye(len(A))
    L_rw = I - D_out_inv @ A

    print(f"Random walk Laplacian L_rw shape: {L_rw.shape}")

    # Verify properties of random walk Laplacian
    # Row sums should be zero (since each row of D_out^(-1) * A sums to 1)
    row_sums_L_rw = np.sum(L_rw, axis=1)
    print(f"L_rw row sums (should be ~0): max abs = {np.abs(row_sums_L_rw).max():.2e}")

    # Check for any numerical issues (zeros on diagonal of D_out would cause problems)
    print(f"\nMin diagonal value of D_out: {np.diag(D_out).min():.2f}")
    print(f"Number of zero-degree nodes: {np.count_nonzero(np.diag(D_out) == 0)}")

    # Show some statistics
    print(f"\nRandom walk Laplacian statistics:")
    print(f"L_rw min: {L_rw.min():.4f}")
    print(f"L_rw max: {L_rw.max():.4f}")
    print(f"L_rw mean: {L_rw.mean():.4f}")
except Exception as e:
    print("Error computing random walk Laplacian:", e)

In [None]:
# Visualize the random walk Laplacian
if 'L_rw' in locals():
    # Plot L_rw
    plt.figure(figsize=(12, 10))
    plt.imshow(L_rw, cmap='bwr', aspect='auto', 
            vmin=-np.max(np.abs(L_rw)), 
            vmax=np.max(np.abs(L_rw)), 
            interpolation='none')
    plt.colorbar(label='Value')
    plt.title('Random Walk Laplacian (L_rw = I - D_out^(-1) * A)')
    plt.xlabel('Column Index')
    plt.ylabel('Row Index')
    plt.tight_layout()
    plt.show()

#### Normalized Laplacian

##### Symmetrize then normalize

Treating A undirected.

In [None]:
# Compute the symmetric Laplacian L_sym = I - D^(-1/2) * A_sym * D^(-1/2)
A_sym = (A + A.T) / 2
D_sym = np.diag(np.sum(A_sym, axis=1))
try:
    D_sym_inv_sqrt = np.linalg.inv(np.sqrt(D_sym))
    
    # Compute the symmetric Laplacian
    I = np.eye(len(A_sym))
    L_sym = I - D_sym_inv_sqrt @ A_sym @ D_sym_inv_sqrt

    print(f"Symmetric Laplacian L_sym shape: {L_sym.shape}")

    # Verify properties of symmetric Laplacian
    # Row sums should be zero
    row_sums_L_sym = np.sum(L_sym, axis=1)
    print(f"L_sym row sums (should be ~0): max abs = {np.abs(row_sums_L_sym).max():.2e}")

    # Check for any numerical issues (zeros on diagonal of D_sym would cause problems)
    print(f"\nMin diagonal value of D_sym: {np.diag(D_sym).min():.2f}")
    print(f"Number of zero-degree nodes: {np.count_nonzero(np.diag(D_sym) == 0)}")

    # Show some statistics
    print(f"\nSymmetric Laplacian statistics:")
    print(f"L_sym min: {L_sym.min():.4f}")
    print(f"L_sym max: {L_sym.max():.4f}")
    print(f"L_sym mean: {L_sym.mean():.4f}")
except Exception as e:
    print("Error computing symmetric Laplacian:", e)

In [None]:
# Visualize the symmetric adjacency matrix A_sym
plt.figure(figsize=(12, 10))
plt.imshow(np.log10(A_sym + 1), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(np.log10(A_sym + 1))), 
           vmax=np.max(np.abs(np.log10(A_sym + 1))),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title('Symmetric Adjacency Matrix A_sym = (A + A^T) / 2')
plt.xlabel('Neuron Index')
plt.ylabel('Neuron Index')

# Add cell type group boundaries
group_boundaries = []
tick_positions = []
tick_labels = []
current_pos = 0
for cell_type in unique_cell_types:
    group_size = cell_type_counts[cell_type]
    group_boundaries.append(current_pos + group_size)
    
    # Add group label at the center of each group
    center_pos = current_pos + group_size / 2
    tick_positions.append(center_pos)
    tick_labels.append(cell_type)
    
    current_pos += group_size

# Set x and y ticks with cell type labels
plt.xticks(tick_positions, tick_labels, rotation=45, ha='right', fontsize=8)
plt.yticks(tick_positions, tick_labels, fontsize=8)

# Add vertical and horizontal lines to separate groups
for boundary in group_boundaries[:-1]:
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

# Show statistics
print(f"A_sym shape: {A_sym.shape}")
print(f"A_sym min: {A_sym.min():.2f}")
print(f"A_sym max: {A_sym.max():.2f}")
print(f"A_sym mean: {A_sym.mean():.2f}")
print(f"Number of non-zero entries: {np.count_nonzero(A_sym)}")

# Verify symmetry
symmetry_error = np.max(np.abs(A_sym - A_sym.T))
print(f"\nSymmetry error (should be ~0): {symmetry_error:.2e}")

In [None]:
# Visualize the symmetrized degree
# Plot D_sym
plt.figure(figsize=(12, 10))
plt.imshow(D_sym, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(D_sym)), 
           vmax=np.max(np.abs(D_sym)), 
           interpolation='none')
plt.colorbar(label='Value')
plt.title('Degree Matrix D_sym')
plt.xlabel('Column Index')
plt.ylabel('Row Index')
plt.tight_layout()
plt.show()

In [None]:
# Visualize the symmetrize-then-normalize Laplacian
plt.figure(figsize=(12, 10))
plt.imshow(L_sym, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(L_sym)), 
           vmax=np.max(np.abs(L_sym)), 
           interpolation='none')
plt.colorbar(label='Value')
plt.title('Symmetrize-then-Normalize Laplacian (L_sym)')
plt.xlabel('Column Index')
plt.ylabel('Row Index')
plt.tight_layout()
plt.show()

##### Balanced symmetrization

Default is row-stochastic (forward).

In [None]:
# Compute L_bal: the direct symmetrization of L_out
# L_bal = D_out^(-1/2) * (D_out - 1/2(A + A^T)) * D_out^(-1/2)

try:
    # First, compute the symmetrized adjacency matrix
    A_sym = (A + A.T) / 2

    # Compute D_out^(-1/2)
    D_out_inv_sqrt = np.linalg.inv(np.sqrt(D_out))

    # Compute L_bal
    L_bal = D_out_inv_sqrt @ (D_out - A_sym) @ D_out_inv_sqrt

    print(f"L_bal shape: {L_bal.shape}")

    # Verify symmetry of L_bal
    symmetry_error = np.max(np.abs(L_bal - L_bal.T))
    print(f"L_bal symmetry error (should be ~0): {symmetry_error:.2e}")

    # Verify properties
    row_sums_L_bal = np.sum(L_bal, axis=1)
    print(f"L_bal row sums (should be ~0): max abs = {np.abs(row_sums_L_bal).max():.2e}")

    # Show some statistics
    print(f"\nL_bal statistics:")
    print(f"  Min: {L_bal.min():.4f}")
    print(f"  Max: {L_bal.max():.4f}")
    print(f"  Mean: {L_bal.mean():.4f}")
    print(f"  Std: {L_bal.std():.4f}")
except Exception as e:
    print("Error computing L_bal:", e)

In [None]:
if 'L_bal' in locals():
    # Visualize L_bal
    plt.figure(figsize=(12, 10))
    plt.imshow(L_bal, cmap='bwr', aspect='auto', 
            vmin=-np.max(np.abs(L_bal)), 
            vmax=np.max(np.abs(L_bal)), 
            interpolation='none')
    plt.colorbar(label='Value')
    plt.title('Balanced Symmetrized Laplacian (L_bal)')
    plt.xlabel('Column Index')
    plt.ylabel('Row Index')
    plt.tight_layout()
    plt.show()

In [None]:
# Workspace checkpoint
variables_to_keep = ['ctoi_grouped', 'noi_grouped',
                     'C_signed_grouped', 'C_unsigned_grouped',
                     'new_order', 'unique_cell_types', 'cell_type_counts',
                     'A_sym', 'D_sym', 'D_in', 'D_out', 'L_sym', 'L_bal', 'L_in', 'L_out', 'L_rw']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

#### Laplacian spectrum

##### Eigenspectrum of L_out

In [None]:
L = L_out

# Compute the eigenvalues of L
eigenvalues_L = np.linalg.eigvals(L)

# Sort eigenvalues in ascending order
eigenvalues_L_sorted = np.sort(eigenvalues_L.real)

# Plot the spectrum
plt.figure(figsize=(12, 6))
plt.plot(eigenvalues_L_sorted, 'o-', markersize=4, linewidth=1)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('Index')
plt.ylabel('Eigenvalue')
plt.title('Spectrum of Out-degree Laplacian')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Min: {eigenvalues_L_sorted.min():.4f}")
print(f"  Max: {eigenvalues_L_sorted.max():.4f}")
print(f"  Mean: {eigenvalues_L_sorted.mean():.4f}")
print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_L_sorted) < 1e-10)}")
print(f"  Number of positive eigenvalues: {np.sum(eigenvalues_L_sorted > 1e-10)}")
print(f"  Number of negative eigenvalues: {np.sum(eigenvalues_L_sorted < -1e-10)}")

# Plot eigenvalues on the complex plane
plt.figure(figsize=(10, 10))
plt.scatter(eigenvalues_L.real, eigenvalues_L.imag, 
           s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.title('Spectrum of Out-degree Laplacian on Complex Plane')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Number of eigenvalues: {len(eigenvalues_L)}")
print(f"  Real part range: [{eigenvalues_L.real.min():.4f}, {eigenvalues_L.real.max():.4f}]")
print(f"  Imaginary part range: [{eigenvalues_L.imag.min():.4f}, {eigenvalues_L.imag.max():.4f}]")
print(f"  Number of purely real eigenvalues (|Im| < 1e-10): {np.sum(np.abs(eigenvalues_L.imag) < 1e-10)}")
print(f"  Number of complex eigenvalues: {np.sum(np.abs(eigenvalues_L.imag) >= 1e-10)}")
print(f"  Number of zero eigenvalues (|λ| < 1e-10): {np.sum(np.abs(eigenvalues_L) < 1e-10)}")


##### Eigenspectrum of L_rw

In [None]:
if 'L_rw' in locals():
    L = L_rw

    # Compute the eigenvalues of L
    eigenvalues_L = np.linalg.eigvals(L)

    # Sort eigenvalues in ascending order
    eigenvalues_L_sorted = np.sort(eigenvalues_L.real)

    # Plot the spectrum
    plt.figure(figsize=(12, 6))
    plt.plot(eigenvalues_L_sorted, 'o-', markersize=4, linewidth=1)
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue')
    plt.title('Spectrum of Random-Walk Laplacian')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Print statistics about the eigenvalues
    print(f"Eigenvalue statistics:")
    print(f"  Min: {eigenvalues_L_sorted.min():.4f}")
    print(f"  Max: {eigenvalues_L_sorted.max():.4f}")
    print(f"  Mean: {eigenvalues_L_sorted.mean():.4f}")
    print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_L_sorted) < 1e-10)}")
    print(f"  Number of positive eigenvalues: {np.sum(eigenvalues_L_sorted > 1e-10)}")
    print(f"  Number of negative eigenvalues: {np.sum(eigenvalues_L_sorted < -1e-10)}")

    # Plot eigenvalues on the complex plane
    plt.figure(figsize=(10, 10))
    plt.scatter(eigenvalues_L.real, eigenvalues_L.imag, 
            s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
    plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
    plt.xlabel('Real Part')
    plt.ylabel('Imaginary Part')
    plt.title('Spectrum of Random-Walk Laplacian on Complex Plane')
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

    # Print statistics about the eigenvalues
    print(f"Eigenvalue statistics:")
    print(f"  Number of eigenvalues: {len(eigenvalues_L)}")
    print(f"  Real part range: [{eigenvalues_L.real.min():.4f}, {eigenvalues_L.real.max():.4f}]")
    print(f"  Imaginary part range: [{eigenvalues_L.imag.min():.4f}, {eigenvalues_L.imag.max():.4f}]")
    print(f"  Number of purely real eigenvalues (|Im| < 1e-10): {np.sum(np.abs(eigenvalues_L.imag) < 1e-10)}")
    print(f"  Number of complex eigenvalues: {np.sum(np.abs(eigenvalues_L.imag) >= 1e-10)}")
    print(f"  Number of zero eigenvalues (|λ| < 1e-10): {np.sum(np.abs(eigenvalues_L) < 1e-10)}")


##### Eigenspectrum of L_sym

In [None]:
L = L_sym

# Compute the eigenvalues of L
eigenvalues_L = np.linalg.eigvals(L)

# Sort eigenvalues in ascending order
eigenvalues_L_sorted = np.sort(eigenvalues_L.real)

# Plot the spectrum
plt.figure(figsize=(12, 6))
plt.plot(eigenvalues_L_sorted, 'o-', markersize=4, linewidth=1)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('Index')
plt.ylabel('Eigenvalue')
plt.title('Spectrum of Symmetric Normalized Laplacian')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Min: {eigenvalues_L_sorted.min():.4f}")
print(f"  Max: {eigenvalues_L_sorted.max():.4f}")
print(f"  Mean: {eigenvalues_L_sorted.mean():.4f}")
print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_L_sorted) < 1e-10)}")
print(f"  Number of positive eigenvalues: {np.sum(eigenvalues_L_sorted > 1e-10)}")
print(f"  Number of negative eigenvalues: {np.sum(eigenvalues_L_sorted < -1e-10)}")

# Plot eigenvalues on the complex plane
plt.figure(figsize=(10, 10))
plt.scatter(eigenvalues_L.real, eigenvalues_L.imag, 
           s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.title('Spectrum of Symmetric Normalized Laplacian on Complex Plane')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Number of eigenvalues: {len(eigenvalues_L)}")
print(f"  Real part range: [{eigenvalues_L.real.min():.4f}, {eigenvalues_L.real.max():.4f}]")
print(f"  Imaginary part range: [{eigenvalues_L.imag.min():.4f}, {eigenvalues_L.imag.max():.4f}]")
print(f"  Number of purely real eigenvalues (|Im| < 1e-10): {np.sum(np.abs(eigenvalues_L.imag) < 1e-10)}")
print(f"  Number of complex eigenvalues: {np.sum(np.abs(eigenvalues_L.imag) >= 1e-10)}")
print(f"  Number of zero eigenvalues (|λ| < 1e-10): {np.sum(np.abs(eigenvalues_L) < 1e-10)}")


##### Eigenspectrum of L_bal

In [None]:
if 'L_bal' in locals():
    L = L_bal

    # Compute the eigenvalues of L
    eigenvalues_L = np.linalg.eigvals(L)

    # Sort eigenvalues in ascending order
    eigenvalues_L_sorted = np.sort(eigenvalues_L.real)

    # Plot the spectrum
    plt.figure(figsize=(12, 6))
    plt.plot(eigenvalues_L_sorted, 'o-', markersize=4, linewidth=1)
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
    plt.xlabel('Index')
    plt.ylabel('Eigenvalue')
    plt.title('Spectrum of Symmetrized Out-degree Laplacian')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Print statistics about the eigenvalues
    print(f"Eigenvalue statistics:")
    print(f"  Min: {eigenvalues_L_sorted.min():.4f}")
    print(f"  Max: {eigenvalues_L_sorted.max():.4f}")
    print(f"  Mean: {eigenvalues_L_sorted.mean():.4f}")
    print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_L_sorted) < 1e-10)}")
    print(f"  Number of positive eigenvalues: {np.sum(eigenvalues_L_sorted > 1e-10)}")
    print(f"  Number of negative eigenvalues: {np.sum(eigenvalues_L_sorted < -1e-10)}")

    # Plot eigenvalues on the complex plane
    plt.figure(figsize=(10, 10))
    plt.scatter(eigenvalues_L.real, eigenvalues_L.imag, 
            s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
    plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
    plt.xlabel('Real Part')
    plt.ylabel('Imaginary Part')
    plt.title('Spectrum of Symmetrized Out-degree Laplacian on Complex Plane')
    plt.grid(True, alpha=0.3)
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

    # Print statistics about the eigenvalues
    print(f"Eigenvalue statistics:")
    print(f"  Number of eigenvalues: {len(eigenvalues_L)}")
    print(f"  Real part range: [{eigenvalues_L.real.min():.4f}, {eigenvalues_L.real.max():.4f}]")
    print(f"  Imaginary part range: [{eigenvalues_L.imag.min():.4f}, {eigenvalues_L.imag.max():.4f}]")
    print(f"  Number of purely real eigenvalues (|Im| < 1e-10): {np.sum(np.abs(eigenvalues_L.imag) < 1e-10)}")
    print(f"  Number of complex eigenvalues: {np.sum(np.abs(eigenvalues_L.imag) >= 1e-10)}")
    print(f"  Number of zero eigenvalues (|λ| < 1e-10): {np.sum(np.abs(eigenvalues_L) < 1e-10)}")


### Checkpoint

In [None]:
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts',
                     'D_out', 'D_in', 'L_out', 'L_in', 'L_rw', 'A_sym', 'D_sym', 'L_sym', 'L_bal']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

### Renormalization group

We will consider the adjacency $A$ to be directly linked to the network's continuous dynamics.

#### Setup and diffusion dynamics

In [None]:
# Handle sinks
zero_out_mask = np.sum(A, axis=1) == 0
A_fixed = A.copy()
A_fixed[zero_out_mask, zero_out_mask] = 1.0

# Consider the row-normalized adjacency as state propagator
# P_hat = D_out^(-1) * A
D_out = np.diag(np.sum(A_fixed, axis=1))
P_hat = np.linalg.inv(D_out) @ A_fixed

# Consider the random walk Laplacian as the diffusion operator
# L_hat = I - P_hat
# or a Markov generator: -L_hat = P_hat - I
L_hat = np.eye(len(A_fixed)) - P_hat

In [None]:
# Compute the diffusion matrix exponential K_hat = exp(-tau * L_hat)
tau = 5
K_hat = la.expm(-tau * L_hat)

print(f"K_hat shape: {K_hat.shape}")
print(f"K_hat computed as exp(-{tau} * L_hat)")

# Verify properties
print(f"\nK_hat statistics:")
print(f"  Min: {K_hat.min():.6f}")
print(f"  Max: {K_hat.max():.6f}")
print(f"  Mean: {K_hat.mean():.6f}")

# Check if K_hat is symmetric (should be for symmetric Laplacian)
symmetry_error_K = np.max(np.abs(K_hat - K_hat.T))
print(f"\nK_hat symmetry error: {symmetry_error_K:.2e}")

# Check row sums (for diffusion kernel, should decay with tau)
row_sums_K = np.sum(K_hat, axis=1)
print(f"K_hat row sums: min={row_sums_K.min():.6f}, max={row_sums_K.max():.6f}")

# Visualize K_hat
plt.figure(figsize=(12, 10))
plt.imshow(K_hat, aspect='auto', cmap='bwr', vmin=-np.max(np.abs(K_hat)), vmax=np.max(np.abs(K_hat)), interpolation='none')
plt.colorbar(label='K_hat values')
plt.title('Visualization of K_hat')
plt.xlabel('Node')
plt.ylabel('Node')
plt.show()

In [None]:
# Compute the canonical density operator rho_hat
rho_hat_diag = np.diag(K_hat) / np.trace(K_hat)
rho_hat = np.diag(rho_hat_diag)

print(f"rho_hat shape: {rho_hat.shape}")
print(f"rho_hat statistics: min={rho_hat_diag.min():.6f}, max={rho_hat_diag.max():.6f}, mean={rho_hat_diag.mean():.6f}")

# Visualize rho_hat
plt.figure(figsize=(12, 10))
plt.imshow(rho_hat, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(rho_hat)), 
           vmax=np.max(np.abs(rho_hat)), 
           interpolation='none')
plt.colorbar()
plt.title("Visualization of rho_hat")
plt.xlabel("Features")
plt.ylabel("Features")
plt.show()

In [None]:
# Compute the system entropy S
# S = -1/log(N) * \sum_{i=1}^{N} \mu_i(\tau) * log(\mu_i(\tau))
# where \mu_i(\tau) are the eigenvalues of rho_hat(\tau)
eigenvalues_rho_hat = np.linalg.eigvals(rho_hat)
eigenvalues_rho_hat = eigenvalues_rho_hat.real
N = len(eigenvalues_rho_hat)
S = -1/np.log(N) * np.sum(eigenvalues_rho_hat * np.log(eigenvalues_rho_hat + 1e-12))  # Add small value to avoid log(0)
print(f"System entropy S: {S:.6f}")

#### Eigendecomposition

In [None]:
# Compute eigendecomposition of P_hat (state propagator)
print("Eigendecomposition of P_hat (State Propagator)")
print("=" * 60)

# For non-symmetric matrices, compute both left and right eigenvectors
eigenvalues, eigenvectors_right = np.linalg.eig(P_hat)
eigenvalues_left, eigenvectors_left_T = np.linalg.eig(P_hat.T)
eigenvectors_left = eigenvectors_left_T.T

# Verify that left and right eigenvalues match
eigenvalue_diff = np.max(np.abs(np.sort(eigenvalues.real) - np.sort(eigenvalues_left.real)))
print(f"Max difference between left and right eigenvalues: {eigenvalue_diff:.2e}")

# Sort by eigenvalue
idx_sorted = np.argsort(eigenvalues.real)
eigenvalues_sorted = eigenvalues[idx_sorted]
eigenvectors_right_sorted = eigenvectors_right[:, idx_sorted]
eigenvectors_left_sorted = eigenvectors_left[idx_sorted, :]

print(f"\nComputed {len(eigenvalues_sorted)} eigenvalues")
print(f"Eigenvalue range: [{eigenvalues_sorted.real.min():.6f}, {eigenvalues_sorted.real.max():.6f}]")

# Count zero eigenvalues
n_zero_eigs = np.sum(np.abs(eigenvalues_sorted) < 1e-10)
print(f"Number of zero eigenvalues (|λ| < 1e-10): {n_zero_eigs}")

# Verify bi-orthogonality: <w_i, v_j> should be δ_ij
biorthogonality = eigenvectors_left_sorted @ eigenvectors_right_sorted
off_diag_max = np.max(np.abs(biorthogonality - np.diag(np.diag(biorthogonality))))
print(f"\nBi-orthogonality check - max off-diagonal: {off_diag_max:.2e}")

# Show first few eigenvalues
print(f"\nFirst 10 eigenvalues:")
for i in range(min(10, len(eigenvalues_sorted))):
    print(f"  λ_{i}: {eigenvalues_sorted[i].real:.6f} + {eigenvalues_sorted[i].imag:.6f}i")

In [None]:
# Plot the spectrum
plt.figure(figsize=(12, 6))
plt.plot(eigenvalues_sorted, 'o-', markersize=4, linewidth=1)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('Index')
plt.ylabel('Eigenvalue')
plt.title('Spectrum of State Propagator P_hat')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Min: {eigenvalues_sorted.min():.4f}")
print(f"  Max: {eigenvalues_sorted.max():.4f}")
print(f"  Mean: {eigenvalues_sorted.mean():.4f}")
print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_sorted) < 1e-10)}")
print(f"  Number of positive eigenvalues: {np.sum(eigenvalues_sorted > 1e-10)}")
print(f"  Number of negative eigenvalues: {np.sum(eigenvalues_sorted < -1e-10)}")

# Plot eigenvalues on the complex plane
plt.figure(figsize=(10, 10))
plt.scatter(eigenvalues_sorted.real, eigenvalues_sorted.imag, 
           s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.title('Spectrum of Out-degree Laplacian on Complex Plane')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

In [None]:
# Compute eigendecomposition of L_hat (Random Walk Laplacian)
print("Eigendecomposition of L_hat (Random Walk Laplacian)")
print("=" * 60)

# For non-symmetric matrices, compute both left and right eigenvectors
eigenvalues, eigenvectors_right = np.linalg.eig(L_hat)
eigenvalues_left, eigenvectors_left_T = np.linalg.eig(L_hat.T)
eigenvectors_left = eigenvectors_left_T.T

# Verify that left and right eigenvalues match
eigenvalue_diff = np.max(np.abs(np.sort(eigenvalues.real) - np.sort(eigenvalues_left.real)))
print(f"Max difference between left and right eigenvalues: {eigenvalue_diff:.2e}")

# Sort by eigenvalue
idx_sorted = np.argsort(eigenvalues.real)
eigenvalues_sorted = eigenvalues[idx_sorted]
eigenvectors_right_sorted = eigenvectors_right[:, idx_sorted]
eigenvectors_left_sorted = eigenvectors_left[idx_sorted, :]

print(f"\nComputed {len(eigenvalues_sorted)} eigenvalues")
print(f"Eigenvalue range: [{eigenvalues_sorted.real.min():.6f}, {eigenvalues_sorted.real.max():.6f}]")

# Count zero eigenvalues
n_zero_eigs = np.sum(np.abs(eigenvalues_sorted) < 1e-10)
print(f"Number of zero eigenvalues (|λ| < 1e-10): {n_zero_eigs}")

# Verify bi-orthogonality: <w_i, v_j> should be δ_ij
biorthogonality = eigenvectors_left_sorted @ eigenvectors_right_sorted
off_diag_max = np.max(np.abs(biorthogonality - np.diag(np.diag(biorthogonality))))
print(f"\nBi-orthogonality check - max off-diagonal: {off_diag_max:.2e}")

# Show first few eigenvalues
print(f"\nFirst 10 eigenvalues:")
for i in range(min(10, len(eigenvalues_sorted))):
    print(f"  λ_{i}: {eigenvalues_sorted[i].real:.6f} + {eigenvalues_sorted[i].imag:.6f}i")

In [None]:
# Plot the spectrum
plt.figure(figsize=(12, 6))
plt.plot(eigenvalues_sorted, 'o-', markersize=4, linewidth=1)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.5)
plt.xlabel('Index')
plt.ylabel('Eigenvalue')
plt.title('Spectrum of Random Walk Laplacian L_hat')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print statistics about the eigenvalues
print(f"Eigenvalue statistics:")
print(f"  Min: {eigenvalues_sorted.min():.4f}")
print(f"  Max: {eigenvalues_sorted.max():.4f}")
print(f"  Mean: {eigenvalues_sorted.mean():.4f}")
print(f"  Number of zero eigenvalues (< 1e-10): {np.sum(np.abs(eigenvalues_sorted) < 1e-10)}")
print(f"  Number of eigenvalues inside unit circle (|λ| < 1): {np.sum(np.abs(eigenvalues_sorted) < 1)}")
print(f"  Number of eigenvalues outside unit circle (|λ| > 1): {np.sum(np.abs(eigenvalues_sorted) > 1)}")

# Plot eigenvalues on the complex plane
plt.figure(figsize=(10, 10))
plt.scatter(eigenvalues_sorted.real, eigenvalues_sorted.imag, 
           s=50, alpha=0.6, edgecolors='black', linewidths=0.5)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.title('Spectrum of Out-degree Laplacian on Complex Plane')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

#### Bi-orthogonal Laplacian Renormalization Group (bi-LRG)

In [None]:
# Bi-LRG parameters
k = 6
alpha = 0.95

In [None]:
# Apply bi-LRG to the connectivity matrix
print(f"Applying bi-LRG with k={k} slow modes...")
print("=" * 50)

# Initialize bi-LRG model
bilrg = BiLRG(k=k, alpha=alpha, cluster_method='kmeans', 
              realify=False, spectral_matrix='L', random_state=42)

# Fit to the unsigned connectivity matrix
A_bilrg = A
bilrg.fit(A_bilrg)
# bilrg.fit(A_bilrg, L=L_bal)

# Print basic results
print(f"Original network size: {A_bilrg.shape[0]} neurons")
coarse_graph = bilrg.get_coarse_graph()
print(f"Coarse network size: {coarse_graph['n_coarse']} groups")
print(f"Spectral fidelity: {coarse_graph['fidelity']:.4f}")
print(f"Slow eigenvalues: {coarse_graph['eigenvalues'][:5]}")

# Show group assignments
groups = coarse_graph['groups']
print(f"\nGroup assignments:")
for group_id in range(coarse_graph['n_coarse']):
    group_size = np.sum(groups == group_id)
    print(f"Group {group_id}: {group_size} neurons")

# Analyze group composition by cell type
print(f"\nGroup composition by cell type:")
for group_id in range(coarse_graph['n_coarse']):
    group_mask = groups == group_id
    group_cell_types = [ctoi_grouped[i] for i in range(len(ctoi_grouped)) if group_mask[i]]
    cell_type_counts_group = {ct: group_cell_types.count(ct) for ct in set(group_cell_types)}
    if cell_type_counts_group:  # Only show non-empty groups
        print(f"Group {group_id}: {dict(sorted(cell_type_counts_group.items()))}")

In [None]:
# Visualize bi-LRG results
print("Visualization of bi-LRG Results")
print("=" * 50)

# Get bi-embedding for visualization
X_biembed = bilrg.get_embedding()
print(f"Bi-embedding shape: {X_biembed.shape}")

# Visualize bi-embedding in 2D (first two principal components)
# Plot 1: First 2 dimensions of right modes (forward dynamics)
plt.figure(figsize=(8, 6))
plt.scatter(X_biembed[:, 0], X_biembed[:, 1], c=groups, 
            cmap='tab10', alpha=0.7, s=50)
plt.xlabel('Right Mode 1 (Forward)')
plt.ylabel('Right Mode 2 (Forward)')
plt.title('bi-LRG Groups in Right Mode Space')
plt.grid(True, alpha=0.3)
cbar1 = plt.colorbar()
cbar1.set_label('Group ID')
plt.tight_layout()
plt.show()

# Plot 2: First 2 dimensions of left modes (backward dynamics)
plt.figure(figsize=(8, 6))
k_half = X_biembed.shape[1] // 2
plt.scatter(X_biembed[:, k_half], X_biembed[:, k_half + 1], c=groups,
            cmap='tab10', alpha=0.7, s=50)
plt.xlabel('Left Mode 1 (Backward)')
plt.ylabel('Left Mode 2 (Backward)')
plt.title('bi-LRG Groups in Left Mode Space')
plt.grid(True, alpha=0.3)
cbar2 = plt.colorbar()
cbar2.set_label('Group ID')

plt.tight_layout()
plt.show()

In [None]:
# Show eigenvalue spectrum on complex plane
U_k, V_k, Lambda_k = bilrg.get_modes()

plt.figure(figsize=(10, 10))
plt.scatter(Lambda_k.real, Lambda_k.imag, 
           s=100, alpha=0.7, edgecolors='black', linewidths=1.5)
plt.axhline(y=0, color='r', linestyle='--', alpha=0.3, linewidth=1)
plt.axvline(x=0, color='r', linestyle='--', alpha=0.3, linewidth=1)

# Add unit circle for reference
theta = np.linspace(0, 2*np.pi, 100)
plt.plot(np.cos(theta), np.sin(theta), 'g--', alpha=0.3, linewidth=1, label='Unit circle')

plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.title('Slow Eigenvalues of bi-LRG on Complex Plane')
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.legend()
plt.tight_layout()
plt.show()

print(f"Slow eigenvalues (k={len(Lambda_k)}): {Lambda_k}")

In [None]:
# Visualize coarse-grained network
print("Coarse Network Analysis")
print("=" * 50)

# Get coarse operators
P_group = coarse_graph['P_group']
L_group = coarse_graph['L_group']
A_group = coarse_graph['A_group']

print(f"Coarse transition matrix P_group shape: {P_group.shape}")
print(f"Coarse Laplacian L_group shape: {L_group.shape}")
print(f"Coarse adjacency A_group shape: {A_group.shape}")

# Plot P_group
plt.figure(figsize=(8, 7))
plt.imshow(P_group, cmap='bwr', aspect='auto',
           vmin=-np.max(np.abs(P_group)), vmax=np.max(np.abs(P_group)))
plt.title('Coarse Transition Matrix (P_group)')
plt.xlabel('Target Group')
plt.ylabel('Source Group')
plt.colorbar(label='Transition Probability')
plt.tight_layout()
plt.show()

# Plot L_group
plt.figure(figsize=(8, 7))
plt.imshow(L_group, cmap='bwr', aspect='auto', 
           vmin=-np.max(np.abs(L_group)), vmax=np.max(np.abs(L_group)))
plt.title('Coarse Laplacian (L_group)')
plt.xlabel('Target Group')
plt.ylabel('Source Group')
plt.colorbar(label='Laplacian Value')
plt.tight_layout()
plt.show()

# Plot A_group
plt.figure(figsize=(8, 7))
plt.imshow(np.log10(np.abs(A_group) + 1), cmap='bwr', aspect='auto',
           vmin=-np.max(np.log10(np.abs(A_group) + 1)), vmax=np.max(np.log10(np.abs(A_group) + 1)))
plt.title('Coarse Adjacency (A_group)')
plt.xlabel('Target Group')
plt.ylabel('Source Group')
plt.colorbar(label='log10(Weight + 1)')
plt.tight_layout()
plt.show()

# Show coarse network statistics
print(f"\nCoarse network statistics:")
print(f"Number of groups: {P_group.shape[0]}")
print(f"Compression ratio: {A_bilrg.shape[0] / P_group.shape[0]:.1f}x")
print(f"Transition matrix row sums (should be ~1): {P_group.sum(axis=1)}")
print(f"Spectral fidelity: {coarse_graph['fidelity']:.4f}")

# Analyze coarse network connectivity
total_coarse_weight = A_group.sum()
print(f"Total weight in coarse network: {total_coarse_weight:.1f}")
print(f"Average weight per connection: {total_coarse_weight / (P_group.shape[0]**2):.3f}")

# Show strongest connections in coarse network
print(f"\nStrongest connections in coarse network:")
flat_indices = np.argsort(A_group.flatten())[::-1][:5]
for idx in flat_indices:
    i, j = np.unravel_index(idx, A_group.shape)
    weight = A_group[i, j]
    if weight > 0.01:  # Only show significant connections
        print(f"Group {i} → Group {j}: weight = {weight:.3f}")

In [None]:
# Bi-Galerkin Projection Analysis
print("Bi-Galerkin Projection Analysis")
print("=" * 40)

# Get bi-Galerkin projected operators
P_galerkin = coarse_graph["P_galerkin"].real
L_galerkin = coarse_graph["L_galerkin"].real
A_galerkin = coarse_graph["A_galerkin"].real
eigenvals = coarse_graph["eigenvalues"]

print(f"Bi-Galerkin operator dimensions: {P_galerkin.shape[0]} x {P_galerkin.shape[1]}")
print(f"Eigenvalues used: {eigenvals}")

# Visualize bi-Galerkin projected operators in separate figures
# Plot P_galerkin
plt.figure(figsize=(8, 7))
plt.imshow(P_galerkin, cmap="bwr", aspect="auto",
           vmin=-np.max(np.abs(P_galerkin)), vmax=np.max(np.abs(P_galerkin)))
plt.title("Bi-Galerkin Projected Transition Matrix (P_galerkin)")
plt.xlabel("Mode Index")
plt.ylabel("Mode Index")
plt.colorbar(label='Value')
plt.tight_layout()
plt.show()

# Plot L_galerkin
plt.figure(figsize=(8, 7))
plt.imshow(L_galerkin, cmap="bwr", aspect="auto",
           vmin=-np.max(np.abs(L_galerkin)), vmax=np.max(np.abs(L_galerkin)))
plt.title("Bi-Galerkin Projected Laplacian (L_galerkin)")
plt.xlabel("Mode Index")
plt.ylabel("Mode Index")
plt.colorbar(label='Value')
plt.tight_layout()
plt.show()

# Plot A_galerkin
plt.figure(figsize=(8, 7))
plt.imshow(np.log10(np.abs(A_galerkin) + 1) * np.sign(A_galerkin), cmap="bwr", aspect="auto",
           vmin=-np.max(np.log10(np.abs(A_galerkin) + 1)), vmax=np.max(np.log10(np.abs(A_galerkin) + 1)))
plt.title("Bi-Galerkin Projected Adjacency (A_galerkin)")
plt.xlabel("Mode Index")
plt.ylabel("Mode Index")
plt.colorbar(label='Value')
plt.tight_layout()
plt.show()

# Print mathematical interpretation
print("\nMathematical Interpretation:")
print(f"P_galerkin = U_k^H @ P @ V_k  (shape: {P_galerkin.shape})")
print(f"L_galerkin = U_k^H @ L_rw @ V_k  (shape: {L_galerkin.shape})")
print(f"A_galerkin = U_k^H @ A @ V_k  (shape: {A_galerkin.shape})")
print(f"\nThis represents the {P_galerkin.shape[0]}-dimensional dynamics")
print("projected into the subspace spanned by the slowest modes.")

# Compare with Markov lumping approach
P_group = coarse_graph["P_group"]
L_group = coarse_graph["L_group"]
A_group = coarse_graph["A_group"]

print(f"\nComparison with Markov Lumping:")
print(f"Bi-Galerkin dimensions: {P_galerkin.shape}")
print(f"Markov lumping dimensions: {P_group.shape}")
print(f"\nBi-Galerkin provides a purely spectral reduction,")
print(f"while Markov lumping preserves node groupings and stationarity.")

##### Verify bi-orthogonality

In [None]:
# Compute left and right eigenvectors for bilrg
print("Computing Bi-orthogonal Eigenvectors for bilrg")
print("=" * 60)

# Select the k slowest modes
U_k = bilrg.U_k_
V_k = bilrg.V_k_
Lambda_k = bilrg.Lambda_k_

# U_k, V_k, Lambda_k = biorthogonal_modes(bilrg.L_rw_, k=5)

print(f"\nSelected k={k} slowest modes:")
print(f"U_k shape: {U_k.shape}")
print(f"V_k shape: {V_k.shape}")
print(f"Lambda_k shape: {Lambda_k.shape}")

# Verify bi-orthogonality of the k-mode subspace: U_k^H @ V_k = I_k
gram_k = U_k.conj().T @ V_k
identity_error_k = np.max(np.abs(gram_k - np.eye(k)))

print(f"\nBi-orthogonality check for k={k} modes: U_k^H @ V_k = I_k")
# print(f"U_k^H @ V_k =")
# print(gram_k)
print(f"\nMax deviation from I_k: {identity_error_k:.2e}")

# Verify that U_k^H @ L @ V_k = Lambda_k (diagonal matrix of eigenvalues)
L_projected = U_k.conj().T @ bilrg.L_ @ V_k

print(f"\nVerification: U_k^H @ L @ V_k")
print(f"Projected Laplacian shape: {L_projected.shape}")
print(f"Expected: diagonal matrix with eigenvalues on diagonal")

# Check if it's diagonal
off_diagonal = L_projected - np.diag(np.diag(L_projected))
off_diagonal_norm = np.max(np.abs(off_diagonal))
print(f"Max off-diagonal magnitude: {off_diagonal_norm:.2e}")

# Extract diagonal and compare with Lambda_k
diagonal_values = np.diag(L_projected)
eigenvalue_error = np.max(np.abs(diagonal_values - Lambda_k))
print(f"Max error between diagonal and Lambda_k: {eigenvalue_error:.2e}")

# print(f"\nDiagonal of U_k^H @ L @ V_k:")
# print(diagonal_values)
# print(f"\nLambda_k (expected eigenvalues):")
# print(Lambda_k)

##### Bi-LRG groups

In [None]:
# Visualize the connectivity matrix reordered by bi-LRG groups
bilrg_labels = coarse_graph['groups']

# Get the order of nodes based on group assignments
node_order = np.argsort(bilrg_labels)
A_reordered = A[np.ix_(node_order, node_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(A_reordered)), aspect='auto', cmap='bwr',
           vmin=-np.max(np.log10(np.abs(A_reordered))), vmax=np.max(np.log10(np.abs(A_reordered))))
plt.colorbar()
plt.title('Reordered Connectivity Matrix (Bi-LRG Groups)')
plt.xlabel('Nodes')
plt.ylabel('Nodes')

# Draw lines to separate groups
unique_groups, group_indices = np.unique(bilrg_labels[node_order], return_index=True)
for idx in group_indices[1:]:
    plt.axhline(y=idx - 0.5, color='k', linestyle='--', linewidth=1)
    plt.axvline(x=idx - 0.5, color='k', linestyle='--', linewidth=1)

plt.tight_layout()
plt.show()

In [None]:
# Visualize the connectivity matrix reordered by bi-LRG groups
bilrg_labels = coarse_graph['groups']

# Get the order of nodes based on group assignments
node_order = np.argsort(bilrg_labels)
A_signed = C_signed_grouped.copy()
A_reordered = A_signed[np.ix_(node_order, node_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(A_reordered)) * np.sign(A_reordered), aspect='auto', cmap='bwr',
           vmin=-np.max(np.log10(np.abs(A_reordered))), vmax=np.max(np.log10(np.abs(A_reordered))))
plt.colorbar()
plt.title('Reordered Connectivity Matrix (Bi-LRG Groups)')
plt.xlabel('Nodes')
plt.ylabel('Nodes')

# Draw lines to separate groups
unique_groups, group_indices = np.unique(bilrg_labels[node_order], return_index=True)
for idx in group_indices[1:]:
    plt.axhline(y=idx - 0.5, color='k', linestyle='--', linewidth=1)
    plt.axvline(x=idx - 0.5, color='k', linestyle='--', linewidth=1)

plt.tight_layout()
plt.show()

### Checkpoint

In [None]:
### Checkpoint
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts',
                     'D_out', 'D_in', 'L_out', 'L_in', 'L_rw', 'A_sym', 'D_sym', 'L_sym', 'L_bal',
                     'P_hat', 'L_hat', 'K_hat', 'rho_hat', 'bilrg', 'P_group', 'L_group', 'A_group', 'P_galerkin', 'L_galerkin', 'A_galerkin', 'bilrg_labels']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

### Modularity optimization

In [None]:
A = C_unsigned_grouped

print(f"Matrix A shape: {A.shape}")
print(f"Matrix A type: {type(A)}")
print(f"Matrix A min: {A.min()}, max: {A.max()}")
print(f"Number of non-zero entries: {np.count_nonzero(A)}")

#### Louvain method

In [None]:
# Convert the unsigned adjacency matrix to a NetworkX graph
G = nx.from_numpy_array(A, create_using=nx.DiGraph())

print(f"Graph created with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

# Apply Louvain community detection
# Note: networkx uses the python-louvain package for this
try:
    communities = community.louvain_communities(G.to_undirected(), seed=42)
    print(f"Louvain detected {len(communities)} communities")
    
    # Sort communities by size (largest first)
    communities_sorted = sorted(communities, key=len, reverse=True)
    
    # Create community assignment array with sorted community indices
    community_labels = np.zeros(len(A), dtype=int)
    for i, comm in enumerate(communities_sorted):
        for node in comm:
            community_labels[node] = i
    
    # Print community sizes
    community_sizes = [len(comm) for comm in communities_sorted]
    print(f"Community sizes: {community_sizes}")
    
    # Calculate modularity
    modularity = community.modularity(G.to_undirected(), communities_sorted)
    print(f"Modularity: {modularity:.4f}")
    
    # Update communities to use sorted version
    communities = communities_sorted
    
except ImportError:
    print("python-louvain not available, using alternative community detection")
    # Fallback to other community detection methods
    communities = list(community.greedy_modularity_communities(G.to_undirected()))
    print(f"Greedy modularity detected {len(communities)} communities")
    
    # Sort communities by size (largest first)
    communities_sorted = sorted(communities, key=len, reverse=True)
    
    community_labels = np.zeros(len(A), dtype=int)
    for i, comm in enumerate(communities_sorted):
        for node in comm:
            community_labels[node] = i
    
    community_sizes = [len(comm) for comm in communities_sorted]
    print(f"Community sizes: {community_sizes}")
    
    modularity = community.modularity(G.to_undirected(), communities_sorted)
    print(f"Modularity (Newman's Q): {modularity:.4f}")
    
    # Update communities to use sorted version
    communities = communities_sorted

# Store results for further analysis
louvain_communities = communities
louvain_labels = community_labels
louvain_modularity = modularity

In [None]:
# Visualize the connectivity matrix reordered by Louvain communities
# Create new ordering based on Louvain communities
louvain_order = []
for i, comm in enumerate(louvain_communities):
    louvain_order.extend(sorted(comm))

# Reorder the matrix by Louvain communities
A_louvain_ordered = A[np.ix_(louvain_order, louvain_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(A_louvain_ordered + 1), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(A_louvain_ordered + 1)),
           vmax=np.max(np.log10(A_louvain_ordered + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Connectivity Matrix Ordered by Louvain Communities ({len(louvain_communities)} communities)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add community boundaries
community_boundaries = []
current_pos = 0
for i, comm in enumerate(louvain_communities):
    community_size = len(comm)
    community_boundaries.append(current_pos + community_size)
    current_pos += community_size

# Add lines to separate communities
for boundary in community_boundaries[:-1]:
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

# Compare Louvain communities with original cell type grouping
print("\nComparison with original cell type grouping:")
print(f"Original cell types: {len(unique_cell_types)}")
print(f"Louvain communities: {len(louvain_communities)}")
print(f"Modularity (Newman's Q): {louvain_modularity:.4f}")

# Create a mapping to see how cell types are distributed across Louvain communities
cell_type_to_louvain = {}
for i, cell_type in enumerate(ctoi_grouped):
    louvain_comm = louvain_labels[i]
    if cell_type not in cell_type_to_louvain:
        cell_type_to_louvain[cell_type] = {}
    if louvain_comm not in cell_type_to_louvain[cell_type]:
        cell_type_to_louvain[cell_type][louvain_comm] = 0
    cell_type_to_louvain[cell_type][louvain_comm] += 1

In [None]:
# Visualize the connectivity matrix reordered by Louvain communities
# Create new ordering based on Louvain communities
louvain_order = []
for i, comm in enumerate(louvain_communities):
    louvain_order.extend(sorted(comm))

# Reorder the matrix by Louvain communities
A_signed = C_signed_grouped
A_louvain_ordered = A_signed[np.ix_(louvain_order, louvain_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(A_louvain_ordered) + 1) * np.sign(A_louvain_ordered), 
           cmap='bwr', aspect='auto', 
           vmin=-np.max(np.log10(np.abs(A_louvain_ordered) + 1)),
           vmax=np.max(np.log10(np.abs(A_louvain_ordered) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Signed Connectivity Matrix Ordered by Louvain Communities ({len(louvain_communities)} communities)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add community boundaries
community_boundaries = []
current_pos = 0
for i, comm in enumerate(louvain_communities):
    community_size = len(comm)
    community_boundaries.append(current_pos + community_size)
    current_pos += community_size

# Add lines to separate communities
for boundary in community_boundaries[:-1]:
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

# Compare Louvain communities with original cell type grouping
print("\nComparison with original cell type grouping:")
print(f"Original cell types: {len(unique_cell_types)}")
print(f"Louvain communities: {len(louvain_communities)}")
print(f"Modularity (Newman's Q): {louvain_modularity:.4f}")

# Create a mapping to see how cell types are distributed across Louvain communities
cell_type_to_louvain = {}
for i, cell_type in enumerate(ctoi_grouped):
    louvain_comm = louvain_labels[i]
    if cell_type not in cell_type_to_louvain:
        cell_type_to_louvain[cell_type] = {}
    if louvain_comm not in cell_type_to_louvain[cell_type]:
        cell_type_to_louvain[cell_type][louvain_comm] = 0
    cell_type_to_louvain[cell_type][louvain_comm] += 1

In [None]:
# Create a 2D heatmap of cell type distribution across Louvain communities
import seaborn as sns

# Create a matrix: rows = cell types, columns = Louvain communities
n_cell_types = len(unique_cell_types)
n_communities = len(louvain_communities)

# Initialize the distribution matrix
distribution_matrix = np.zeros((n_cell_types, n_communities))

# Fill the matrix
for i, cell_type in enumerate(ctoi_grouped):
    louvain_comm = louvain_labels[i]
    cell_type_idx = unique_cell_types.index(cell_type)
    distribution_matrix[cell_type_idx, louvain_comm] += 1

# Create the heatmap
plt.figure(figsize=(16, 12))
sns.heatmap(distribution_matrix, 
            xticklabels=[f'Comm {i}' for i in range(n_communities)],
            yticklabels=unique_cell_types,
            cmap='viridis',
            annot=False,
            fmt='g',
            cbar_kws={'label': 'Number of neurons'})

plt.title('Cell Type Distribution Across Louvain Communities')
plt.xlabel('Louvain Community')
plt.ylabel('Cell Type')
plt.xticks(rotation=45)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.show()

# Show statistics about the distribution
print(f"Distribution matrix shape: {distribution_matrix.shape}")
print(f"Non-zero entries in distribution matrix: {np.count_nonzero(distribution_matrix)}")

# Calculate how many cell types are split across multiple communities
split_cell_types = 0
for i in range(n_cell_types):
    if np.count_nonzero(distribution_matrix[i, :]) > 1:
        split_cell_types += 1

print(f"Cell types split across multiple communities: {split_cell_types}/{n_cell_types}")

# Calculate community purity (max cell type fraction in each community)
community_purity = []
for j in range(n_communities):
    if distribution_matrix[:, j].sum() > 0:
        purity = distribution_matrix[:, j].max() / distribution_matrix[:, j].sum()
        community_purity.append(purity)

print(f"Average community purity: {np.mean(community_purity):.3f}")
print(f"Community purity range: {np.min(community_purity):.3f} - {np.max(community_purity):.3f}")

#### Leiden algorithm

In [None]:
# Implement Leiden algorithm for community detection
# Convert NetworkX graph to igraph for Leiden algorithm
# Create adjacency list
edges = list(G.edges(data=True))
edge_list = [(u, v) for u, v, d in edges]
weights = [d['weight'] for u, v, d in edges]

# Create igraph
g_ig = ig.Graph(directed=False)
g_ig.add_vertices(G.number_of_nodes())
g_ig.add_edges(edge_list)
g_ig.es['weight'] = weights

# Apply Leiden algorithm
partition = leidenalg.find_partition(g_ig, leidenalg.ModularityVertexPartition, 
                                    weights='weight', seed=42)

# Extract community assignments
leiden_communities = [list(community) for community in partition]
leiden_labels = np.zeros(len(A), dtype=int)

for i, comm in enumerate(leiden_communities):
    for node in comm:
        leiden_labels[node] = i

leiden_modularity = partition.modularity

print(f"Leiden detected {len(leiden_communities)} communities")
leiden_community_sizes = [len(comm) for comm in leiden_communities]
print(f"Community sizes: {sorted(leiden_community_sizes, reverse=True)}")
print(f"Modularity (Newman's Q): {leiden_modularity:.4f}")

In [None]:
# Visualize Leiden communities (if available)
# Visualize the connectivity matrix reordered by Leiden communities
leiden_order = []
for i, comm in enumerate(leiden_communities):
    leiden_order.extend(sorted(comm))

# Reorder the matrix by Leiden communities
A_leiden_ordered = A[np.ix_(leiden_order, leiden_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(A_leiden_ordered + 1), 
            cmap='bwr', aspect='auto',
            vmin=-np.max(np.log10(A_leiden_ordered + 1)),
            vmax=np.max(np.log10(A_leiden_ordered + 1)),
            interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Connectivity Matrix Ordered by Leiden Communities ({len(leiden_communities)} communities)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add community boundaries
leiden_boundaries = []
current_pos = 0
for i, comm in enumerate(leiden_communities):
    community_size = len(comm)
    leiden_boundaries.append(current_pos + community_size)
    current_pos += community_size

# Add lines to separate communities
for boundary in leiden_boundaries[:-1]:
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

# Compare Leiden communities with original cell type grouping
print("\nComparison with original cell type grouping:")
print(f"Original cell types: {len(unique_cell_types)}")
print(f"Leiden communities: {len(leiden_communities)}")
print(f"Modularity (Newman's Q): {leiden_modularity:.4f}")

In [None]:
# Visualize Leiden communities (if available)
# Visualize the connectivity matrix reordered by Leiden communities
leiden_order = []
for i, comm in enumerate(leiden_communities):
    leiden_order.extend(sorted(comm))

# Reorder the matrix by Leiden communities
A_signed = C_signed_grouped
A_leiden_ordered = A_signed[np.ix_(leiden_order, leiden_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(A_leiden_ordered) + 1) * np.sign(A_leiden_ordered), 
                   cmap='bwr', aspect='auto',
                   vmin=-np.max(np.log10(np.abs(A_leiden_ordered) + 1)),
                   vmax=np.max(np.log10(np.abs(A_leiden_ordered) + 1)),
                   interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Signed Connectivity Matrix Ordered by Leiden Communities ({len(leiden_communities)} communities)')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add community boundaries
leiden_boundaries = []
current_pos = 0
for i, comm in enumerate(leiden_communities):
    community_size = len(comm)
    leiden_boundaries.append(current_pos + community_size)
    current_pos += community_size

# Add lines to separate communities
for boundary in leiden_boundaries[:-1]:
    plt.axhline(y=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)
    plt.axvline(x=boundary-0.5, color='k', linewidth=0.5, alpha=0.7)

plt.tight_layout()
plt.show()

# Compare Leiden communities with original cell type grouping
print("\nComparison with original cell type grouping:")
print(f"Original cell types: {len(unique_cell_types)}")
print(f"Leiden communities: {len(leiden_communities)}")
print(f"Modularity (Newman's Q): {leiden_modularity:.4f}")

In [None]:
# Create 2D heatmap for Leiden method (if available)
# Create a matrix: rows = cell types, columns = Leiden communities
n_leiden_communities = len(leiden_communities)

# Initialize the distribution matrix
leiden_distribution_matrix = np.zeros((n_cell_types, n_leiden_communities))

# Fill the matrix
for i, cell_type in enumerate(ctoi_grouped):
    leiden_comm = leiden_labels[i]
    cell_type_idx = unique_cell_types.index(cell_type)
    leiden_distribution_matrix[cell_type_idx, leiden_comm] += 1

# Create the heatmap
plt.figure(figsize=(16, 12))
sns.heatmap(leiden_distribution_matrix, 
            xticklabels=[f'Comm {i}' for i in range(n_leiden_communities)],
            yticklabels=unique_cell_types,
            cmap='viridis',
            annot=False,
            fmt='g',
            cbar_kws={'label': 'Number of neurons'})

plt.title('Cell Type Distribution Across Leiden Communities')
plt.xlabel('Leiden Community')
plt.ylabel('Cell Type')
plt.xticks(rotation=45)
plt.yticks(rotation=0, fontsize=8)
plt.tight_layout()
plt.show()

# Show statistics about the distribution
print(f"Leiden distribution matrix shape: {leiden_distribution_matrix.shape}")
print(f"Non-zero entries in distribution matrix: {np.count_nonzero(leiden_distribution_matrix)}")

# Calculate how many cell types are split across multiple communities
leiden_split_cell_types = 0
for i in range(n_cell_types):
    if np.count_nonzero(leiden_distribution_matrix[i, :]) > 1:
        leiden_split_cell_types += 1

print(f"Cell types split across multiple communities: {leiden_split_cell_types}/{n_cell_types}")

# Calculate community purity
leiden_community_purity = []
for j in range(n_leiden_communities):
    if leiden_distribution_matrix[:, j].sum() > 0:
        purity = leiden_distribution_matrix[:, j].max() / leiden_distribution_matrix[:, j].sum()
        leiden_community_purity.append(purity)

print(f"Average community purity: {np.mean(leiden_community_purity):.3f}")
print(f"Community purity range: {np.min(leiden_community_purity):.3f} - {np.max(leiden_community_purity):.3f}")

### Checkpoint

In [None]:
### Checkpoint
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts',
                     'D_out', 'D_in', 'L_out', 'L_in', 'L_rw', 'A_sym', 'D_sym', 'L_sym', 'L_bal',
                     'P_hat', 'L_hat', 'K_hat', 'rho_hat', 'bilrg', 'P_group', 'L_group', 'A_group', 'P_galerkin', 'L_galerkin', 'A_galerkin', 'bilrg_labels',
                     'louvain_communities', 'louvain_labels', 'louvain_modularity', 'leiden_communities', 'leiden_labels', 'leiden_modularity']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

### Degree-corrected SBM

In [None]:
K = k  # Number of blocks

In [None]:
# Prepare the adjacency matrix A for DC-SBM fitting
# Use the Delta7+EPG subset connectivity matrix (87 x 87)
A_dcsbm = C_unsigned_grouped.copy()
# A_dcsbm = A_sym

print(f"Input matrix A shape: {A_dcsbm.shape}")
print(f"Matrix type: {type(A_dcsbm)}")
print(f"Number of non-zero entries: {np.count_nonzero(A_dcsbm)}")
print(f"Matrix density: {np.count_nonzero(A_dcsbm) / (A_dcsbm.shape[0] * A_dcsbm.shape[1]):.3f}")

# Verify matrix properties
print(f"\nMatrix statistics:")
print(f"Min value: {A_dcsbm.min()}")
print(f"Max value: {A_dcsbm.max()}")
print(f"Mean value: {A_dcsbm.mean():.3f}")

# Compute degrees
k_out, k_in = degrees(A_dcsbm)
print(f"\nDegree statistics:")
print(f"Out-degree range: [{k_out.min():.0f}, {k_out.max():.0f}]")
print(f"In-degree range: [{k_in.min():.0f}, {k_in.max():.0f}]")
print(f"Mean out-degree: {k_out.mean():.2f}")
print(f"Mean in-degree: {k_in.mean():.2f}")

In [None]:
# Split data into training and validation sets
print("Splitting data for cross-validation...")
A_train, A_val, mask_val = heldout_split(A_dcsbm, frac=0.1, stratify_degrees=True, seed=42)

print(f"Original matrix: {A_dcsbm.shape}, {np.count_nonzero(A_dcsbm)} non-zeros")
print(f"Training matrix: {A_train.shape}, {np.count_nonzero(A_train.toarray())} non-zeros")
print(f"Validation matrix: {A_val.shape}, {np.count_nonzero(A_val.toarray())} non-zeros")
print(f"Validation mask: {mask_val.shape}, {np.count_nonzero(mask_val.toarray())} entries")

# Verify split
total_train = np.count_nonzero(A_train.toarray())
total_val = np.count_nonzero(A_val.toarray())
total_orig = np.count_nonzero(A_dcsbm)
print(f"\nSplit verification:")
print(f"Train + validation edges: {total_train + total_val} (original: {total_orig})")
print(f"Held-out fraction: {total_val / total_orig:.3f}")

In [None]:
# Fit DC-SBM
# Initialize model
dcsbm = DCSBM(K=K, max_iter=200, tol=1e-4, seed=42, init="spectral")

# Fit the model
dcsbm.fit(A_train)

# Print results
print(f"Converged: {dcsbm.converged_}")
print(f"Number of iterations: {dcsbm.n_iter_}")
print(f"Final ELBO: {dcsbm.elbo_[-1]:.4f}")

# Get block assignments
dcsbm_labels = dcsbm.predict()
print(f"\nBlock assignments:")
for block_id in range(K):
    block_size = np.sum(dcsbm_labels == block_id)
    print(f"Block {block_id}: {block_size} nodes")

# Analyze block composition by cell type
print(f"\nBlock composition by cell type:")
for block_id in range(K):
    block_mask = dcsbm_labels == block_id
    block_cell_types = [ctoi_grouped[i] for i in range(len(ctoi_grouped)) if block_mask[i]]
    cell_type_counts_block = {ct: block_cell_types.count(ct) for ct in set(block_cell_types)}
    print(f"Block {block_id}: {dict(sorted(cell_type_counts_block.items()))}")

# Get model parameters
params = dcsbm.get_params()
print(f"\nModel parameters:")
print(f"Omega (block-block rates):")
print(params['Omega'])
print(f"Block proportions (pi): {params['pi']}")

In [None]:
# Evaluate model performance
print("Model Evaluation")
print("=" * 50)

# Compute validation score
val_score = dcsbm.score(A_val, mask=mask_val)
print(f"Held-out log-likelihood: {val_score:.4f}")

# Compute training score for comparison
train_score = dcsbm.score(A_train)
print(f"Training log-likelihood: {train_score:.4f}")

# Plot ELBO convergence
plt.figure(figsize=(10, 6))
plt.plot(dcsbm.elbo_, 'b-', linewidth=2, label='ELBO')
plt.xlabel('Iteration')
plt.ylabel('Evidence Lower Bound (ELBO)')
plt.title(f'DC-SBM Convergence (K={K})')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# Print diagnostic information
diagnostics = dcsbm.diagnostics()
print(f"\nDiagnostics:")
print(f"ELBO increased monotonically: {all(diagnostics['elbo_trace'][i+1] >= diagnostics['elbo_trace'][i] - 1e-6 for i in range(len(diagnostics['elbo_trace'])-1))}")
print(f"ELBO improvement: {diagnostics['elbo_trace'][-1] - diagnostics['elbo_trace'][0]:.4f}")

# Analyze soft membership matrix
Q = params['Q']
print(f"\nSoft membership analysis:")
print(f"Q matrix shape: {Q.shape}")
print(f"Q row sums (should be ~1): min={Q.sum(axis=1).min():.6f}, max={Q.sum(axis=1).max():.6f}")

# Show degree of "soft" vs "hard" assignments
assignment_confidence = Q.max(axis=1)
print(f"Assignment confidence: mean={assignment_confidence.mean():.3f}, min={assignment_confidence.min():.3f}")
soft_assignments = np.sum(assignment_confidence < 0.8)
print(f"Neurons with soft assignments (confidence < 0.8): {soft_assignments}/{len(labels)}")

In [None]:
# Visualize DC-SBM results
print("Visualization of DC-SBM Results")
print("=" * 50)

# Create ordering based on DC-SBM block assignments
dcsbm_order = []
for block_id in range(K):
    block_nodes = np.where(dcsbm_labels == block_id)[0]
    dcsbm_order.extend(sorted(block_nodes))

# Reorder the matrix by DC-SBM blocks
A_dcsbm_ordered = A_dcsbm[np.ix_(dcsbm_order, dcsbm_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(A_dcsbm_ordered + 1), 
           cmap='bwr', aspect='auto',
           vmin=-np.max(np.log10(A_dcsbm_ordered + 1)),
           vmax=np.max(np.log10(A_dcsbm_ordered + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Connectivity Matrix Ordered by DC-SBM Blocks (K={K})')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add block boundaries
block_boundaries = np.cumsum([0] + [np.sum(dcsbm_labels == i) for i in range(K)])
for b in block_boundaries:
    plt.axhline(y=b-0.5, color='red', linewidth=2, alpha=0.8)
    plt.axvline(x=b-0.5, color='red', linewidth=2, alpha=0.8)

# Add text annotations for blocks
for block_id in range(K):
    block_size = np.sum(dcsbm_labels == block_id)
    center_pos = (np.sum(dcsbm_labels == block_id) // 2) + (0 if block_id == 0 else block_boundaries[block_id])
    plt.text(center_pos, len(dcsbm_labels)+5, f'Block {block_id} (n={block_size})', 
             ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

# Compare DC-SBM blocks with cell types
print(f"\nComparison with biological cell types:")
print(f"Original cell types: {len(unique_cell_types)} (Delta7, EPG)")
print(f"DC-SBM blocks: {len(np.unique(dcsbm_labels))}")

# Create contingency table
contingency = np.zeros((len(unique_cell_types), K))
for i, cell_type in enumerate(ctoi_grouped):
    block_id = dcsbm_labels[i]
    cell_type_idx = unique_cell_types.index(cell_type)
    contingency[cell_type_idx, block_id] += 1

# Visualize contingency table as a heatmap
plt.figure(figsize=(14, 6))
sns.heatmap(contingency, 
            xticklabels=[f'Block {i}' for i in range(K)],
            yticklabels=unique_cell_types,
            cmap='viridis',
            annot=True,
            fmt='g',
            cbar_kws={'label': 'Number of neurons'})
plt.title('Cell Type Distribution Across DC-SBM Blocks')
plt.xlabel('DC-SBM Block')
plt.ylabel('Cell Type')
plt.tight_layout()
plt.show()

print(f"\nContingency table (cell types vs DC-SBM blocks):")
print(f"{'Cell Type':<10}", end=' ')
for block_id in range(K):
    print(f"{'Block ' + str(block_id):<8}", end=' ')
print()
print("-" * (10 + 9 * K))
for i, cell_type in enumerate(unique_cell_types):
    print(f"{cell_type:<10}", end=' ')
    for block_id in range(K):
        print(f"{contingency[i, block_id]:>8.0f}", end=' ')
    print()

# Calculate block purity with respect to cell types
print(f"\nBlock purity analysis:")
for block_id in range(K):
    block_neurons = np.sum(dcsbm_labels == block_id)
    max_cell_type_in_block = contingency[:, block_id].max()
    purity = max_cell_type_in_block / block_neurons if block_neurons > 0 else 0
    dominant_cell_type = unique_cell_types[np.argmax(contingency[:, block_id])]
    print(f"Block {block_id}: {purity:.3f} purity (dominant: {dominant_cell_type})")

In [None]:
# Visualize DC-SBM results
print("Visualization of DC-SBM Results")
print("=" * 50)

# Create ordering based on DC-SBM block assignments
dcsbm_order = []
for block_id in range(K):
    block_nodes = np.where(dcsbm_labels == block_id)[0]
    dcsbm_order.extend(sorted(block_nodes))

# Reorder the matrix by DC-SBM blocks
A_dcsbm_signed = C_signed_grouped
A_dcsbm_signed_ordered = A_dcsbm_signed[np.ix_(dcsbm_order, dcsbm_order)]

plt.figure(figsize=(12, 10))
plt.imshow(np.log10(np.abs(A_dcsbm_signed_ordered) + 1) * np.sign(A_dcsbm_signed_ordered), 
           cmap='bwr', aspect='auto',
           vmin=-np.max(np.log10(np.abs(A_dcsbm_signed_ordered) + 1)),
           vmax=np.max(np.log10(np.abs(A_dcsbm_signed_ordered) + 1)),
           interpolation='none')
plt.colorbar(label='Number of Synapses (log10)')
plt.title(f'Signed Connectivity Matrix Ordered by DC-SBM Blocks (K={K})')
plt.xlabel('Postsynaptic Neuron Index')
plt.ylabel('Presynaptic Neuron Index')

# Add block boundaries
block_boundaries = np.cumsum([0] + [np.sum(dcsbm_labels == i) for i in range(K)])
for b in block_boundaries:
    plt.axhline(y=b-0.5, color='red', linewidth=2, alpha=0.8)
    plt.axvline(x=b-0.5, color='red', linewidth=2, alpha=0.8)

# Add text annotations for blocks
for block_id in range(K):
    block_size = np.sum(dcsbm_labels == block_id)
    center_pos = (np.sum(dcsbm_labels == block_id) // 2) + (0 if block_id == 0 else block_boundaries[block_id])
    plt.text(center_pos, len(dcsbm_labels)+5, f'Block {block_id} (n={block_size})', 
             ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.tight_layout()
plt.show()

### Checkpoint

In [None]:
### Checkpoint
variables_to_keep = ['ctoi_grouped', 'noi_grouped', 'C_signed_grouped', 'C_unsigned_grouped', 'new_order', 'unique_cell_types', 'cell_type_counts',
                     'D_out', 'D_in', 'L_out', 'L_in', 'L_rw', 'A_sym', 'D_sym', 'L_sym', 'L_bal',
                     'P_hat', 'L_hat', 'K_hat', 'rho_hat', 'bilrg', 'P_group', 'L_group', 'A_group', 'P_galerkin', 'L_galerkin', 'A_galerkin', 'bilrg_labels',
                     'louvain_communities', 'louvain_labels', 'louvain_modularity', 'leiden_communities', 'leiden_labels', 'leiden_modularity',
                     'dcsbm', 'dcsbm_labels']
print("Variables to keep for further analysis:")
for var in variables_to_keep:
    if var in globals():
        var_obj = globals()[var]
        var_type = type(var_obj).__name__
        
        if hasattr(var_obj, 'shape'):
            structure = f"shape: {var_obj.shape}"
        elif hasattr(var_obj, '__len__'):
            structure = f"length: {len(var_obj)}"
        else:
            structure = "scalar"
            
        print(f"\t{var}: {var_type}, {structure}")
    else:
        print(f"\t{var}: NOT FOUND")

### Role similarity

### Reduced dynamics

#### Balanced truncation

#### Koopman mode decomposition

## System identification

### Model-based

### Data-driven (simulation-based)

## Role discovery

### Computational library

### Canonical labeling

## System synthesis