In [1]:
import sys
from pathlib import Path
current_dir = Path().resolve()
sys.path.append(current_dir.parent.parent.as_posix())
import os

# %% General setup
import pandas as pd
from pathlib import Path
from sonogenetics.analysis.data_io import DataIO

import utils
import numpy as np
from sonogenetics.project_colors import ProjectColors
from sonogenetics.analysis.analysis_params import dataset_dir, figure_dir_analysis
from sonogenetics.preprocessing.dataset_sessions import dataset_sessions

# Load data


if not os.path.exists(figure_dir_analysis):
    os.makedirs(figure_dir_analysis)
    
data_io = DataIO(dataset_dir)

session_id = data_io.sessions[0]

figure_dir_analysis = figure_dir_analysis / session_id
print(session_id)
data_io.load_session(data_io.sessions[0], load_pickle=True, load_waveforms=False)
data_io.dump_as_pickle()

loadname = dataset_dir / f'{session_id}_cells.csv'
cells_df = pd.read_csv(loadname, header=[0, 1], index_col=0)
clrs = ProjectColors()

INCLUDE_RANGE = 50  # include cells at max distance = 50 um

# Print available recording ids
print("Available recording ids:")
for rec_id in data_io.recording_ids:
    print(f"- {rec_id}")


2025-02-16 mouse c57 566 eMSCL A
Loading pickled data (not from h5 file)
Available recording ids:
- rec_1_26-02-16_A_pilot021626_noblocker


In [2]:
def plot_single_trial(cd, td, ci, display=True):
    fig = utils.make_figure(
        width    =1,
        height   =1.5,
        x_domains={
            1: [[0.15, 0.95]],
        },
        y_domains={
            1: [[0.1, 0.9]]
        },
    )

    # Setup variables for plotting
    burst_offset   = 0
    x_plot, y_plot = [], []
    yticks         = []
    ytext          = []
    pos            = dict(row=1, col=1)

    has_sig = False

    spike_times = cd[tid]['spike_times']
    bins = cd[tid]['bins']

    # ytext.append(f'dv: {dv:.0f} bd: {bd:.0f}, prr: {prr:.0f}')

    # yticks.append(burst_offset + len(spike_times) / 2)

    for burst_i, sp in enumerate(spike_times):
        x_plot.append(np.vstack([sp, sp, np.full(sp.size, np.nan)]).T.flatten())
        y_plot.append(np.vstack([np.ones(sp.size) * burst_offset,
                                np.ones(sp.size)* burst_offset +1, np.full(sp.size, np.nan)]).T.flatten())
        burst_offset += 1

    x_plot = np.hstack(x_plot)
    y_plot = np.hstack(y_plot)


    fig.add_scatter(
        x = x_plot, y = y_plot,
        mode = 'lines', line = dict(color='black', width=0.5),
        showlegend = False,
        **pos,
    )

    fig.update_xaxes(
        tickvals = np.arange(-500, 500, 100),
        title_text = f'time [ms]',
        range = [bins[0]-1, bins[-1]+1],
        **pos,
    )

    fig.update_yaxes(
        range=[0, burst_offset],
        tickvals = yticks,
        ticktext = ytext,
        **pos,
    )

    sname = figure_dir_analysis  / 'raster plots' / f'{ci}' / f'{ci}_{td}'
    utils.save_fig(fig, sname, display=display, verbose=False)


In [3]:
import numpy as np

def plot_firing_rate(
    spiketimes,
    ci,
    display=False,
    t_start=0.0,          # seconds
    t_end=10.0,           # seconds
    window_width_ms=50.0,
    step_size_ms=5.0,
):
    """
    Plot firing rate using sliding window.

    Parameters
    ----------
    spiketimes : 1D array
        Spike times in milliseconds.
    t_start : float
        Start time in seconds (default 0).
    t_end : float
        End time in seconds (default 10).
    window_width_ms : float
        Sliding window width in ms (default 50).
    step_size_ms : float
        Step size in ms (default 5).
    """

    spiketimes = np.asarray(spiketimes)

    # Convert time window to ms
    t_start_ms = t_start * 1000.0
    t_end_ms   = t_end * 1000.0

    # Keep spikes within requested time window
    mask = (spiketimes >= t_start_ms) & (spiketimes <= t_end_ms)
    spikes = spiketimes[mask]

    # Define sliding window centers
    window_starts = np.arange(
        t_start_ms,
        t_end_ms - window_width_ms + step_size_ms,
        step_size_ms
    )

    firing_rates = np.zeros(len(window_starts))

    # Compute firing rate (Hz)
    for i, w_start in enumerate(window_starts):
        w_end = w_start + window_width_ms
        count = np.sum((spikes >= w_start) & (spikes < w_end))
        firing_rates[i] = count / (window_width_ms / 1000.0)  # convert to Hz

    # Convert x-axis to seconds (center of window)
    x_plot = (window_starts + window_width_ms / 2.0) / 1000.0
    y_plot = firing_rates

    # ---- Plot ----
    fig = utils.make_figure(
        width=1,
        height=1.5,
        x_domains={
            1: [[0.15, 0.95]],
        },
        y_domains={
            1: [[0.1, 0.9]]
        },
    )

    pos = dict(row=1, col=1)

    fig.add_scatter(
        x=x_plot,
        y=y_plot,
        mode='lines',
        line=dict(color='black', width=0.5),
        showlegend=False,
        **pos,
    )

    sname = figure_dir_analysis / 'raster plots' / f'{ci}' / f'{ci}'
    utils.save_fig(fig, sname, display=display, verbose=False)

    return x_plot, y_plot


In [2]:
df = data_io.burst_df
print(df.shape)
print(df.burst_id[-1] * 2)

train_ids = df.train_id.unique()
# print(len(train_ids))

tid = train_ids[20]
df2 = df.loc[df.train_id == tid]
print(df2.shape)

# for c in df2.columns:
#     print(c)

onset = df2.laser_burst_onset.values - df2.laser_burst_onset.values[0]
offset = df2.laser_burst_offset.values - df2.laser_burst_onset.values
# print(offset)

cluster_id = r'uid_2026-02-11 mouse c57 565 eMSCL A_029'
cluster_data = utils.load_obj(dataset_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl')
ct_data = cluster_data[tid]
# print(ct_data.keys())

spiketimes = data_io.spiketimes
sp = spiketimes['rec_2_pilot_021126'][cluster_id]
plot_firing_rate(sp, cluster_id, display=True, window_width_ms=100)
# plot_single_trial(cluster_data, tid, cluster_id)


(720, 39)
1438.0
(30, 39)



Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`



KeyError: 'tid_2025-02-16 mouse c57 566 eMSCL A_020'

### Detect the electrode stimulation site with the most significant responses, per cell

In [5]:
# %% Detect electrode stim site with most significant responses, per cell
pref_ec_df = pd.DataFrame()

for cluster_id in data_io.cluster_ids:
    pref_ec = None
    n_sig_pref_ec = None

    for ec in data_io.burst_df.electrode.unique():

        df = data_io.burst_df.query(f'electrode == {float(ec)}')
        tids = df.train_id.unique()
        n_sig = 0
        for tid in tids:


            t = cells_df.loc[cluster_id, tid]['is_significant']
            if cells_df.loc[cluster_id, (tid, 'is_significant')] == True:
                n_sig += 1


        if n_sig > 1:
            if pref_ec is None or n_sig > n_sig_pref_ec:
                pref_ec = ec
                n_sig_pref_ec = n_sig
            elif n_sig == n_sig_pref_ec:
                print(f'cluster {cluster_id} has 2 pref ecs')

    # if n_sig_pref_ec is None or n_sig == 0:
    #     print(f'{cluster_id}: no sig responses')
    # else:
    #     print(f'{cluster_id}: {n_sig}')

        
    pref_ec_df.at[cluster_id, 'ec'] = pref_ec

print('Electrode stimulation site with the most significant responses per cell detected.\n\n------ End Of Cell ------')

Electrode stimulation site with the most significant responses per cell detected.

------ End Of Cell ------


### Plot the raster plots for each cell

In [6]:
##%% Plot raster plots for each individual cell, during stimulation at each electrode

cluster_ids = data_io.cluster_df.index.values

electrodes  = data_io.burst_df.electrode.unique()

print(f'saving data in: {figure_dir_analysis / "raster plots"}')

x = 0

for cluster_id in cluster_ids:

    # if '_029' not in cluster_id:
    #     continue

    cluster_data = utils.load_obj(dataset_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl')

    n_electrodes = electrodes.size

    for ec in electrodes:

        # Setup figure layout
        fig = utils.make_figure(
            width    =1,
            height   =1.5,
            x_domains={
                1: [[0.15, 0.95]],
            },
            y_domains={
                1: [[0.1, 0.9]]
            },
        )

        # Setup variables for plotting
        burst_offset   = 0
        x_plot, y_plot = [], []
        x_lines, y_lines = [], []
        yticks         = []
        ytext          = []
        pos            = dict(row=1, col=1)

        has_sig = False

        for rec_i, rec_name in enumerate(data_io.recording_ids):
            d_select = data_io.burst_df.query('electrode == @ec and '
                                                'recording_name == @rec_name').copy()

            for (bd, prr, dv), df in d_select.groupby(['laser_burst_duration', 'laser_pulse_repetition_rate', 'dac_voltage']):    
                train_plot_height_start = burst_offset

                tids = df.train_id.unique()
                assert len(tids) == 1
                tid = tids[0]  
                spike_times = cluster_data[tid]['spike_times']
                bins = cluster_data[tid]['bins']

                ytext.append(f'dv: {dv:.0f} bd: {bd:.0f}, prr: {prr:.0f}')

                yticks.append(burst_offset + len(spike_times) / 2)

                for burst_i, sp in enumerate(spike_times):
                    x_plot.append(np.vstack([sp, sp, np.full(sp.size, np.nan)]).T.flatten())
                    y_plot.append(np.vstack([np.ones(sp.size) * burst_offset,
                                            np.ones(sp.size)* burst_offset +1, np.full(sp.size, np.nan)]).T.flatten())
                    burst_offset += 1

                x_lines.extend([0, bd, bd, 0, 0, None])
                y_lines.extend([train_plot_height_start, train_plot_height_start,
                                burst_offset, burst_offset, train_plot_height_start, None])

        if len(x_plot) == 0:
            continue
        
        x_plot = np.hstack(x_plot)
        y_plot = np.hstack(y_plot)

        fig.add_scatter(
            x=x_lines, y=y_lines,
            mode='lines', line=dict(width=0.00001, color='black'),
            fill='toself', fillcolor='rgba(0, 200, 100, 0.1)',
            showlegend=False,
            **pos,
        )

        fig.add_scatter(
            x = x_plot, y = y_plot,
            mode = 'lines', line = dict(color='black', width=0.5),
            showlegend = False,
            **pos,
        )

        fig.update_xaxes(
            tickvals = np.arange(-500, 500, 100),
            title_text = f'time [ms]',
            range = [bins[0]-1, bins[-1]+1],
            **pos,
        )

        fig.update_yaxes(
            range=[0, burst_offset],
            tickvals = yticks,
            ticktext = ytext,
            **pos,
        )


        if ec == pref_ec_df.loc[cluster_id, 'ec']:
            sname = figure_dir_analysis  / 'raster plots' / 'significant_responses' / f'{cluster_id}' 
            utils.save_fig(fig, sname, display=False, verbose=False)
        else:
            sname = figure_dir_analysis  / 'raster plots' / 'not_significant_responses' / f'{cluster_id}_{ec}'
            utils.save_fig(fig, sname, display=False, verbose=False)



saving data in: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\raster plots


In [10]:
##%% optimized single-panel firing rate plotting
from tqdm import tqdm 

# for cluster_id in tqdm(data_io.cluster_df.index.values, desc='Plotting firing rates'):
for cluster_id in data_io.cluster_df.index.values:
    # if '_029' not in cluster_id:
        # continue


    # Load data for this cluster
    cluster_data = utils.load_obj(
        dataset_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl'
    )




    # Create a *single panel* figure
    fig = utils.make_figure(
        width    = 1,
        height   = 1.5,
        x_domains={1: [[0.1, 0.4], [0.6, 0.9]],
                2: [[0.1, 0.4], [0.6, 0.9]],                   
                },
        y_domains={1: [[0.6, 0.9], [0.6, 0.9]], 
                2: [[0.1, 0.4], [0.1, 0.4]],},
        subplot_titles={1: ['burst duration 10', 'burst duration 20'], 2: ['', ''], }
    )


    # Track global y-limits
    y_max = 0

    has_shaded = []

    ec = pref_ec_df.loc[cluster_id, 'ec']
    if pd.isna(ec):
        continue

    # Loop over all recordings (all curves go on same axis)
    for rec_i, rec_name in enumerate(data_io.recording_ids):

        # Filter data
        d_select = data_io.burst_df.query(
            'electrode == @ec and recording_name == @rec_name'
        ).copy()

        for (bd, prr, dv), df in d_select.groupby(['laser_burst_duration', 'laser_pulse_repetition_rate', 'dac_voltage']):

            if bd == 10:
                col = 1
            elif bd == 20:
                col = 2
            else:
                raise ValueError(f'Unexpected burst duration: {bd}')

            if prr == 4000:
                row = 1
            elif prr == 6000:
                row = 2
            else:
                raise ValueError(f'Unexpected pulse repetition rate: {prr}')
            
            pos = dict(row=row, col=col)
        
            tid  = df.train_id.iloc[0]           

            min_laser_level = 3000
            max_laser_level = 8000

            laser_level = int((dv - min_laser_level) / (max_laser_level - min_laser_level) * 100)

            clr_a = clrs.laser_level(dv, alpha = 0.2)
            clr   = clrs.laser_level(dv, alpha = 1.0)

            # Load bootstrap FR data
            bins    = cluster_data[tid]['bins']
            sp      = cluster_data[tid]['firing_rate']
            ci_low  = cluster_data[tid]['firing_rate_ci_low']
            ci_high = cluster_data[tid]['firing_rate_ci_high']
            latency = cells_df.loc[cluster_id, (tid, 'response_latency')]
            fr      = cells_df.loc[cluster_id, (tid, 'response_firing_rate')]

            if ci_high is not None:
                y_max = max(y_max, ci_high.max())

            if pos not in has_shaded:
                xp = [0, 0, bd, bd]
                yp = [0, 1000, 1000, 0]
                fig.add_scatter(
                    x=xp, y=yp, mode='lines',
                    line=dict(color='black', width=0.001),
                    fill='toself', fillcolor='rgba(100, 100, 0, 0.1)',
                    showlegend=False,
                    **pos
                )
                has_shaded.append(pos)

            if pd.notna(latency):
                fig.add_scatter(
                    x=[latency, latency],
                    y=[0, 1000],
                    line=dict(color=clr, dash='2px,2px', width=0.5),
                    showlegend=False,
                    mode='lines',
                    **pos,
                )

            if pd.notna(fr):
                fig.add_scatter(
                    x=bins,
                    y=np.ones_like(bins) * fr,
                    mode='lines',
                    line=dict(color=clr, dash='2px,2px', width=0.5),
                    showlegend=False,
                    **pos,
                )


            # --- confidence interval ---
            fig.add_scatter(
                x=bins,
                y=ci_low,
                line=dict(width=0),
                showlegend=False,
                **pos,
            )
            fig.add_scatter(
                x=bins,
                y=ci_high,
                mode='lines',
                line=dict(width=0),
                fill='tonexty',
                fillcolor=clr_a,
                showlegend=False,
                **pos,
            )

            # --- firing-rate line ---
            show_leg = False

            fig.add_scatter(
                x=bins,
                y=sp,
                name='',
                line=dict(width=1, color=clr),
                showlegend=show_leg,
                **pos,
            )

    # Final axes formatting
    for i in range(1, 3):
        for j in range(1, 3):
            fig.update_xaxes(
                tickvals=np.arange(-500, 500, 100),
                title_text='Time (ms)',
                range=[-101, 351],
                row=i, col=j,
            )
            if y_max > 200:
                ystep = 50
            elif y_max > 100:
                ystep = 25
            else:
                ystep = 10
                
            fig.update_yaxes(
                title_text='Firing rate (Hz)',
                tickvals=np.arange(0, np.ceil(y_max/10)*10 + 1, ystep),  # nice 10-Hz spacing
                range=[0, y_max],
                row=i, col=j,
            )

    # Save figure
    sname = figure_dir_analysis / 'firing_rate_per_condition' / f'{cluster_id}_{ec}'

    utils.save_fig(fig, sname, display=False, formats=['png'], verbose=False)
    print(f'saved: {sname}')

print(f'Saved figures in {figure_dir_analysis}')



saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_000_50.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_001_50.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_003_50.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_004_50.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_007_50.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_008_250.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse c57 566 eMSCL A_009_250.0
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rate_per_condition\uid_2025-02-16 mouse

In [39]:
# Plot session statistics

x_spacing = 3

bds = data_io.burst_df.laser_burst_duration.unique()
prrs = data_io.burst_df.laser_pulse_repetition_rate.unique()
dpwrs = data_io.burst_df.dac_voltage.unique()

cluster_id = [c for c in data_io.cluster_ids if '_026' in c][0]


# Create a *single panel* figure
fig = utils.make_figure(
    width    = 1,
    height   = 1.5,
    x_domains={1: [[0.1, 0.4], [0.6, 0.9]],
            2: [[0.1, 0.4], [0.6, 0.9]],                   
            },
    y_domains={1: [[0.6, 0.9], [0.6, 0.9]], 
            2: [[0.1, 0.4], [0.1, 0.4]],},
    subplot_titles={1: ['burst duration 10', 'burst duration 20'], 2: ['', ''], }
)

bd = bds[0]
prr = prrs[0]

for bd_i, bd in enumerate(bds):
    for prr_i, prr in enumerate(prrs):
        xtext = []
        xticks = []

        for dv_i, dv in enumerate(dpwrs):

            firing_rates = []

            for cluster_id in data_io.cluster_ids:
                ec = pref_ec_df.loc[cluster_id, 'ec']
                if pd.isna(ec):
                    continue
                tid = data_io.burst_df.query(f'laser_burst_duration == {bd} and laser_pulse_repetition_rate == {prr} and dac_voltage == {dv} and electrode == {ec}').train_id.unique()
                assert len(tid) == 1
                tid = tid[0]

                firing_rates.append(cells_df.loc[cluster_id, tid]['response_firing_rate'])

            firing_rates = np.array(firing_rates)

            pos = dict(row=prr_i+1, col=bd_i+1)

            min_laser_level = 3000
            max_laser_level = 8000

            laser_level = int((dv - min_laser_level) / (max_laser_level - min_laser_level) * 100)

            box_specs = dict(
                name='',
                boxpoints='all',
                marker=dict(color=clrs.laser_level(dv, alpha = 1.0), size=2),
                line=dict(color=clrs.laser_level(dv, alpha = 1.0), width=1.5),
                showlegend=False,
            )


            fig.add_box(
                x=np.ones_like(firing_rates) * (dv_i * x_spacing),
                y=firing_rates,
                **box_specs,
                **pos,
            )

            xticks.append(dv_i * x_spacing)
            xtext.append(f'{dv / 1000:.1f}')


        fig.update_xaxes(
            tickvals=xticks,
            ticktext=xtext,
            title_text='dac [V]' if pos['row'] == 2 else '',
            **pos,
        )

        fig.update_yaxes(
            tickvals=np.arange(0, 200, 50),
            title_text='FR [Hz]' if pos['col'] == 1 else '',
            range=[0, 180],
            **pos,
        )

# Save figure
sname = figure_dir_analysis /  f'firing_rates'

utils.save_fig(fig, sname, display=True, formats=['png'], verbose=False)
print(f'saved: {sname}')



displaying figure
saved: C:\sono\figures\2025-02-16 mouse c57 566 eMSCL A\firing_rates


In [38]:
data_io.burst_df.Attenuators.unique()

array([5.8], dtype=float32)