In [1]:
import jaxley as jx
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
import glob

# ---------------------------------------------------------
# 1. CONFIGURATION
# ---------------------------------------------------------
TYPE_NAME = "CT1"
# glob for the body_id
swcs_with_type = glob.glob(f"./morphologies/{TYPE_NAME}_*.swc")
body_id = swcs_with_type[0].split(f"{TYPE_NAME}_")[1].split("_")[0]
BODY_ID = int(body_id)

SWC_PATH = f"./morphologies/{TYPE_NAME}_{BODY_ID}_skeleton_with_synapses.swc"
INPUTS_CSV = f"./morphologies/{TYPE_NAME}_{BODY_ID}_postsynaptic_connections.csv"
OUTPUTS_CSV = f"./morphologies/{TYPE_NAME}_{BODY_ID}_presynaptic_connections.csv"

In [None]:
# ---------------------------------------------------------
# 2. LOAD AND SCALE MORPHOLOGY
# ---------------------------------------------------------
print(f"Loading SWC: {SWC_PATH}")
# Jaxley reads SWC. Note: We assume the SWC is in nanometers (NeuPrint default).
ct1 = jx.read_swc(SWC_PATH, ncomp=1) # n_comp=1 means 1 compartment per SWC segment

Loading SWC: ./morphologies/CT1_10009_skeleton_with_synapses.swc


  warn(
  warn(


In [None]:
from tqdm import tqdm
# ---------------------------------------------------------
# 3. LOAD SYNAPSES
# ---------------------------------------------------------
print("Loading Synapse CSVs...")
# Inputs: Synapses ONTO CT1. Location is 'x_post', 'y_post', 'z_post'
df_inputs = pd.read_csv(INPUTS_CSV)
input_coords = df_inputs[['x_post', 'y_post', 'z_post']].values 

# Outputs: Synapses FROM CT1. Location is 'x_pre', 'y_pre', 'z_pre'
df_outputs = pd.read_csv(OUTPUTS_CSV)
output_coords = df_outputs[['x_pre', 'y_pre', 'z_pre']].values

print(f"Loaded {len(df_inputs)} Input Synapses")
print(f"Loaded {len(df_outputs)} Output Synapses")

Loading Synapse CSVs...
Loaded 62900 Input Synapses
Loaded 62900 Output Synapses
Mapping synapses to nearest compartments...


12601it [04:06, 51.18it/s]


In [38]:
# 2. Build Tree
all_comp_indices = ct1.nodes.index.to_numpy()
tree = KDTree(ct1.nodes[['x', 'y', 'z']].values)

# 3. Query - Input Synapses
# Note: SWC and CSV are both in nm, so we don't need to scale for matching.
dists_in, idxs_in = tree.query(input_coords)
df_inputs['jaxley_branch'] = [all_comp_indices[i] for i in idxs_in]

# 4. Query - Output Synapses
dists_out, idxs_out = tree.query(output_coords)
df_outputs['jaxley_branch'] = [all_comp_indices[i] for i in idxs_out]

In [None]:
import jaxley as jx
from jaxley.channels import Leak
from jaxley.synapses import IonotropicSynapse
import jax.numpy as jnp
import pandas as pd
import numpy as np
from scipy.spatial import KDTree
import glob


class GradedSynapse(jx.Synapse):
    """
    Continuous transmission synapse.
    Conductance g = g_max / (1 + exp(-(v_pre - v_mid)/slope))
    """
    def __init__(self, g_max=1e-3, v_mid=-40.0, slope=5.0, e_syn=0.0):
        super().__init__()
        self.g_max = g_max
        self.v_mid = v_mid
        self.slope = slope
        self.e_syn = e_syn
        # Jaxley requires defining parameters
        self.synapse_params = {
            "GradedSynapse_g_max": g_max,
            "GradedSynapse_v_mid": v_mid,
            "GradedSynapse_slope": slope,
            "GradedSynapse_e_syn": e_syn,
        }
        self.synapse_state_names = [] 

    def current(self, state, v_pre, v_post):
        # We access parameters dynamically so they can be optimized/set
        g_max = self.synapse_params["GradedSynapse_g_max"]
        v_mid = self.synapse_params["GradedSynapse_v_mid"]
        slope = self.synapse_params["GradedSynapse_slope"]
        e_syn = self.synapse_params["GradedSynapse_e_syn"]
        
        # Sigmoidal activation
        g = g_max / (1.0 + jnp.exp(-(v_pre - v_mid) / slope))
        return g * (v_post - e_syn)

# ---------------------------------------------------------
# 2. LOAD DATA
# ---------------------------------------------------------
print(f"Loading SWC: {SWC_PATH}")
# Step 2: Detailed neuron. 
# ncomp=1 gives 1 compartment per SWC segment (high detail).
ct1 = jx.read_swc(SWC_PATH, ncomp=1)
ct1.insert(Leak()) # Active channels can be added here if needed

# Load Synapses
df_inputs = pd.read_csv(INPUTS_CSV)
df_outputs = pd.read_csv(OUTPUTS_CSV)

# Group synapses by Body ID to create "Neurons" instead of just "Synapses"
# If 'bodyId_pre' / 'bodyId_post' columns exist, use them. 
# Otherwise, we treat each synapse as a distinct input source (less bio-plausible).
if 'bodyId_pre' in df_inputs.columns:
    input_groups = df_inputs.groupby('bodyId_pre')
    print(f"Grouped {len(df_inputs)} synapses into {len(input_groups)} Input Neurons")
else:
    # Fallback: Treat every synapse as a separate input
    df_inputs['bodyId_pre'] = df_inputs.index
    input_groups = df_inputs.groupby('bodyId_pre')

if 'bodyId_post' in df_outputs.columns:
    output_groups = df_outputs.groupby('bodyId_post')
    print(f"Grouped {len(df_outputs)} synapses into {len(output_groups)} Output Neurons")
else:
    df_outputs['bodyId_post'] = df_outputs.index
    output_groups = df_outputs.groupby('bodyId_post')

# ---------------------------------------------------------
# 3. CREATE POINT NEURONS
# ---------------------------------------------------------

input_cells = [jx.Cell() for _ in range(len(input_groups))]
for cell in input_cells:
    cell.insert(Leak())

# Create list of output point neurons
output_cells = [jx.Cell() for _ in range(len(output_groups))]
for cell in output_cells:
    cell.insert(Leak())

# Combine into one Network
# Order: [Detailed_CT1, Input_1...Input_N, Output_1...Output_M]
network = jx.Network([ct1] + input_cells + output_cells)

# ---------------------------------------------------------
# 4. MAP SYNAPSES TO BRANCHES (KDTree)
# ---------------------------------------------------------
# Map coordinates to Detailed Neuron Branches
tree = KDTree(ct1.nodes[['x', 'y', 'z']].values)
all_comp_indices = ct1.nodes.index.to_numpy() # Global indices

# Helper to get branch index from coordinate
def get_branch_idx(coords):
    dist, idx = tree.query(coords)
    # Map KDTree index back to dataframe index (if not aligned)
    # However, jx.read_swc nodes usually align.
    return all_comp_indices[idx]

# ---------------------------------------------------------
# 5. CONNECT NETWORK
# ---------------------------------------------------------
print("Connecting Network...")

# Connect INPUTS -> CT1
# We iterate over our created input_cells and the grouped dataframe
# The 'network.cell(i)' index depends on the list order passed to jx.Network
ct1_net_idx = 0
input_start_idx = 1
output_start_idx = 1 + len(input_cells)

# INPUTS
for i, (body_id, group) in enumerate(input_groups):
    pre_cell = network.cell(input_start_idx + i)
    
    # Iterate over all synapses from this presynaptic body
    for _, row in group.iterrows():
        # Find where on CT1 this synapse lands
        coords = row[['x_post', 'y_post', 'z_post']].values
        branch_idx = get_branch_idx(coords)
        
        # Connect Point(0.0) -> Detailed(branch_idx, loc=1.0)
        # Using standard Excitatory/Inhibitory synapse
        # You can toggle sign by changing 'e_syn' (0 for Exc, -80 for Inh)
        jx.connect(
            pre_cell.branch(0).loc(1.0),
            network.cell(ct1_net_idx).branch(branch_idx).loc(1.0),
            IonotropicSynapse(e_syn=0.0, gS=0.001) # Excitatory default
        )

# OUTPUTS
for i, (body_id, group) in enumerate(output_groups):
    post_cell = network.cell(output_start_idx + i)
    
    for _, row in group.iterrows():
        # Find where on CT1 this synapse originates
        coords = row[['x_pre', 'y_pre', 'z_pre']].values
        branch_idx = get_branch_idx(coords)
        
        # Connect Detailed(branch_idx, loc=1.0) -> Point(0.0)
        # Using GradedSynapse since CT1 is likely non-spiking
        jx.connect(
            network.cell(ct1_net_idx).branch(branch_idx).loc(1.0),
            post_cell.branch(0).loc(0.0),
            GradedSynapse(v_mid=-40.0, g_max=0.001, e_syn=-80.0) # Inhibitory default
        )

# ---------------------------------------------------------
# 6. STIMULATION & RECORDING
# ---------------------------------------------------------
# Fictitious Drive: Stimulate a random subset of input neurons
print("Setting up stimulation...")
dt = 0.025
t_max = 100.0
n_stim = int(len(input_cells) * 0.2) # Stimulate 20% of inputs
stim_indices = np.random.choice(len(input_cells), n_stim, replace=False)

# Create a step current
current = jx.step_current(i_delay=10.0, i_dur=50.0, i_amp=0.1, delta_t=dt, t_max=t_max)

for idx in stim_indices:
    # Input cells are indices [input_start_idx : output_start_idx]
    network.cell(input_start_idx + idx).branch(0).loc(0.0).stimulate(current)

# Record Outputs
# We record voltage of all output point neurons to measure population vector
for i in range(len(output_cells)):
    network.cell(output_start_idx + i).record("v")

# Also record CT1 soma (root) to see what's happening inside
network.cell(ct1_net_idx).branch(0).loc(0.0).record("v")

# ---------------------------------------------------------
# 7. RUN SIMULATION
# ---------------------------------------------------------
print("Simulating...")
# To run on GPU, ensure JAX is installed with CUDA
voltage = jx.integrate(network, delta_t=dt)

print("Simulation Complete. Shape:", voltage.shape)
# Voltage shape: (Time, Num_Recordings)
# Plotting logic would follow here

array([[-249.82421875, -151.82421875, -254.        ],
       [ -13.65039062,  -43.6484375 ,  415.3515625 ],
       [ 284.140625  ,   31.140625  , -134.        ],
       ...,
       [ 100.        ,   81.65625   ,  -98.34375   ],
       [ -23.50976562,  -83.50976562,  100.5078125 ],
       [ 126.59960938,    1.19921875,  -76.        ]], shape=(62900, 3))

In [None]:
import jaxley as jx
from jaxley.synapses import IonotropicSynapse
from jaxley.channels import HH, Leak
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import KDTree
import glob
import jax.numpy as jnp

def create_point_neuron():
    """Creates a simple point neuron (1 branch, 1 compartment)."""
    comp = jx.Compartment()
    branch = jx.Branch(comp, ncomp=1)
    cell = jx.Cell(branch, parents=[-1])
    cell.insert(Leak()) # Add basic leak
    return cell

# Define Synapses
class ExcitatorySynapse(IonotropicSynapse):
    def __init__(self):
        super().__init__()
        self.synapse_params["e_syn"] = 0.0  # mV
        self.synapse_params["tau"] = 5.0    # ms
        self.synapse_params["gS"] = 0.001   # uS (Adjust based on density)

class InhibitorySynapse(IonotropicSynapse):
    def __init__(self):
        super().__init__()
        self.synapse_params["e_syn"] = -70.0 # mV
        self.synapse_params["tau"] = 10.0    # ms
        self.synapse_params["gS"] = 0.005    # uS


def build_network(morphology_path, df_in, df_out, detail_level='full'):
    """
    Builds the Jaxley network.
    detail_level: 'full' (SWC morphology) or 'point' (1 compartment)
    """
    print(f"Building network with detail: {detail_level}")
    
    # 1. Load Main Neuron
    if detail_level == 'full':
        main_cell = jx.read_swc(morphology_path, ncomp=1)
        # Fix: ensure passive properties are reasonable
        main_cell.insert(Leak())
        main_cell.set("Leak_g_leak", 0.0001) # S/cm2
    else:
        # Create a single compartment sphere representing the cell
        comp = jx.Compartment()
        # Estimate area from the full morphology if possible, or use standard
        branch = jx.Branch(comp, ncomp=1) 
        main_cell = jx.Cell(branch, parents=[-1])
        main_cell.insert(Leak())

    # 2. Create Pre and Post Neurons
    # Optimization: For very large N, consider grouping or input currents.
    # Here we simulate explicit point neurons.
    n_inputs = len(df_in)
    n_outputs = len(df_out)
    
    # Create lists of cells
    pre_cells = [create_point_neuron() for _ in range(n_inputs)]
    post_cells = [create_point_neuron() for _ in range(n_outputs)]
    
    # Combine into Network
    # Order: [Pre... , Main, Post...]
    all_cells = pre_cells + [main_cell] + post_cells
    net = jx.Network(all_cells)
    
    # Indices in the network list
    main_cell_idx = n_inputs
    post_cell_start_idx = n_inputs + 1
    
    # 3. MAPPING SYNAPSES (Spatial Query)
    # We need to map the XYZ coordinates to the specific branch index of the main cell.
    if detail_level == 'full':
        # Use the KDTree on the loaded cell's nodes
        # jx.read_swc creates nodes accessible via .nodes
        cell_nodes = net.cell(main_cell_idx).nodes
        tree = KDTree(cell_nodes[['x', 'y', 'z']].values)
        
        # Query Inputs
        _, idxs_in = tree.query(df_in[['x_post', 'y_post', 'z_post']].values)
        # Get the global branch index for each synapse
        target_branches_in = cell_nodes.iloc[idxs_in]['global_branch_index'].values
        
        # Query Outputs
        _, idxs_out = tree.query(df_out[['x_pre', 'y_pre', 'z_pre']].values)
        source_branches_out = cell_nodes.iloc[idxs_out]['global_branch_index'].values
    else:
        # For point model, everything connects to branch 0
        target_branches_in = np.zeros(n_inputs, dtype=int)
        source_branches_out = np.zeros(n_outputs, dtype=int)

    # 4. CONNECTING
    print("Connecting Synapses...")
    
    # PRE -> MAIN
    # We use vectorization: connect(List of Pre, List of Post)
    pre_locs = net.cell(list(range(n_inputs))).branch(0).loc(1.0)
    # Connect to the specific calculated branch on the main cell
    post_locs = net.cell(main_cell_idx).branch(target_branches_in.tolist()).loc(0.5)
    jx.connect(pre_locs, post_locs, ExcitatorySynapse())

    # MAIN -> POST
    pre_locs_out = net.cell(main_cell_idx).branch(source_branches_out.tolist()).loc(0.5)
    post_locs_out = net.cell(list(range(post_cell_start_idx, post_cell_start_idx + n_outputs))).branch(0).loc(0.0)
    jx.connect(pre_locs_out, post_locs_out, InhibitorySynapse())
    
    return net, n_inputs, n_outputs

# ---------------------------------------------------------
# 4. SIMULATION PROTOCOL
# ---------------------------------------------------------
def run_simulation(net, n_inputs, n_outputs, t_max=100.0):
    dt = 0.1
    
    # 1. Generate Fictitious Input Drive
    # Randomly select 20% of inputs to fire
    np.random.seed(42)
    active_indices = np.random.choice(n_inputs, size=int(n_inputs * 0.2), replace=False)
    
    # Create a step current for active input neurons
    # Current injection: delay 10ms, duration 50ms
    current = jx.step_current(i_delay=10.0, i_dur=50.0, i_amp=0.2, delta_t=dt, t_max=t_max)
    
    # Apply stimulus to Pre neurons
    # We must iterate or use specific indexing. jaxley allows stimulating a View.
    net.cell(active_indices.tolist()).branch(0).loc(0.0).stimulate(current)
    
    # 2. Record
    # Record voltage of all POST neurons
    # The post neurons are at indices [n_inputs+1 : n_inputs+1+n_outputs]
    post_indices = list(range(n_inputs + 1, n_inputs + 1 + n_outputs))
    net.cell(post_indices).branch(0).loc(0.0).record("v")
    
    # Optional: Record Main neuron soma to see global voltage
    net.cell(n_inputs).branch(0).loc(0.0).record("v")
    
    # 3. Integrate
    print("Simulating...")
    result = jx.integrate(net, delta_t=dt)
    return result, post_indices

# ---------------------------------------------------------
# 5. EXECUTION & ANALYSIS
# ---------------------------------------------------------

# Note: For testing, slice your DFs to small numbers (e.g., 50 synapses) 
# or JAX compilation might take a while for the first run.
df_in_subset = df_inputs.iloc[:50] # Remove slice for full run
df_out_subset = df_outputs.iloc[:50]

# --- RUN DETAILED MODEL ---
net_detailed, n_in, n_out = build_network(SWC_PATH, df_in_subset, df_out_subset, detail_level='full')
res_detailed, post_idxs = run_simulation(net_detailed, n_in, n_out)
# Extract data (Shape: Time x Number_of_Recordings)
# Note: Identify which column corresponds to which recording. 
# Jaxley stores recordings in order. We recorded 'v' on Post cells then Main soma.
v_post_detailed = res_detailed[:, :n_out] 

# --- RUN POINT MODEL ---
net_point, _, _ = build_network(SWC_PATH, df_in_subset, df_out_subset, detail_level='point')
res_point, _ = run_simulation(net_point, n_in, n_out)
v_post_point = res_point[:, :n_out]

# --- COMPARE ---
# Compute Cosine Similarity between the population output vectors at peak response
# (e.g., at t=40ms, index 400)
time_idx = 400 
vec_detailed = v_post_detailed[time_idx, :]
vec_point = v_post_point[time_idx, :]

# Normalize
norm_detailed = vec_detailed / np.linalg.norm(vec_detailed)
norm_point = vec_point / np.linalg.norm(vec_point)
similarity = np.dot(norm_detailed, norm_point)

print(f"Cosine Similarity of Output Population Vector: {similarity:.4f}")

# Plotting
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title(f"Detailed Model (Cell {BODY_ID})")
plt.plot(v_post_detailed)
plt.xlabel("Time steps")
plt.ylabel("Post-Synaptic Voltage (mV)")

plt.subplot(1, 2, 2)
plt.title("Point Model")
plt.plot(v_post_point)
plt.xlabel("Time steps")
plt.show()