In [None]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import pandas as pd

import npc_sessions

In [None]:
session = npc_sessions.Session('DRpilot_662892_20230821')
session.subject.df['genotype']

In [None]:
tuple(session._intervals.keys())

In [None]:
fig, axes = session.sync_data.plot_diode_measured_sync_square_flips()
stim_display_times = npc_sessions.get_stim_frame_times(*session.stim_paths, sync=session.sync_data)
names = tuple(k for k, v in stim_display_times.items() if v is not None)
for idx, ax in enumerate(axes):
    ax.set_title(names[idx].stem.split('_')[0])
fig.set_size_inches(12, 6)

In [None]:
import npc_sessions.utils as utils

self = optotagging_trials
onset_frame_times = utils.safe_index(
    self._frame_times, self._hdf5["trialOptoOnsetFrame"][self.trial_index]
)
onset_sync_time = self.start_time
np.mean(onset_sync_time - onset_frame_times)

In [None]:
optotagging_trials = tuple(session._intervals[key] for key in session._intervals if 'OptoTagging' in key)
optotagging_trials[1]._df

In [None]:
waveform = optotagging_trials._stim_recordings[0].presentation.waveform
plt.plot(waveform.timestamps, waveform.samples)

In [None]:
units = npc_sessions.get_units_electrodes_spike_times(session.id).filter('default_qc')
units

In [None]:
opto_locations = trials[['location']].unique()

In [None]:
import bisect
import numpy as np
import numpy.typing as npt
import scipy.ndimage
import nwbwidgets.analysis.spikes as spikes
from typing import TypeVar, Sequence
T = TypeVar('T', bound=np.generic)
def extract_times(times: Sequence[T], start_time, stop_time) -> npt.NDArray[T]:
    """
    >>> extract_times([0,1,1.5,2], 1, 2)
    array([1. , 1.5, 2. ])
    """
    start = bisect.bisect_left(times, start_time)
    stop = bisect.bisect_right(times, stop_time, start)
    return np.asarray(times[start:stop])


In [None]:
import polars as pl
# lambda x: extract_times(x, 0, 100))
start, stop = 948.39387, 948.59637 + 0.1
f = (pl.element().where(pl.element().ge(start) & pl.element().le(stop))).rolling_sum() / stop-start
out = units.select('unit_name', 'firing_rate', (pl.col('spike_times').list.eval(f, parallel=True).alias('resp'))).explode('resp')
out.with_columns((pl.col('resp') - pl.col('firing_rate')).alias('delta_resp'))

In [None]:
import matplotlib.pyplot as plt
import numpy as np
units: pl.DataFrame

dfs = [t._df.with_columns((pl.col('stop_time') - pl.col('start_time')).round(4).alias('duration')) for t in optotagging_trials]
durations = pl.concat(dfs)['duration'].unique().sort().to_list()


for probe_letter, peak_channel, spike_times in units.sort('peak_channel', descending=False).select(pl.col('device_name').str.strip('Probe'), 'peak_channel', 'spike_times').iter_rows():
    fig, axes = plt.subplots(len(durations), len(dfs), figsize=(12, 6))
    
    for hidx, (df, haxes) in enumerate(zip(dfs, axes.T)):
        trials = df.filter(pl.col('location').str.contains(probe_letter))
        if trials.is_empty():
            continue
        spike_times = np.array(spike_times)
        # align_on = 'response_window_start_time'
        onset_column = 'start_time'
        offset_column = 'stop_time'
        
        for vidx, (ax, duration) in enumerate(zip(haxes, durations)):
            # ax.sharex(haxes[0])
            pad_start = .03 
            pad_end = .6 - duration
            _trials = trials.filter(pl.col('duration') == duration)
            if _trials.is_empty():
                continue
            # print(_trials)
            on = _trials[onset_column] - pad_start
            off = _trials[offset_column] + pad_end
            ax.eventplot(
                    [spike_times[a:b] - on[idx] - pad_start 
                    for idx, (a, b)
                    in enumerate(zip(np.searchsorted(spike_times, on), np.searchsorted(spike_times, off)))]
                )
            # ax.axvline(0, color=[.8]*3, linestyle='--')
            offset_time = (_trials[offset_column][0] - _trials[onset_column][0])
            ax.set(xmargin=0, ymargin=0, xlim=[-pad_start, offset_time + pad_end])
            ax.add_patch(plt.Rectangle((0, ax.get_ylim()[0]), offset_time, np.diff(ax.get_ylim()).item(), color=[.8]*3, alpha=.5))
            ax.set(xticks=[0, offset_time, offset_time + pad_end])
            plt.setp(ax.get_xticklabels(), rotation=30)
            if vidx == len(dfs) - 1:
                ax.set(xlabel='time, s')    
            else:
                ax.set(xticklabels=[])
            plt.tight_layout()
            ax.set_title(f"{trials['start_time'].min():.2f} : {trials['stop_time'].max():.2f} s", fontsize=6)
            
            if hidx == 0:
                ax.set(ylabel='trials')
            else:
                ax.set(yticklabels=[])
            
    plt.suptitle(
        f"Probe {probe_letter}, peak channel = {peak_channel}, laser Bregma = ({_trials['bregma_x'][0]:.2f}, {_trials['bregma_y'][0]:.2f})",
        fontsize=6,
        )
    plt.show()