In [None]:
# import libraries 
import nest
import nest.voltage_trace
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
import networkx as nx
import pathlib
from collections import deque
import os

rng = np.random.default_rng(42)

In [None]:
DATA_DIR = pathlib.Path("../Datasets/Original")

CELLTYPE_MAP = {
    "KCs": "exc", "PNs": "exc", "PNs-somato": "exc",
    "LNs": "inh", "MB-FBNs": "inh", "MB-FFNs": "inh",
    "pre-DN-SEZs": "mixed", "pre-DN-VNCs": "mixed", "RGNs": "inh",
    "DN-VNCs": "mixed", "LHNs": "exc",
    "MBONs": "exc", "MBINs": "mod", "DN-SEZs": "mixed", "CNs": "mixed"
}

neuron_params = {
    "C_m": 250,
    "tau_m": 30,
    "t_ref": 3,
    "E_L": -70,
    "V_reset": -65,
    "V_th": -55,
}

In [None]:
# load connectivity matrices
aa = pd.read_csv(DATA_DIR / 'aa_connectivity_matrix.csv', index_col=0)
ad = pd.read_csv(DATA_DIR / 'ad_connectivity_matrix.csv', index_col=0)
da = pd.read_csv(DATA_DIR / 'da_connectivity_matrix.csv', index_col=0)
dd = pd.read_csv(DATA_DIR / 'dd_connectivity_matrix.csv', index_col=0)

In [None]:
conn_mats = {
    'aa': aa,
    'ad': ad, 
    'da': da, 
    'dd': dd
}

In [None]:
def load_attributes() -> pd.DataFrame:
    # merge s3 and s4 tabls
    s3 = pd.read_csv(DATA_DIR / "s3.csv")
    s4 = pd.read_csv(DATA_DIR / "s4.csv")
    return s3.merge(s4, on="skid", how="left", suffixes=("_axon", "_dendrite"))

In [None]:
def build_graph(conn_mats: dict[str, pd.DataFrame]) -> nx.DiGraph:
    aa, ad, da, dd = conn_mats["aa"], conn_mats["ad"], conn_mats["da"], conn_mats["dd"]
    G = nx.DiGraph()
    for nid in aa.index:
        G.add_node(int(nid))
    
    for pre_id in aa.index:
        for post_id in aa.columns:
            w_aa, w_ad = aa.loc[pre_id, post_id], ad.loc[pre_id, post_id]
            w_da, w_dd = da.loc[pre_id, post_id], dd.loc[pre_id, post_id]
            weight = w_aa + w_ad + w_da + w_dd 
            if weight > 0:
                G.add_edge(int(pre_id), int(post_id), weight=(weight))
                
    return G

In [None]:
def annotate_graph(G: nx.DiGraph, attr_df: pd.DataFrame,
                   inh_targer_perc: float = 0.20,
                   rng = None) -> None:
    rng = rng or random.Random(42)   
    
    # inhibitory/excitatory
    for nid in G.nodes:
        G.nodes[nid]['is_inh'] = None
        
    known_inh, candidates = set(), []
    for _, row in attr_df.iterrows():
        nid = int(row['skid'])
        if nid not in G: 
            continue
        ctype = str(row.get('celltype_axon', row.get('celltype')))
        tag = CELLTYPE_MAP.get(ctype, 'unknown')
        
        if tag == 'inh':
            G.nodes[nid]['is_inh'] = True
            known_inh.add(nid)
        elif tag == "exc" or tag == "unknown":
            G.nodes[nid]["is_inh"] = False
        else: # mixed/mod
            candidates.append(nid)
        
        total_n = len(G.nodes)
        target_inh = int(inh_targer_perc * total_n)
        remaining_inh = target_inh - len(known_inh)
    
        random.shuffle(candidates)
        selected_inh = set(candidates[:remaining_inh])
    
        for nid in candidates:
            G.nodes[nid]['is_inh'] = (nid in selected_inh)
    
        for nid in G.nodes:
            if G.nodes[nid]['is_inh'] is None:
                G.nodes[nid]['is_inh'] = False
                
    # is_input/is_output
    inp = pd.read_csv(DATA_DIR / "inputs.csv",  index_col=0)
    out = pd.read_csv(DATA_DIR / "outputs.csv", index_col=0)
    for nid in G:
        G.nodes[nid]["is_input"]  = inp["axon_input"].get(nid, 0)  > 50
        G.nodes[nid]["is_output"] = out["axon_output"].get(nid, 0) > 50
    
    # signal depth 
    depth = {n:-1 for n in G.nodes}
    q = deque([n for n in G.nodes if G.nodes[n]['is_input']])
    for n in q: depth[n]=0
    while q:
        u=q.popleft()
        for v in G.successors(u):
            if depth[v]==-1:
                depth[v]=depth[u]+1
                q.append(v)
    for n in G.nodes: G.nodes[n]['signal_depth']=depth[n]
    
    # direction label
    def edge_dir(u,v):
        du,dv = depth[u], depth[v]
        if du<0 or dv<0: return 'unknown'
        if du<dv:        return 'feedforward'
        if du>dv:        return 'feedback'
        return 'lateral'
    
    for u,v,d in G.edges(data=True):
        d['signal_direction']=edge_dir(u,v) 

In [None]:
def create_nest_network(G: nx.DiGraph,
                        neuron_params: dict,
                        BASE_W: float = 1,
                        DELAY: float = 1.5,) -> tuple[list[int], dict[int, int]]:

    nest.ResetKernel()
    neurons = nest.Create("iaf_psc_alpha", G.number_of_nodes(), params=neuron_params)
    node_list = sorted(G.nodes())
    node_index = {nid: i for i, nid in enumerate(node_list)}

    for u, v, d in G.edges(data=True):
        w = d["weight"]
        sign = -1 if G.nodes[u]["is_inh"] else 1
        weight = BASE_W * w * sign
        nest.Connect(
            neurons[node_index[u]:node_index[u]+1],
            neurons[node_index[v]:node_index[v]+1],
            syn_spec={"weight": weight, "delay": DELAY}
        )
    return neurons, node_index

In [None]:
def connect_spike_recorder(neurons):
    spike_recorder = nest.Create("spike_recorder")
    nest.Connect(neurons, spike_recorder)
    return spike_recorder

In [None]:
def connect_dc_generator(
        stim_ids: list, neurons,
        node_index, stim_amp: float, 
        start: float, stop: float):
    dc = nest.Create("dc_generator", params={"amplitude": stim_amp, "start": start, "stop": stop})
    for nid in stim_ids:
        idx = node_index[nid]
        nest.Connect(dc, neurons[idx:idx+1])     

In [None]:
def get_input_neuron_ids(G: nx.DiGraph) -> list[int]:
    return [nid for nid, data in G.nodes(data=True) if data.get("is_input", False)]

In [None]:
def prepare_graph(conn_mats):
    print('building graph...')
    G = build_graph(conn_mats)
    print(f'Graph: {len(G)} neurons,'
          f'with {G.number_of_nodes()} nodes')
    attr = load_attributes()
    print('annotating graph...')
    annotate_graph(G=G, attr_df=attr)
    return G

In [None]:
def run_experiment(
        G: nx.DiGraph,
        stim_start: float,
        stim_stop: float,
        tail: float,
        generator_amp: float = 600,
        stim_ids: list = None,
        BASE_W: float = 1,
        DELAY: float = 2.5,):
    sim_time = stim_stop + tail
    
    print('creating nest network...')
    neurons, node_index = create_nest_network(
        G=G, neuron_params=neuron_params, 
        BASE_W=BASE_W, DELAY=DELAY)
    spike_recorder = connect_spike_recorder(neurons)
    
    if stim_ids is None:                      
        stim_ids = get_input_neuron_ids(G)
        
    print(f"Stimulation of {len(stim_ids)} neurons ")
    connect_dc_generator(
        stim_ids=stim_ids, 
        neurons=neurons,
        node_index=node_index, 
        stim_amp=generator_amp, 
        start=stim_start, 
        stop=stim_stop
    )
    
    nest.Simulate(sim_time)
    
    events = nest.GetStatus(spike_recorder, "events")[0]
    times = events["times"]
    return events, times 

In [None]:
stim_start = 50
stim_stop = 500
tail = 1000
G = prepare_graph(conn_mats)
events, times = run_experiment(G=G, stim_start=stim_start, stim_stop=stim_stop, tail=tail, generator_amp=600, BASE_W=4.6, DELAY=2.5)

## REPORTS

In [None]:
def echo_duration(times: np.ndarray, sim_time: float,
                  quiet_ms: float = 20.0) -> float:
    post = times[times > stim_stop]
    if post.size == 0:
        return 0.0

    post_sorted = np.sort(post)
    gaps = np.diff(np.append(post_sorted, quiet_ms + sim_time))
    end_idx = np.argmax(gaps >= quiet_ms)
    echo_end = post_sorted[end_idx]
    echo_dur = echo_end - stim_stop
    return echo_dur

In [None]:
sim_time = stim_stop + 1000
echo_duration(times, sim_time=sim_time)

In [None]:
def spike_matrix(neurons, events: dict,
                 sim_time: float, bin_ms: float = 1.0):
    senders = events["senders"]
    times = events["times"]

    n_neurons = len(neurons)
    n_bins = int(np.ceil(sim_time / bin_ms))

    X = np.zeros((n_neurons, n_bins), dtype=np.int32)

    gid2row = {neurons[i].global_id: i for i in range(n_neurons)}
    bin_idx = (times // bin_ms).astype(int)
    bin_idx[bin_idx == n_bins] = n_bins - 1        

    for gid, b in zip(senders, bin_idx):
        X[gid2row[gid], b] += 1

    return X, gid2row

In [None]:
def save_activity_matrix(events: dict, neurons,
                         sim_time: float, bin_ms: float = 1.0,
                         fname: str = "activity_matrix.csv") -> np.ndarray:
    X, _ = spike_matrix(neurons, events, sim_time, bin_ms)
    
    df = pd.DataFrame(X.T, columns=[f"neuron_{i}" for i in range(X.shape[0])])
    df.to_csv(fname, index_label="time_bin")
    
    print(f"{fname} saved")
    return X

In [None]:
def save_adj_matrix(G: nx.DiGraph,
                    node_index,
                    BASE_W: float = 1.0,
                    fname: str = "adj.csv"):
    N = len(node_index)
    adj = np.zeros((N, N), dtype=np.float32)
    for u, v, d in G.edges(data=True):
        sign = -1 if G.nodes[u]["is_inh"] else 1
        adj[node_index[u], node_index[v]] = BASE_W * sign * d["weight"]
    pd.DataFrame(adj,
                 index=list(node_index.keys()),
                 columns=list(node_index.keys())
                ).to_csv(fname)
    print(f"{fname} saved")

In [None]:
def report_most_active_neuron(X, node_list: list[int], G: nx.DiGraph):
    total_spikes = X.sum(axis=1)
    idx_max = int(np.argmax(total_spikes))
    nid_max = node_list[idx_max]

    print("─" * 40)
    print(f"Most active neuron: {nid_max}")
    print(f"Total spikes = {total_spikes[idx_max]}")
    print(f"Signal depth = {G.nodes[nid_max].get('signal_depth')}")
    print(f"In-degree and out-deg = {G.in_degree(nid_max)} / {G.out_degree(nid_max)}")
    print(f"Is inhibitory = {G.nodes[nid_max].get('is_inh')}")
    print(f"Is Input-class neuron = {G.nodes[nid_max].get('is_input')}")
    if nx.has_path(G, nid_max, nid_max):
        print("Self-loop or cycle detected")

In [None]:
def save_activity_by_depth(X, G: nx.DiGraph, node_index, fname: str = "activity_by_depth.csv"):
    depth_attr = nx.get_node_attributes(G, "signal_depth")
    records = []
    for lvl in sorted(set(depth_attr.values())):
        rows = [node_index[n] for n, d in depth_attr.items() if d == lvl]
        if not rows:
            continue
        total = X[rows, :].sum()
        records.append({
            "signal_depth": lvl,
            "total_spikes": total,
            "mean_spikes": total / len(rows),
            "n_neurons": len(rows) })
    pd.DataFrame(records).to_csv(fname, index=False)
    print(f"{fname} saved")

## PREVIOUS CODE

In [None]:
# connect neurons based on connectivity matrices, 
# for each conn type, create a connection if weight > 0 
for pre_id in aa.index:
    for post_id in aa.columns:
        # get the index in the NEST neuron list
        pre_idx = node_index[int(pre_id)]
        post_idx = node_index[int(post_id)]
        
        # get the actual neuron objects
        pre_neuron = neurons[pre_idx:pre_idx+1]
        post_neuron = neurons[post_idx:post_idx+1]
        
        # type aa
        weight_aa = aa.loc[pre_id, post_id]
        if weight_aa > 0:
            # connect with scaled weight and fixed delay
            nest.Connect(
                pre_neuron,
                post_neuron,
                syn_spec={"weight": 0.5 * weight_aa, "delay": 1.5}
            )
        
        # type ad
        weight_ad = ad.loc[pre_id, post_id]
        if weight_ad > 0:
            nest.Connect(
                pre_neuron,
                post_neuron,
                syn_spec={"weight": 0.5 * weight_ad, "delay": 1.5}
            )
        
        # type da
        weight_da = da.loc[pre_id, post_id]
        if weight_da > 0:
            nest.Connect(
                pre_neuron,
                post_neuron,
                syn_spec={"weight": 0.5 * weight_da, "delay": 1.5}
            )
        
        # type dd
        weight_dd = dd.loc[pre_id, post_id]
        if weight_dd > 0:
            nest.Connect(
                pre_neuron,
                post_neuron,
                syn_spec={"weight": 0.5 * weight_dd, "delay": 1.5}
            )

In [None]:
# set up spike recording device and connect to all neurons
spike_recorder = nest.Create("spike_recorder")
nest.Connect(neurons, spike_recorder)

In [None]:
# select random neurons for stimulation
import random 

num_neurons_to_stimulate = 10
stimulated_indices = random.sample(range(N_neurons), num_neurons_to_stimulate)
stimulated_neurons = [neurons[i] for i in stimulated_indices]
stimulated_neurons

In [None]:
#  set up a multimeter to record membrane potential from one neuron that were activated
multimeter = nest.Create("multimeter", params={"record_from": ["V_m"], "interval": 0.1})
nest.Connect(multimeter, stimulated_neurons[0][0:1])

In [None]:
# create a DC generator for external stimulation
dc = nest.Create("dc_generator", params={"amplitude": 400.0, "start": 50.0, "stop": 150.0})

In [None]:
# connect DC generator to each selected neuron
for i in range(len(stimulated_neurons)):
    nest.Connect(dc, stimulated_neurons[i][0:1])

In [None]:
# run the simulation for the specified time
sim_time = 200.0
nest.Simulate(sim_time)

In [None]:
# print spike events recorded during simulation
spikes = nest.GetStatus(spike_recorder, "events")[0]
print(spikes)

In [None]:
print(nest.Models())

In [None]:
# print voltage recordings from the multimeter
dmm = nest.GetStatus(multimeter)[0]
Vms = dmm["events"]["V_m"]
for vm in Vms:
    print(vm)