In [None]:
import pynwb
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 12})
plt.rcParams.update({'lines.solid_capstyle': 'butt'})


### Helper functions

In [None]:
def times_in_intervals(times, intervals):
    times.sort()
    # Sort intervals by start time
    intervals = intervals[intervals[:,0].argsort()]
    return times_in_intervals_rec(times, intervals)

def times_in_intervals_rec(times, intervals):
    if times.size == 0:
        return True
    if intervals.size == 0:
        return False
    t = times[0]
    start_time = intervals[0, 0]
    end_time = intervals[0, 1]
    if t >= start_time:
        if t <= end_time:
            return times_in_intervals_rec(times[1:], intervals)
        else:
            return times_in_intervals_rec(times, intervals[1:, :])
    else:
        return False

### Data import parameters

In [None]:
data_dir = '/home/tjd/Src/nwbquery/examples/franklab/ff_import'
animday = 'Bon04'
nwb_filename = data_dir + '/' + animday + '_test.nwb'

### Read NWB file

In [None]:
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf = io.read()

### Collect spike times and observation intervals

In [None]:
# Select cluster 
# TODO select cluster by metadata
cluster_id = 30

# Get cluster name
clname_idx = nwbf.units.colnames.index('cluster_name')
cluster_name = animday + ' ' + nwbf.units.columns[clname_idx][cluster_id]
print('Cluster name = ' + cluster_name)

# Get spike times
spikes_t = nwbf.modules['Spike Data']['UnitTimes'].get_unit_spike_times(cluster_id)
print('# of spikes = %d' % spikes_t.size)
print('Time of 1st/last spike (s): %0.4f / %0.4f \n' % (spikes_t[0], spikes_t[-1]))

# Get spike observation intervals
obsint_idx = nwbf.units.colnames.index('obs_intervals')
obs_IntervalSeries = nwbf.units.columns[obsint_idx][cluster_id] # returns IntervalSeries
assert np.all(np.abs(obs_IntervalSeries.data)==1), "Multiple interval types in an IntervalSeries not supported"
obs_intervals = np.reshape(obs_IntervalSeries.timestamps, (-1,2))
print('# of intervals = %d' % obs_intervals.shape[0])
print('Spike Observation Intervals (s): ')
print(obs_intervals)
print()

assert times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'

### Plot spiking and some sample Time Queries

In [None]:
fig1 = plt.figure(1, figsize=(15,6))
ax1 = fig1.add_subplot(1,1,1)
labels = []
labels.append((1,'Acquired Spiking Data'))

obsint_h = ax1.plot(obs_intervals.T, np.full(obs_intervals.T.shape, 1), 'b', linewidth=25, marker='', alpha=0.1)
spikes_h = ax1.plot(spikes_t, np.full(spikes_t.shape, 1), marker='|', markersize=10, linestyle='', color='b')

# Time queries
tqs = []
# (query, short_label, long_label, mockup answers)
# Get spikes from epoch 1
tqs.append((obs_intervals[0,:], 'A', 'Full epoch', obs_intervals[0,:]))
# Get spikes from all run epochs
tqs.append((obs_intervals[(1,3,5),:], 'B', 'Multiple epochs', obs_intervals[(1,3,5),:]))
# Get spikes from valid interval with no spiking
tqs.append(([5900, 6200], 'C', 'Full overlap, no spikes', [5900, 6200]))
# partial overlap between query and obs_int
tqs.append(([6600, 7050], 'D', 'Partial overlap', [6600, 6811]))
# non-overlap between query and obs_int
tqs.append(([9500, 9900], 'E', 'No overlap', []))

y_offset = -5
spacing = 6
cmap = plt.get_cmap("tab10")
plots_h = []
for i, tq in enumerate(tqs):
    tq_intervals = np.array(tq[0]).T
    ypos = y_offset-i*spacing
    labels.append((ypos, '[ %s ]  %s' % (tq[1], tq[2])))
    line_h = ax1.plot(tq_intervals,np.full(tq_intervals.shape, ypos), 
                        color=cmap(i), linewidth=5, marker='')
    ax1.vlines(tq_intervals, ypos, -1, color=cmap(i), linestyle='--', alpha=0.2)
    plots_h.append(line_h[0])
    
    # plot query response
    r_intervals = np.array(tq[3], ndmin=2)
    r_spikes = []
    if r_intervals.size:
        for r_int in r_intervals:
            r_spikes.extend([s for s in spikes_t if s >= r_int[0] and s <= r_int[1]])

        ax1.plot(r_intervals.T, np.full(r_intervals.T.shape, ypos-3),
                color='b', linewidth=25, marker='', alpha=0.1)
        spikes_h = ax1.plot(r_spikes, np.full(len(r_spikes), ypos-3), marker='|', markersize=10, linestyle='', color='b')
    
#     obsint_h = ax1.plot(obs_intervals.T, np.full(obs_intervals.T.shape, 1), 'b', linewidth=25, marker='', alpha=0.1)
#     spikes_h = ax1.plot(spikes_t, np.full(spikes_t.shape, 1), marker='|', markersize=10, linestyle='', color='b')

ax1.set_ylim([ypos-spacing,2*spacing])
ax1.set_yticks([l[0] for l in labels])
ax1.set_yticklabels([l[1] for l in labels])

ax1.set_xlabel('Time (s)')
ax1.legend([spikes_h[0], obsint_h[0]], ['Spike times', 'Spike observation Intervals'],
           labelspacing=1, borderpad=1, loc='upper right' )
# ax1.legend([spikes_h, obsint_h] + plots_h,
#            ['Spike times', 'Spike observation Intervals'] + [l[1] for l in labels],
#            labelspacing=1, borderpad=1 )
plt.tight_layout()
fig1.savefig('./spike_timequeries.png', dpi=200)
None

In [None]:
tq

### Compute average firing rate across intervals

In [None]:
def mean_firing_rate(spikes_t, obs_intervals):
    assert times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'
    return spikes_t.size / np.diff(obs_intervals, axis=1).sum()

fr_mean_Hz = mean_firing_rate(spikes_t, obs_intervals)
print('Mean firing rate of cluster %s: %0.3f Hz' % (cluster_name, fr_mean_Hz) )

### Collect some behavioral intervals

In [None]:
speed = nwbf.modules['Behavior']['Speed']['Speed d4 e1']
nwbf.session_start_time.timestamp() #TODO
electrode_query = nwbf.electrodes['location']

In [None]:
speed_ts = np.array(speed.timestamps)-nwbf.session_start_time.timestamp()
speed_data = np.array(speed.data)

speed_gt5_ts = speed_ts[speed_data[:]>5]

fig2 = plt.figure(2, figsize=(15,4))
ax2 = fig2.add_subplot(1,1,1)
# ax2.hist(speed.data,100);
ax2.plot(speed_gt5_ts, np.full(speed_gt5_ts.shape, 1), 'x')