In [15]:
import os
import sys
from pathlib import Path 
import importlib
import numpy as np
import matplotlib.pyplot as plt
import burst_analysis.detection as bad
importlib.reload(bad)
import burst_analysis.computation as bac
importlib.reload(bac)
import burst_analysis.plotting as bat
importlib.reload(bat)
from burst_analysis.loading import SpikeDataLoader
from projects.parkinsons.coordinator import OrchestratorPDx2 
project_root = Path("~/bioinformatics").expanduser().resolve()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
data_path = Path("~/bioinformatics/data/extracted/maxtwo_newconfig1").expanduser()
orc = OrchestratorPDx2(data_path) 

In [20]:
# load datasets
loader = SpikeDataLoader(data_path)
datasets = loader.load() # dictionary of dataset_keys -> SpikeData objects
print(f"Datasets loaded: {list(datasets.keys())}")

Datasets loaded: ['M06359s1_D53_175µM_T1_1hr', 'M06359s1_D53_175µM_T1_3hr', 'M06359s1_D55_175µM_T2_24hr', 'M06359s1_D56_175µM_T2_48hr', 'M06359s1_D57_175µM_T2_72hr', 'M06359s2_D53_Control_T1_1hr', 'M06359s2_D53_Control_T1_3hr', 'M06359s2_D55_Control_T2_24hr', 'M06359s2_D56_Control_T2_48hr', 'M06359s2_D57_Control_T2_72hr', 'M06359s3_D53_175µM_T1_1hr', 'M06359s3_D53_175µM_T1_3hr', 'M06359s3_D55_175µM_T2_24hr', 'M06359s3_D56_175µM_T2_48hr', 'M06359s3_D57_175µM_T2_72hr', 'M06359s4_D53_175µM_T1_1hr', 'M06359s4_D53_175µM_T1_3hr', 'M06359s4_D55_175µM_T2_24hr', 'M06359s4_D56_175µM_T2_48hr', 'M06359s4_D57_175µM_T2_72hr', 'M06359s5_D53_175µM_T1_1hr', 'M06359s5_D53_175µM_T1_3hr', 'M06359s5_D55_175µM_T2_24hr', 'M06359s5_D56_175µM_T2_48hr', 'M06359s5_D57_175µM_T2_72hr', 'M06359s6_D53_175µM_T1_1hr', 'M06359s6_D53_175µM_T1_3hr', 'M06359s6_D55_175µM_T2_24hr', 'M06359s6_D56_175µM_T2_48hr', 'M06359s6_D57_175µM_T2_72hr', 'MO6359s1_D53_175µM_BASELINE_0hr', 'MO6359s1_D58_175µM_T2_D4', 'MO6359s1_D59_175µM_T

In [None]:
def compute_ifr_matrix(spike_trains, duration, bin_size=0.01):
    bin_edges = np.arange(0, duration + bin_size, bin_size)
    n_units = len(spike_trains)
    ifr = np.zeros((len(bin_edges)-1, n_units))
    for i, train in enumerate(spike_trains):
        if len(train) > 0:
            counts, _ = np.histogram(train, bins=bin_edges)
            ifr[:, i] = counts / bin_size
    time_axis = bin_edges[:-1] + bin_size/2
    return time_axis, ifr

def compute_cc_matrix(spike_trains, duration, bin_size=0.01):
    time_axis, ifr = compute_ifr_matrix(spike_trains, duration, bin_size)
    mean_rates = ifr.mean(axis=0)
    std_rates = ifr.std(axis=0)
    std_rates[std_rates == 0] = 1
    ifr_norm = (ifr - mean_rates) / std_rates
    return np.corrcoef(ifr_norm.T)

def compute_lag_matrix(spike_trains, duration, bin_size=0.01, max_lag=0.35):
    time_axis, ifr = compute_ifr_matrix(spike_trains, duration, bin_size)
    n_units = ifr.shape[1]
    mean_rates = ifr.mean(axis=0)
    std_rates = ifr.std(axis=0)
    std_rates[std_rates == 0] = 1
    ifr_norm = (ifr - mean_rates) / std_rates
    n_bins = ifr_norm.shape[0]
    max_lag_bins = int(max_lag / bin_size)
    lag_matrix = np.zeros((n_units, n_units))
    for i in range(n_units):
        for j in range(n_units):
            if i == j:
                continue
            corr = np.correlate(ifr_norm[:, i], ifr_norm[:, j], mode='full')
            lags = np.arange(-n_bins+1, n_bins) * bin_size
            mid = len(corr)//2
            low, high = mid-max_lag_bins, mid+max_lag_bins+1
            idx = np.argmax(np.abs(corr[low:high]))
            lag_matrix[i,j] = lags[low+idx]*1000
    return lag_matrix

# -------------------
# Plot Functions
# -------------------

def plot_panel_A(spike_trains, burst_windows):
    fig, ax = plt.subplots(figsize=(6,3))
    first_burst = burst_windows[0]
    last_burst = burst_windows[-1]
    for unit_idx, spikes in enumerate(spike_trains):
        spikes_first = spikes[(spikes>=first_burst[0])&(spikes<=first_burst[1])]
        spikes_last = spikes[(spikes>=last_burst[0])&(spikes<=last_burst[1])]
        ax.plot(spikes_first, np.ones_like(spikes_first)*unit_idx, 'k.', alpha=0.5)
        ax.plot(spikes_last, np.ones_like(spikes_last)*unit_idx, 'r.', alpha=0.5)
    ax.set_title('Panel A: Spike Times First vs Last Burst')
    return fig

def plot_panel_B(cc_matrix):
    fig, ax = plt.subplots(figsize=(5,3))
    ax.plot(cc_matrix.mean(axis=0))
    ax.set_title('Panel B: Mean Firing Rate Profile')
    return fig

def plot_panel_C(cc_matrix):
    fig, ax = plt.subplots(figsize=(5,5))
    im = ax.imshow(cc_matrix, cmap='gray_r', vmin=0, vmax=np.percentile(cc_matrix,99))
    fig.colorbar(im, ax=ax)
    ax.set_title('Panel C: Cross-corr matrix')
    return fig

def plot_panel_D(lag_matrix):
    fig, ax = plt.subplots(figsize=(5,5))
    im = ax.imshow(lag_matrix, cmap='bwr', vmin=-150, vmax=150)
    fig.colorbar(im, ax=ax)
    ax.set_title('Panel D: Lag of Max Cross-Corr')
    return fig

def plot_panel_E(cc_matrix, backbone_units):
    n_units = cc_matrix.shape[0]
    nonrigid = [u for u in range(n_units) if u not in backbone_units]
    bb, bn, nn = [], [], []
    for i in range(n_units):
        for j in range(i+1, n_units):
            val = cc_matrix[i, j]
            if i in backbone_units and j in backbone_units:
                bb.append(val)
            elif i in backbone_units or j in backbone_units:
                bn.append(val)
            else:
                nn.append(val)
    fig, ax = plt.subplots(figsize=(5,3))
    data = [bb, bn, nn]
    labels = ['BB-BB','BB-NR','NR-NR']
    ax.violinplot(data, showmeans=True)
    ax.set_xticks([1,2,3])
    ax.set_xticklabels(labels)
    ax.set_title('Panel E: CC by Unit Type')
    return fig

def plot_panel_F(cc_matrix):
    fig, ax = plt.subplots(figsize=(5,3))
    positive_pairs = np.abs(cc_matrix[np.triu_indices(cc_matrix.shape[0],k=1)])
    positive_pairs = positive_pairs[positive_pairs>0]
    if len(positive_pairs)==0:
        positive_pairs = np.array([1e-6])
    bins = np.logspace(np.log10(positive_pairs.min()), np.log10(positive_pairs.max()),50)
    ax.hist(positive_pairs,bins=bins,color='gray',alpha=0.7)
    ax.set_xscale('log')
    ax.set_title('Panel F: Log-Hist of CCs')
    return fig


# 1: figA = plot_panel_A(spike_trains, burst_windows)
# 2: figB = plot_panel_B(compute_cc_matrix(spike_trains, duration))
# 3: figC = plot_panel_C(compute_cc_matrix(spike_trains, duration))
# 4: figD = plot_panel_D(compute_lag_matrix(spike_trains, duration))
# 5: figE = plot_panel_E(compute_cc_matrix(spike_trains, duration), backbone_units)
# 6: figF = plot_panel_F(compute_cc_matrix(spike_trains, duration))


In [None]:
plot_panel_A(spike_trains, burst_windows)