In [None]:
%load_ext autoreload
%autoreload 2

# General
import pynwb
import numpy as np

# Local
from nwb_query import ContinuousData, PointProcess, TimeIntervals

# Plotting
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
plt.rcParams.update({'font.size': 12})
plt.rcParams.update({'lines.solid_capstyle': 'butt'})
mdates.rcParams.update({'date.autoformatter.microsecond': '%H:%M:%S.%f'})

### Todo
- ContinuousData accessors (based on column names?)
- Make ContinuousData and PointProcess iterable by interval (return single interval and its data)
- Convenience functions: obs_durations (diff of obs_intervals)
- write occupancy function
- query spiking by epoch before doing behavior selection (requires adding epochs to NWBfile)
    - requires extracting epoch information from nspike
- behavior (position/speed) indexed by time, and not by epoch name
    - requires concatenating behav on import
- ? Subclass PointProcess for spiking data (include other columns: clustering metadata, e.g.)
- ? Subclass ContinuousData for Behavioral data (include SI units?)

### Data import and analysis parameters

In [None]:
# what data to analyze
d = {}
d['anim'] = 'Bon'
d['day'] = 4 # 1-indexed

d['epoch'] = 4 # 1-indexed
d['cluster_id'] = 30

# analysis configuration
c = {}
c['speed_threshold'] = 0.05 # m/s

### Read NWB file

In [None]:
animday = '{}{:02d}'.format(d['anim'], d['day'])
nwb_filename = './' + animday + '_test.nwb'

print('Loading file: %s' % nwb_filename)
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf = io.read()
 
sst = nwbf.session_start_time.timestamp()

### Dataset Query: speed of a given animal, day, and epoch
#### NWBFile (one animal), day, epoch --> ContinuousProcess (speed)

In [None]:
speed_module_name = 'Speed d{:d} e{:d}'.format(d['day'], d['epoch']) # HACKY--we should query on epoch directly
speed_h5py = nwbf.modules['Behavior']['Speed'][speed_module_name]
speed = ContinuousData(data=speed_h5py.data[()], timestamps=speed_h5py.timestamps[()])

### Dataset Query: position of a given animal, day, and epoch
#### NWBFile (one animal), day, epoch --> ContinuousProcess (position)

In [None]:
position_module_name = 'Position d{:d} e{:d}'.format(d['day'], d['epoch']) # HACKY--we should query on epoch directly
position_h5py = nwbf.modules['Behavior']['Position'][position_module_name]
position = ContinuousData(position_h5py.data[()], position_h5py.timestamps[()])

print('*** All position records for epoch ***')
print('# of measurements = %d' % position.data.shape[0])
print('# of intervals = %d' % len(position.obs_intervals))
print('duration of intervals = %0.2f s' % np.sum(position.obs_intervals.durations()))

### Dataset Query: spiking of a given animal and cluster
#### NWBFile (one animal), cluster --> PointProcess (spiking)

In [None]:
# Get spike times and obs_intervals
spikes_t = nwbf.units['spike_times'][d['cluster_id']]
obs_intervals = TimeIntervals(nwbf.units['obs_intervals'][d['cluster_id']])
spiking = PointProcess(event_times=spikes_t,
                       obs_intervals=obs_intervals)

print('*** Spiking for cluster %s ***' % d['cluster_id'])
print('# of spikes = %d' % spiking.event_times.shape[0])
print('# of intervals = %d (epochs?)' % len(spiking.obs_intervals))
print('duration of intervals = %0.2f s' % np.sum(spiking.obs_intervals.durations()))

### TODO ----->  Query: Find spiking within an epoch
#### PointProcess (spiking), TimeInterval (epoch) --> PointProcess (spiking)

In [None]:
# epoch_time_interval = TimeIntervals(np.array([[epoch_start, epoch_end]]))
# spiking.time_query(epoch_time_interval)

### Analysis: Find time intervals where speed > threshold
#### ContinuousData (speed), lambda function --> TimeIntervals
This is an _analysis_, not a query, because we are not simply selecting a subset of a given datatype. i.e. We are not asking for a subset of the speed data, but rather for the intervals where it fulfills a lambda function. The lambda function could have been something different, like "find the times of all upward threshold crossings, and then pad this by 5 seconds on either side.". Regardless of how simple or complex the lambda function is, we consider this to be an analysis. Using the output of this to select a subset of the spiking data, however, is a query.

In [None]:
speed_threshold_fn = lambda x: x > c['speed_threshold']
speed_time_intervals = speed.filter_intervals(speed_threshold_fn)

print('*** Times where speed > threshold ***')
print('# of intervals = %d' % len(speed_time_intervals))
print('duration of intervals = %0.2f s' % np.sum(speed_time_intervals.durations()))

### Query: spiking during time intervals where speed > threshold
#### PointProcess (spiking), TimeIntervals --> PointProcess (spiking)

In [None]:
spiking_run = spiking.time_query(speed_time_intervals)  # Use the built-in time query method on PointProcess

print('*** Spiking where speed > threshold ***')
print('# of spikes = %d' % len(spiking_run.event_times))
print('# of intervals = %d' % len(spiking_run.obs_intervals))
print('duration of intervals = %0.2f s' % np.sum(spiking_run.obs_intervals.durations()))
print()

### Analysis: Mark animal position at the event time of each spike
#### PointProcess (spiking), ContinuousData (position[m x 2]) --> PointProcess with marks (spike times with associated positions)

In [None]:
spiking_run_mark_pos = spiking_run.mark_with_ContinuousData(position)

print('*** Spiking where speed > threshold, marked with position ***')
print('# of marked spikes = %d' % len(spiking_run_mark_pos.event_times))
print('# of intervals = %d' % len(spiking_run_mark_pos.obs_intervals))
print('duration of intervals = %0.2f s' % np.sum(spiking_run_mark_pos.obs_intervals.durations()))
print()

### Query: Get animal locations during running intervals
#### ContinuousData (position [m x 2]), TimeIntervals --> ContinuousData (position [m_new x 2])

In [None]:
position_run = position.time_query(speed_time_intervals)

print('*** Position where speed > threshold ***')
print('# of samples = %d' % position_run.data.shape[0])
print('# of intervals = %d' % len(position_run.obs_intervals))
print('duration of intervals = %0.2f s' % np.sum(position_run.obs_intervals.durations()))
print()

### Plot spikes by location

In [None]:
fig1 = plt.figure(1, figsize=(15,15))
ax1 = fig1.add_subplot(1,1,1)
ax1.axis('equal')

plt.plot(position.data[:,0], position.data[:,1], marker='', color='gray', label='Rat location', zorder=1)

run_label = 'Rat location during movement'
for ivl in position_run.obs_intervals.intervals:
    ivl_data = position.time_query(TimeIntervals(ivl)).data 
    plt.plot(ivl_data[:,0], ivl_data[:,1], marker='', color='red', label=run_label, zorder=2)
    run_label = '_' # omit later lines from legend

plt.scatter(spiking_run_mark_pos.marks[:, 0], spiking_run_mark_pos.marks[:, 1], marker='D', s=50, label='Location at spike times during movement', zorder=3)

ax1.legend()
ax1.set_xlabel('X position (m)')
ax1.set_ylabel('Y position (m)')
ax1.set_title('Spike-position map for {} d{} e{} c{}, speed > {:0.1f} cm/s'.format(d['anim'], d['day'], d['epoch'], d['cluster_id'], (c['speed_threshold'] * 100)))
pass