In [9]:
import sys
from pathlib import Path
from utils import make_figure, save_fig
current_dir = Path().resolve()
sys.path.append(current_dir.parent.parent.as_posix())
import os
import pandas as pd
from sonogenetics.analysis.data_io import DataIO
import numpy as np
from sonogenetics.project_colors import ProjectColors
from sonogenetics.analysis.analysis_params import dataset_dir, figure_dir_analysis

# Load data
if not os.path.exists(figure_dir_analysis):
    os.makedirs(figure_dir_analysis)
    
data_io = DataIO(dataset_dir)
session_id = '2026-02-19 mouse c57 5713 Mekano6 A'

figure_dir_analysis = figure_dir_analysis / session_id
print(session_id)
data_io.load_session(session_id, load_pickle=False, load_waveforms=False)
data_io.dump_as_pickle()

loadname = dataset_dir / f'{session_id}_cells.csv'
cells_df = pd.read_csv(loadname, header=[0, 1], index_col=0)
clrs = ProjectColors()

# Print available recording ids
print("Available recording ids:")
for rec_id in data_io.recording_ids:
    print(f"- {rec_id}")

print('session protocols')
for pname in data_io.train_df.protocol.unique():
    print(f'-{pname}')

2026-02-19 mouse c57 5713 Mekano6 A
Available recording ids:
- rec_2_pa_dose_sequence_1
- rec_3_pa_dmd_pilot1
session protocols
-pa_dose_sequence_1
-pa_dmd_pilot1


In [10]:
data_io.burst_df.has_dmd.unique()

array(['False', 'True'], dtype=object)

In [2]:
# Check trigger times

# Select a protocol
df = data_io.burst_df.query('protocol == "pa_dmd_pilot1"')

# Select a random trial
tid = list(df.train_id.unique())[12]
tid = 'tid_2026-02-19 mouse c57 5713 Mekano6 A_107'
print(df.laser_onset_delay.unique())

fig = make_figure(
    width=1, height=1,
    x_domains={1: [[0.1, 0.9]]},
    y_domains={1: [[0.1, 0.9]]},
)

for (delay, tid), df2 in df.groupby(['laser_onset_delay', 'train_id']):
    if pd.isna(delay):
        continue

    in_legend = False
    y = df2.laser_burst_onset - df2.dmd_burst_onset
    x = np.arange(0, y.size, 1)

    if delay == 10:
        clr = 'red'
    elif delay == 20:
        clr = 'green'
    elif delay == 30:
        clr = 'blue'

    fig.add_scatter(
        x=x, y=y,
        mode='markers+lines',
        showlegend=not in_legend,
        name=f'{delay:.0f}',
        marker=dict(color=clr)
    )

    in_legend=True

fig.update_xaxes(
    title_text='burst i'
)
fig.update_yaxes(
    tickvals=np.arange(10, 100, 10),
    title_text='laser onset - burst onset [ms]',
    showgrid=False,
)
save_fig(
    fig=fig, savename=figure_dir_analysis / 'misc' / 'onset_delays', display=True
)
# fig.show()

# for delay, df2 in df.groupby('laser_onset_delay'):
#     print(f'{delay}')
#     mean_delay = np.mean(df2.laser_burst_onset - df2.dmd_burst_onset)




[nan 10. 20. 30.]
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\misc\onset_delays.png
displaying figure


In [3]:
from sonogenetics.analysis.plot_responses_single_session import plot_raster_per_protocol

# plot_raster_per_protocol(data_io=data_io)


In [4]:
import numpy as np
from scipy.ndimage import gaussian_filter1d

# -----------------------------
# Parameters
# -----------------------------
protocol = 'pa_dmd_pilot1'

t_pre = 100
t_after = 200
stepsize = 5
binwidth = 30

smooth_sigma = 3
baseline_t0 = -120
baseline_t1 = 0
max_z = 5

# -----------------------------
# Get trial + recording
# -----------------------------
trials = data_io.train_df.query('protocol == @protocol')
sort_idx = None

for ti, tid in enumerate(trials.index.values):
    rec_id = trials.loc[tid, 'recording_name']

    burst_onsets = data_io.burst_df.query('train_id == @tid').dmd_burst_onset
    n_trains = burst_onsets.size

    bin_centres = np.arange(-t_pre, t_after, stepsize)
    n_bins = bin_centres.size

    # --- Robust baseline indexing ---
    baseline_mask = (bin_centres >= baseline_t0) & (bin_centres < baseline_t1)
    if not np.any(baseline_mask):
        raise ValueError("Baseline window outside bin range.")

    # -----------------------------
    # Compute population response
    # -----------------------------
    population_fr = []
    cluster_list = []

    for cid in data_io.cluster_ids:

        spiketrain = data_io.spiketimes[rec_id][cid]
        binned_sp = np.zeros((n_trains, n_bins))

        for burst_i, burst_onset in enumerate(burst_onsets):
            for bin_i, bin_centre in enumerate(bin_centres):

                # symmetric bin (recommended)
                t0 = burst_onset + bin_centre - binwidth / 2
                t1 = burst_onset + bin_centre + binwidth / 2

                count = np.sum((spiketrain >= t0) & (spiketrain < t1))
                binned_sp[burst_i, bin_i] = count

        mean_fr = np.mean(binned_sp, axis=0)

        population_fr.append(mean_fr)

        # --- FIXED f-string quotes ---
        ch = data_io.cluster_df.loc[cid, 'ch']
        cluster_list.append(f"{cid.split('_')[-1]}: {ch}")

    population_fr = np.array(population_fr)

    # -----------------------------
    # Convert to firing rate (Hz)
    # -----------------------------
    population_fr = population_fr / (binwidth / 1000)

    # -----------------------------
    # Optional smoothing
    # -----------------------------
    if smooth_sigma > 0:
        population_fr = gaussian_filter1d(population_fr, sigma=smooth_sigma, axis=1)

    # -----------------------------
    # Z-score to baseline window
    # -----------------------------
    baseline = population_fr[:, baseline_mask]

    baseline_mean = baseline.mean(axis=1, keepdims=True)
    baseline_std = baseline.std(axis=1, keepdims=True)

    baseline_std[baseline_std == 0] = 1  # avoid divide-by-zero

    population_z = (population_fr - baseline_mean) / baseline_std

    # Optional clipping
    population_z = np.clip(population_z, -max_z, max_z)

    # -----------------------------
    # Sort neurons by strongest modulation (abs peak after 0 ms)
    # -----------------------------
    post_mask = bin_centres >= 0
    peak_modulation = np.abs(population_z[:, post_mask]).max(axis=1)

    if sort_idx is None:
        sort_idx = np.argsort(peak_modulation)[::-1]

    population_z = population_z[sort_idx]
    cluster_list = np.array(cluster_list)[sort_idx]

    # -----------------------------
    # Plot heatmap
    # -----------------------------
    fig = make_figure(
        height=2,
        y_domains={1: [[0.1, 0.99]]}
    )

    fig.add_heatmap(
        z=population_z,
        x=bin_centres,
        y=np.arange(population_z.shape[0]),
        colorscale='RdBu_r',
        zmid=0,
        zmin=-max_z,
        zmax=max_z,
        showscale=False
    )

    fig.update_yaxes(
        tickvals=np.arange(population_z.shape[0]),
        ticktext=cluster_list
    )

    fig.update_xaxes(
        tickvals=np.arange(-200, 301, 50),
        title_text='time [ms]'
    )

    sname = figure_dir_analysis / f'heatmap_{ti}_has_laser_{trials.loc[tid, "has_laser"]}'
    save_fig(fig=fig, savename=sname, display=False)

saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_0_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_1_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_2_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_3_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_4_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_5_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_6_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_7_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_8_has_laser_False.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_9_has_laser_True.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_10_has_laser_True.png
saved: C:\sono\figures

In [4]:
import numpy as np
from scipy.ndimage import gaussian_filter1d

# -----------------------------
# Parameters
# -----------------------------
protocol = 'pa_dmd_pilot1'

t_pre = 100
t_after = 200
stepsize = 5
binwidth = 30

smooth_sigma = 3
baseline_t0 = -120
baseline_t1 = 0
max_z = 5

# -----------------------------
# Get trial + recording
# -----------------------------
trials = data_io.train_df.query('protocol == @protocol')

for cid in data_io.cluster_ids:
    frates = []
    for ti, tid in enumerate(trials.index.values):

        # Get rename + spike train
        rec_id = trials.loc[tid, 'recording_name']
        spiketrain = data_io.spiketimes[rec_id][cid]

        # Get burs tonsets
        burst_onsets = data_io.burst_df.query('train_id == @tid').dmd_burst_onset
        n_trains = burst_onsets.size

        # Placeholder to extract data into
        bin_centres = np.arange(-t_pre, t_after, stepsize)
        n_bins = bin_centres.size

        # --- Robust baseline indexing ---
        baseline_mask = (bin_centres >= baseline_t0) & (bin_centres < baseline_t1)
        if not np.any(baseline_mask):
            raise ValueError("Baseline window outside bin range.")

        # Get spikes per bin
        binned_sp = np.zeros((n_trains, n_bins))

        for burst_i, burst_onset in enumerate(burst_onsets):
            for bin_i, bin_centre in enumerate(bin_centres):

                # symmetric bin (recommended)
                t0 = burst_onset + bin_centre - binwidth / 2
                t1 = burst_onset + bin_centre + binwidth / 2

                count = np.sum((spiketrain >= t0) & (spiketrain < t1))
                binned_sp[burst_i, bin_i] = count

        # Get mean firing rate
        mean_fr = np.mean(binned_sp, axis=0)
        frates.append(mean_fr)

    cell_fr = np.array(frates)

    # -----------------------------
    # Convert to firing rate (Hz)
    # -----------------------------
    cell_fr = cell_fr / (binwidth / 1000)

    # -----------------------------
    # Optional smoothing
    # -----------------------------
    if smooth_sigma > 0:
        cell_fr = gaussian_filter1d(cell_fr, sigma=smooth_sigma, axis=1)

    # -----------------------------
    # Z-score to baseline window
    # -----------------------------
    baseline = cell_fr[:, baseline_mask]

    baseline_mean = baseline.mean(axis=1, keepdims=True)
    baseline_std = baseline.std(axis=1, keepdims=True)

    # baseline_std[baseline_std == 0] = 1  # avoid divide-by-zero
    idx = np.where(baseline != 0)[0]
    cell_fr[idx] = (cell_fr[idx] - baseline_mean[idx]) / baseline_std[idx]

    # Optional clipping
    cell_fr = np.clip(cell_fr, -max_z, max_z)

    # -----------------------------
    # Plot heatmap
    # -----------------------------
    fig = make_figure(
        height=2,
        y_domains={1: [[0.1, 0.99]]}
    )

    fig.add_heatmap(
        z=cell_fr,
        x=bin_centres,
        y=np.arange(cell_fr.shape[0]),
        colorscale='RdBu_r',
        zmid=0,
        zmin=-max_z,
        zmax=max_z,
        showscale=False
    )

    fig.update_xaxes(
        tickvals=np.arange(-200, 301, 50),
        title_text='time [ms]'
    )

    sname = figure_dir_analysis / 'heatmap_per_cell' / f'heatmap_{protocol}_{cid}'
    save_fig(fig=fig, savename=sname, display=False)



saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_000.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_001.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_002.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_003.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_004.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5713 Mekano6 A_005.png
saved: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\heatmap_per_cell\heatmap_pa_dmd_pilot1_uid_2026-02-19 mouse c57 5

In [8]:
data_io.train_df.electrode.unique()

array([238., 165.,  24.,  47.,  16.,  43.], dtype=float32)