In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from allensdk.brain_observatory.ecephys.visualization import plot_mean_waveforms, plot_spike_counts, raster_plot
from dandi import dandiapi
from pynwb import NWBHDF5IO

%matplotlib inline

In [None]:
stim_filepath = f"../../../data/visual_coding/sub-699733573_ses-715093703.nwb"
stim_io = NWBHDF5IO(stim_filepath, mode="r", load_namespaces=True)
stim_file = stim_io.read() 
units = stim_file.units.to_dataframe()
units.head()

In [None]:
units.shape

In [None]:
units.keys()

In [None]:
good_units = units[units.quality == 'good']

In [None]:
# sorting units by quality or other attribute (or firing rate?)
# plot spike times for units
# show average waveform across units (possibly in different brain areas?)
# selection of unit
# showing waveform
# for unit, plot firing rate over time or drift
# show unit location

### Showing Spike Times

In [None]:
    def presentationwise_spike_counts(
        self,
        bin_edges,
        stimulus_presentation_ids,
        unit_ids,
        binarize=False,
        dtype=None,
        large_bin_size_threshold=0.001,
        time_domain_callback=None
    ):
        ''' Build an array of spike counts surrounding stimulus onset per
        unit and stimulus frame.

        Parameters
        ---------
        bin_edges : numpy.ndarray
            Spikes will be counted into the bins defined by these edges.
            Values are in seconds, relative to stimulus onset.
        stimulus_presentation_ids : array-like
            Filter to these stimulus presentations
        unit_ids : array-like
            Filter to these units
        binarize : bool, optional
            If true, all counts greater than 0 will be treated as 1. This
            results in lower storage overhead, but is only reasonable if bin
            sizes are fine (<= 1 millisecond).
        large_bin_size_threshold : float, optional
            If binarize is True and the largest bin width is greater than
            this value, a warning will be emitted.
        time_domain_callback : callable, optional
            The time domain is a numpy array whose values are trial-aligned bin
            edges (each row is aligned to a different trial). This optional
            function will be applied to the time domain before counting spikes.

        Returns
        -------
        xarray.DataArray :
            Data array whose dimensions are stimulus presentation, unit,
            and time bin and whose values are spike counts.

        '''

        stimulus_presentations = self._filter_owned_df(
            'stimulus_presentations',
            ids=stimulus_presentation_ids)
        units = self._filter_owned_df('units', ids=unit_ids)

        largest_bin_size = np.amax(np.diff(bin_edges))
        if binarize and largest_bin_size > large_bin_size_threshold:
            warnings.warn(
                'You\'ve elected to binarize spike counts, but your maximum '
                f'bin width is {largest_bin_size:2.5f} seconds. '
                'Binarizing spike counts with such a large bin width can '
                'cause significant loss of accuracy! '
                'Please consider only binarizing spike counts '
                f'when your bins are <= {large_bin_size_threshold} '
                'seconds wide.'
            )

        bin_edges = np.array(bin_edges)
        domain = build_time_window_domain(
            bin_edges,
            stimulus_presentations['start_time'].values,
            callback=time_domain_callback)

        out_of_order = np.where(np.diff(domain, axis=1) < 0)
        if len(out_of_order[0]) > 0:
            out_of_order_time_bins = \
                [(row, col) for row, col in zip(out_of_order)]
            raise ValueError("The time domain specified contains out-of-order "
                             f"bin edges at indices: {out_of_order_time_bins}")

        ends = domain[:, -1]
        starts = domain[:, 0]
        time_diffs = starts[1:] - ends[:-1]
        overlapping = np.where(time_diffs < 0)[0]

        if len(overlapping) > 0:
            # Ignoring intervals that overlaps multiple time bins because
            # trying to figure that out would take O(n)
            overlapping = [(s, s + 1) for s in overlapping]
            warnings.warn("You've specified some overlapping time intervals "
                          f"between neighboring rows: {overlapping}, "
                          "with a maximum overlap of"
                          f" {np.abs(np.min(time_diffs))} seconds.")

        tiled_data = build_spike_histogram(
            domain,
            self.spike_times,
            units.index.values,
            dtype=dtype,
            binarize=binarize
        )

        stim_presentation_id = stimulus_presentations.index.values

        tiled_data = xr.DataArray(
            name='spike_counts',
            data=tiled_data,
            coords={
                'stimulus_presentation_id': stim_presentation_id,
                'time_relative_to_stimulus_onset': (bin_edges[:-1] +
                                                    np.diff(bin_edges) / 2),
                'unit_id': units.index.values
            },
            dims=['stimulus_presentation_id',
                  'time_relative_to_stimulus_onset',
                  'unit_id']
        )

        return tiled_data


In [None]:
# We're going to build an array of spike counts surrounding stimulus presentation onset
# To do that, we will need to specify some bins (in seconds, relative to stimulus onset)
time_bin_edges = np.linspace(-0.01, 0.4, 200)

# look at responses to the flash stimulus
flash_250_ms_stimulus_presentation_ids = session.stimulus_presentations[
    session.stimulus_presentations['stimulus_name'] == 'flashes'
].index.values

# and get a set of units with only decent snr
decent_snr_unit_ids = session.units[
    session.units['snr'] >= 1.5
].index.values

spike_counts_da = session.presentationwise_spike_counts(
    bin_edges=time_bin_edges,
    stimulus_presentation_ids=flash_250_ms_stimulus_presentation_ids,
    unit_ids=decent_snr_unit_ids
)
spike_counts_da

In [None]:
data_array = good_units.spike_times

fig, ax = plt.subplots(figsize=(12, 12))

img = ax.imshow(data_array.T, interpolation='none')
plt.colorbar(img, cax=cbar_axis)

cbar_axis.set_ylabel(cbar_label, fontsize=16)

ax.yaxis.set_major_locator(plt.NullLocator())
ax.set_ylabel(ylabel, fontsize=16)

reltime = np.array(time_coords)
ax.set_xticks(np.arange(0, len(reltime), xtick_step))
ax.set_xticklabels([f'{mp:1.3f}' for mp in reltime[::xtick_step]], rotation=45)
ax.set_xlabel(xlabel, fontsize=16)

ax.set_title(title, fontsize=20)

return fig

### Waveforms

In [None]:
unit_num = 950913039

In [None]:
units.waveform_mean[950913039].shape

In [None]:
waveforms = np.array([waveform for waveform in units.waveform_mean])
waveforms.shape

In [None]:
avg_waveform = np.average(waveforms,axis=0)
avg_waveform.shape

In [None]:
fig, ax = plt.subplots()
# for waveform in waveforms:
#     ax.plot(waveform)
ax.plot(avg_waveform)
plt.show()

In [None]:
waveform = units.waveform_mean[unit_num]
waveform.shape

In [None]:
unit_avg_waveform = np.average(waveform, axis=0)

In [None]:
fig, ax = plt.subplots()
ax.plot(unit_avg_waveform)
plt.show()