In [1]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
current_dir = Path().resolve()
sys.path.append(current_dir.parent.parent.as_posix())

# %% General setup
import pandas as pd
from pathlib import Path
from audrey.data_io import DataIO

import utils
import numpy as np
from utils import make_figure, save_fig
from scipy.stats import wilcoxon
from axorus.preprocessing.project_colors import ProjectColors


# Load data
session_id = '250904_A'
data_dir = Path(r'C:\audrey')
figure_dir = Path(r'C:\audrey\Figure') / '250904_A'

data_io = DataIO(data_dir)
loadname = data_dir / f'{session_id}_cells.csv'
data_io.load_session(session_id, load_pickle=True, load_waveforms=False)
data_io.dump_as_pickle()
cells_df = pd.read_csv(loadname, header=[0, 1], index_col=0)
clrs = ProjectColors()

INCLUDE_RANGE = 50  # include cells at max distance = 50 um

clrs = ProjectColors()



Loading pickled data (not from h5 file)


In [2]:


# %% Detect electrode stim site with most significant responses, per cell

electrodes = data_io.burst_df.electrode.unique()

pref_ec_dict = {}

for cluster_id in data_io.cluster_df.index.values:

    pref_ec = None
    n_sig_pref_ec = None

    max_fr = None
    for ec in electrodes:
        df = data_io.burst_df.query(f'electrode == {ec}')
        tids = df.train_id.unique()
        n_sig = 0
        for tid in tids:
            if cells_df.loc[cluster_id, (tid, 'is_significant')] is True:
                n_sig += 1

        if n_sig > 1:
            if pref_ec is None or n_sig > n_sig_pref_ec:
                pref_ec = ec
                n_sig_pref_ec = n_sig

    pref_ec_dict[cluster_id] = pref_ec


In [3]:
# Print available recordings in data
print("Available recordings in data:")
for r in data_io.burst_df.recording_name.unique():
    print(f"- {r}")


Available recordings in data:
- 250904_A_005_noblocker_PA_prr-series
- 250904_A_006_noblocker_PA_prr-series
- 250904_A_007_noblocker_DMD_light-series20ms
- 250904_A_008_noblocker_DMD_light-series200ms
- 250904_A_009_noblocker_PADMD_light-prr-series


In [4]:

#%% Plot raster plots for each individual cell

cluster_ids = data_io.cluster_df.index.values

# Select which recordings to plot
recording_names = [
    '250904_A_005_noblocker_PA_prr-series',
    '250904_A_006_noblocker_PA_prr-series',
    '250904_A_007_noblocker_DMD_light-series20ms',
    '250904_A_008_noblocker_DMD_light-series200ms',
    '250904_A_009_noblocker_PADMD_light-prr-series',
]

for cluster_id in cluster_ids:

    cluster_data = utils.load_obj(data_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl')

    n_electrodes = electrodes.size

    electrode = pref_ec_dict[cluster_id]

    if electrode is None: # Skip cells without pref electrode
        continue

    # Setup figure layout
    fig = utils.make_figure(width=1, height=1.5,
        x_domains={1: [[0.1, 0.9]],},
        y_domains={1: [[0.1, 0.9]]},
    )

    # Setup variables for plotting
    burst_offset = 0
    x_plot, y_plot = [], []
    yticks = []
    ytext = []
    pos = dict(row=1, col=1)


    for rec_i, rec_name in enumerate(recording_names):
        d_select = data_io.burst_df.query('electrode == electrode and '
                                              'recording_name == @rec_name').copy()
        d_select.sort_values('duty_cycle', inplace=True)
        # repetition_frequencies = d_select.repetition_frequency.unique()

        if '_005_' in rec_name:
            stim_strenghtes = d_select.duty_cycle.unique()
        elif '_006_' in rec_name:
            stim_strenghtes = d_select.duty_cycle.unique()
        elif '_007_' in rec_name:
            stim_strenghtes = d_select.light_intensity.unique()
        elif '_008_' in rec_name:
            stim_strenghtes = d_select.light_intensity.unique()
        elif '_009_' in rec_name:
            stim_strenghtes = d_select.laser_duty_cycle.unique()

        for stim_strenght in stim_strenghtes:
            if '_005_' in rec_name or '_006_' in rec_name:
                tid = d_select.query('duty_cycle == @stim_strenght').iloc[0].train_id
                bd = data_io.burst_df.query('train_id == @tid').iloc[0].burst_duration
                rname = 'PA'

            elif '_007_' in rec_name or '_008_' in rec_name:
                tid = d_select.query('light_intensity == @stim_strenght').iloc[0].train_id
                bd = np.nan
                rname = 'DMD'

            elif '_009_' in rec_name:
                tid = d_select.query('laser_duty_cycle == @stim_strenght').iloc[0].train_id
                bd = data_io.burst_df.query('train_id == @tid').iloc[0].laser_burst_duration
                rname = 'PA+DMD'
        
            spike_times = cluster_data[tid]['spike_times']
            bins = cluster_data[tid]['bins']

            ytext.append(f'stimval: {stim_strenght:.0f} bd: {bd:.0f}, {rname}')
            yticks.append(burst_offset + len(spike_times) / 2)

            for burst_i, sp in enumerate(spike_times):
                x_plot.append(np.vstack([sp, sp, np.full(sp.size, np.nan)]).T.flatten())
                y_plot.append(np.vstack([np.ones(sp.size) * burst_offset,
                                         np.ones(sp.size)* burst_offset +1, np.full(sp.size, np.nan)]).T.flatten())
                burst_offset += 1

    x_plot = np.hstack(x_plot)
    y_plot = np.hstack(y_plot)

    fig.add_scatter(
        x=x_plot, y=y_plot,
        mode='lines', line=dict(color='black', width=0.5),
        showlegend=False,
        **pos,
    )

    fig.update_xaxes(
        tickvals=np.arange(-500, 500, 100),
        title_text=f'time [ms]',
        range=[bins[0]-1, bins[-1]+1],
        **pos,
    )

    fig.update_yaxes(
        # range=[0, n_bursts],
        tickvals=yticks,
        ticktext=ytext,
        **pos,
    )

    sname = figure_dir  / 'raster plots' / f'{cluster_id}'

    utils.save_fig(fig, sname, display=False)



saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_000.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_001.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_002.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_003.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_004.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_005.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_006.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_007.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_008.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_009.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_010.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_011.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_012.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_013.png
saved: C:\audrey\Figure\250904_A\raster plots\uid_250904_014.png
saved: C:\audrey\Figure\2

In [36]:
# Plot single cell response curves
# 1 plot per recording with response curves per condition

# Setup figure spacing
x_offset = 0.05
x_spacing = 0.1


cluster_ids = data_io.cluster_df.index.values

# Select which recordings to plot
recording_names = [
    '250904_A_005_noblocker_PA_prr-series',
    '250904_A_006_noblocker_PA_prr-series',
    '250904_A_007_noblocker_DMD_light-series20ms',
    '250904_A_008_noblocker_DMD_light-series200ms',
    '250904_A_009_noblocker_PADMD_light-prr-series',
]

plot_titles = [
    'PA',
    'PA',
    'DMD - 20ms',
    'DMD - 200ms',
    'PA+DMD',  
]


# cluster_ids = ['uid_250904_015']
for cluster_id in cluster_ids:

    # Get preferred electrode
    electrode = pref_ec_dict[cluster_id]
    if electrode is None: # Skip cells without pref electrode
        continue


    # Setup figure layout
    fig = utils.simple_fig(width=1, height=1.5, n_cols=3, n_rows=2,
        subplot_titles={
            1: [plot_titles[0], plot_titles[1], plot_titles[2]], 
            2: [plot_titles[3], plot_titles[4], '']}
    )

    # Setup variables for plotting
    burst_offset = 0
    x_plot, y_plot = [], []
    yticks = []
    ytext = []
    pos = dict(row=1, col=1)
    plot_row = 1
    plot_col = 1

    # Load cluster data
    cluster_data = utils.load_obj(data_dir / 'bootstrapped' / f'bootstrap_{cluster_id}.pkl')

    for rec_i, rec_name in enumerate(recording_names):
        y_max = 0
        d_select = data_io.burst_df.query('electrode == electrode and '
                                              'recording_name == @rec_name').copy()
        d_select.sort_values('duty_cycle', inplace=True)
        # repetition_frequencies = d_select.repetition_frequency.unique()

        if '_005_' in rec_name:
            stim_strenghts = d_select.duty_cycle.unique()
        elif '_006_' in rec_name:
            stim_strenghts = d_select.duty_cycle.unique()
        elif '_007_' in rec_name:
            stim_strenghts = d_select.light_intensity.unique()
        elif '_008_' in rec_name:
            stim_strenghts = d_select.light_intensity.unique()
        elif '_009_' in rec_name:
            stim_strenghts = d_select.laser_duty_cycle.unique()

        for stim_strenght in stim_strenghts:
            if '_005_' in rec_name or '_006_' in rec_name:
                tid = d_select.query('duty_cycle == @stim_strenght').iloc[0].train_id
                clr_a = clrs.duty_cycle(stim_strenght, 0.2)
                clr = clrs.duty_cycle(stim_strenght)
                bd = d_select.query('duty_cycle == @stim_strenght').iloc[0].laser_burst_duration

            elif '_007_' in rec_name or '_008_' in rec_name:
                tid = d_select.query('light_intensity == @stim_strenght').iloc[0].train_id
                clr_a = clrs.dmd_light_stim(stim_strenght, 0.2)
                clr = clrs.dmd_light_stim(stim_strenght)
                bd = d_select.query('light_intensity == @stim_strenght').iloc[0].dmd_burst_duration

            elif '_009_' in rec_name:
                tid = d_select.query('laser_duty_cycle == @stim_strenght').iloc[0].train_id
                clr_a = clrs.padmd_stim(stim_strenght, 0.2)
                clr = clrs.padmd_stim(stim_strenght)
                bd = d_select.query('laser_duty_cycle == @stim_strenght').iloc[0].laser_burst_duration

        
            bins = cluster_data[tid]['bins']
            sp = cluster_data[tid]['firing_rate']
            ci_low = cluster_data[tid]['firing_rate_ci_low']
            ci_high = cluster_data[tid]['firing_rate_ci_high']

            
            if sp is None:
                continue

            if ci_high.max() > y_max:
                y_max = ci_high.max()

            pos = dict(row=plot_row, col=plot_col )

            fig.add_scatter(x=[0, 0, bd, bd],
                            y=[0, 1000, 1000, 0],
                            line=dict(width=0),
                            mode='lines',
                            fillcolor='rgba(128, 128, 128, 0.1)',
                            fill='toself',
                            showlegend=False,
                            **pos,)

            fig.add_scatter(
                x=bins, y=ci_low, line=dict(width=0), showlegend=False,
                **pos,
            )

            fig.add_scatter(
                x=bins,
                y=ci_high,
                line=dict(width=0),
                mode='lines',
                fill='tonexty',
                fillcolor=clr_a,  # shaded area
                showlegend=False, **pos,
            )

            fig.add_scatter(
                x=bins, y=sp, line=dict(width=1, color=clr), name='Firing rate', **pos,
                showlegend=False,
            )

            fig.update_xaxes(
                tickvals=np.arange(-500, 500, 100),
                title_text=f'time [ms]',
                range=[-101, 351],
                **pos,
            )

            fig.update_yaxes(
                range=[0, y_max],
                tickvals=yticks,
                ticktext=ytext,
                **pos,
            )

        plot_col += 1
        if plot_col > 3:
            plot_col = 1
            plot_row += 1



    sname = figure_dir  / 'firing_rate_per_condition' / f'{cluster_id}'

    utils.save_fig(fig, sname, display=False)

    




saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_000.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_001.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_002.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_003.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_004.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_005.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_006.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_007.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_008.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_009.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_010.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_250904_011.png
saved: C:\audrey\Figure\250904_A\firing_rate_per_condition\uid_2

KeyboardInterrupt: 

In [30]:
data_io.burst_df.columns

Index(['rec_id', 'burst_id', 'laser_train_onset', 'laser_burst_onset',
       'laser_burst_offset', 'train_id', 'burst_period', 'burst_count',
       'burst_duration', 'duty_cycle', 'electrode', 'protocol', 'pos_xyz',
       'Recording Number', 'Retina Slice', 'Blocker', 'Blocker Admission',
       'Washout Start', 'Animal', 'Animal Birthdate', 'Implant',
       'Implant Diameter', 'Laser', 'Connected Fibers', 'Mea', 'Medium',
       'Attenuators', 'Laser level', 'Inter protocol interval',
       'inter_trial_interval', 'rec_file', 'baseline_duration',
       'dmd_baseline_duration', 'dmd_spot_size_um', 'dmd_frame_freq',
       'stim_offset', 'dmd_burst_duration', 'dmd_burst_period',
       'dmd_burst_count', 'light_intensity', 'Laser level.1',
       'inter_trial_interval.1', 'laser_burst_duration', 'laser_burst_period',
       'laser_burst_count', 'laser_duty_cycle', 'laser_prr', 'recording_name',
       'rec_train_i', 'laser_x', 'laser_y', 'stimtype', 'dmd_train_onset',
       'dmd_