In [1]:
import sys
from pathlib import Path

from utils import make_figure, save_fig

current_dir = Path().resolve()
sys.path.append(current_dir.parent.parent.as_posix())
import os

# %% General setup
import pandas as pd
from pathlib import Path
from sonogenetics.analysis.data_io import DataIO

import utils
import numpy as np
from sonogenetics.project_colors import ProjectColors
from sonogenetics.analysis.analysis_params import dataset_dir, figure_dir_analysis
from sonogenetics.preprocessing.dataset_sessions import dataset_sessions
from sonogenetics.analysis.analyse_responses import BootstrapOutput
# Load data


if not os.path.exists(figure_dir_analysis):
    os.makedirs(figure_dir_analysis)

data_io = DataIO(dataset_dir)
session_id = '2026-02-19 mouse c57 5713 Mekano6 A'

figure_dir_analysis = figure_dir_analysis / session_id
print(session_id)
data_io.load_session(session_id, load_pickle=True, load_waveforms=False)
data_io.dump_as_pickle()

loadname = dataset_dir / f'{session_id}_cells.csv'
cells_df = pd.read_csv(loadname, header=[0, 1], index_col=0)
clrs = ProjectColors()

INCLUDE_RANGE = 50  # include cells at max distance = 50 um

# Print available recording ids
print("Available recording ids:")
for rec_id in data_io.recording_ids:
    print(f"- {rec_id}")


2026-02-19 mouse c57 5713 Mekano6 A
Loading pickled data (not from h5 file)
Available recording ids:
- rec_2_pa_dose_sequence_1
- rec_3_pa_dmd_pilot1


In [2]:
# %% Detect electrode stim site with most significant responses, per cell
cluster_df = pd.DataFrame()

for cluster_id in data_io.cluster_ids:
    pref_ec = None
    n_sig_pref_ec = None

    for ec in data_io.burst_df.electrode.unique():

        df = data_io.burst_df.query(f'electrode == {float(ec)}')
        tids = df.train_id.unique()
        n_sig = 0
        for tid in tids:
            is_ex = cells_df.loc[cluster_id, (tid, 'is_excited')]
            is_in = cells_df.loc[cluster_id, (tid, 'is_inhibited')]
            if is_ex or is_in:
                n_sig += 1

        if n_sig > 0:
            if pref_ec is None or n_sig > n_sig_pref_ec:
                pref_ec = ec
                n_sig_pref_ec = n_sig

    cluster_df.at[cluster_id, 'ec'] = pref_ec


cid = data_io.cluster_ids[0]
cluster_labels = pd.DataFrame()

for cid in data_io.cluster_ids:
    ec = cluster_df.loc[cid, 'ec']
    if ec is None or pd.isna(ec):
        cluster_df.at[cid, 'label'] = 'none'
        continue

    df = data_io.burst_df.query(f'electrode == {float(ec)}')
    tids = df.train_id.unique()

    n_ex, n_in = 0, 0
    for tid in tids:
        if cells_df.loc[cid, (tid, 'is_excited')]:
            n_ex += 1
        if cells_df.loc[cid, (tid, 'is_inhibited')]:
            n_in += 1

    # print(f'{cid}: {n_ex} excited, {n_in} inhibited')
    if n_ex > 1 and n_in > 1:
        cluster_df.at[cid, 'label'] = 'ex+in'
    elif n_ex > 1 and n_in <= 1:
        cluster_df.at[cid, 'label'] = 'ex'
    elif n_ex <= 1 and n_in > 0:
        cluster_df.at[cid, 'label'] = 'in'
    else:
        cluster_df.at[cid, 'label'] = 'none'

for lbl, lbl_df in cluster_df.groupby('label'):
    print(f'{lbl}: {lbl_df.shape[0]}')


ex: 26
ex+in: 35
in: 6
none: 5


In [3]:
cluster_ids = data_io.cluster_df.index.values
electrodes  = data_io.burst_df.electrode.unique()
print(f'saving data in: {figure_dir_analysis / "raster plots"}')
x = 0
for cluster_id in cluster_ids:
    cluster_data = utils.load_obj(dataset_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl')
    n_electrodes = electrodes.size
    for ec in electrodes:
        for rec_i, rec_name in enumerate(data_io.recording_ids):

            if ec != cluster_df.loc[cluster_id, 'ec']: continue

            # Setup figure layout
            fig = utils.make_figure(
                width    =1,
                height   =1.5,
                x_domains={
                    1: [[0.15, 0.95]],
                },
                y_domains={
                    1: [[0.1, 0.9]]
                },
                subplot_titles={
                    1: [f'response type: {cluster_df.loc[cluster_id, "label"]}']
                }
            )
            # Setup variables for plotting
            burst_offset   = 0
            x_plot, y_plot = [], []
            x_lines, y_lines = [], []
            ex_x_lines, ex_y_lines = [None], [None]
            in_x_lines, in_y_lines = [None], [None]
            yticks         = []
            ytext          = []
            pos            = dict(row=1, col=1)

            has_sig = False

            if rec_name != 'rec_2_pa_dose_sequence_1':
                continue
            d_select = data_io.burst_df.query('electrode == @ec and '
                                                'recording_name == @rec_name').copy()

            for (bd, prr, dv), df in d_select.groupby(['laser_burst_duration', 'laser_pulse_repetition_rate', 'dac_voltage']):
                train_plot_height_start = burst_offset

                tids = df.train_id.unique()
                assert len(tids) == 1
                tid = tids[0]
                if tid not in cluster_data.keys():
                    continue
                trial_data = cluster_data[tid]
                spike_times = trial_data.spike_times
                bins = trial_data.bins

                ytext.append(f'dv: {dv:.0f} bd: {bd:.0f}, prr: {prr:.0f}')

                yticks.append(burst_offset + len(spike_times) / 2)

                for burst_i, sp in enumerate(spike_times):
                    x_plot.append(np.vstack([sp, sp, np.full(sp.size, np.nan)]).T.flatten())
                    y_plot.append(np.vstack([np.ones(sp.size) * burst_offset,
                                            np.ones(sp.size)* burst_offset +1, np.full(sp.size, np.nan)]).T.flatten())
                    burst_offset += 1

                x_lines.extend([0, bd, bd, 0, 0, None])
                y_lines.extend([train_plot_height_start, train_plot_height_start,
                                burst_offset, burst_offset, train_plot_height_start, None])

                if cells_df.loc[cluster_id, (tid, 'is_excited')]:
                    lat = cells_df.loc[cluster_id, (tid, 'excitation_start')]
                    ex_x_lines.extend([lat, lat, None])
                    ex_y_lines.extend([train_plot_height_start, burst_offset, None])

                if cells_df.loc[cluster_id, (tid, 'is_inhibited')]:
                    lat = cells_df.loc[cluster_id, (tid, 'inhibition_start')]
                    in_x_lines.extend([lat, lat, None])
                    in_y_lines.extend([train_plot_height_start, burst_offset, None])

            if len(x_plot) == 0:
                continue

            x_plot = np.hstack(x_plot)
            y_plot = np.hstack(y_plot)

            in_x_lines, in_y_lines = np.hstack(in_x_lines), np.hstack(in_y_lines)
            ex_x_lines, ex_y_lines = np.hstack(ex_x_lines), np.hstack(ex_y_lines)

            fig.add_scatter(
                x=x_lines, y=y_lines,
                mode='lines', line=dict(width=0.00001, color='black'),
                fill='toself', fillcolor='rgba(0, 200, 100, 0.1)',
                showlegend=False,
                **pos,
            )

            # Excitation onset
            fig.add_scatter(x=ex_x_lines, y=ex_y_lines, mode='lines', line=dict(color='blue', width=1),
                            showlegend=False, **pos)
            # Inhibition onset
            fig.add_scatter(x=in_x_lines, y=in_y_lines, mode='lines', line=dict(color='red', width=1),
                            showlegend=False, **pos)

            # Scatter plot } raster
            fig.add_scatter(
                x = x_plot, y = y_plot,
                mode = 'lines', line = dict(color='black', width=0.5),
                showlegend = False,
                **pos,
            )

            fig.update_xaxes(
                tickvals = np.arange(-500, 500, 100),
                title_text = f'time [ms]',
                range = [bins[0]-1, bins[-1]+1],
                **pos,
            )

            fig.update_yaxes(
                range=[0, burst_offset],
                tickvals = yticks,
                ticktext = ytext,
                **pos,
            )


            if ec == cluster_df.loc[cluster_id, 'ec']:
                sname = figure_dir_analysis  / 'raster plots' / rec_name / 'significant_responses' / f'{cluster_id}'
                utils.save_fig(fig, sname, display=False, verbose=False)
            else:
                sname = figure_dir_analysis  / 'raster plots' / rec_name /'not_significant_responses' / f'{cluster_id}_{ec}'
                utils.save_fig(fig, sname, display=False, verbose=False)


saving data in: C:\sono\figures\2026-02-19 mouse c57 5713 Mekano6 A\raster plots
