In [None]:
import pynwb
import numpy as np

### Helper functions

In [None]:
def times_in_intervals(times, intervals):
    times.sort()
    # Sort intervals by start time
    intervals = intervals[intervals[:,0].argsort()]
    return times_in_intervals_rec(times, intervals)

def times_in_intervals_rec(times, intervals):
    if times.size == 0:
        return True
    if intervals.size == 0:
        return False
    t = times[0]
    start_time = intervals[0, 0]
    end_time = intervals[0, 1]
    if t >= start_time:
        if t <= end_time:
            return times_in_intervals_rec(times[1:], intervals)
        else:
            return times_in_intervals_rec(times, intervals[1:, :])
    else:
        return False

### Data import parameters

In [None]:
data_dir = '/home/tjd/Src/nwbquery/examples/franklab/ff_import'
animday = 'Bon04'
nwb_filename = data_dir + '/' + animday + '_test.nwb'

### Read NWB file

In [None]:
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf_read = io.read()

### Collect spike times and observation intervals

In [None]:
# Select cluster 
# TODO select cluster by metadata
cluster_id = 30

# Get cluster name
clname_idx = nwbf_read.units.colnames.index('cluster_name')
cluster_name = animday + ' ' + nwbf_read.units.columns[clname_idx][cluster_id]
print('Cluster name = ' + cluster_name)

# Get spike times
spikes_t = nwbf_read.modules['Spike Data']['UnitTimes'].get_unit_spike_times(cluster_id)
print('# of spikes = %d' % spikes_t.size)
print('Time of 1st/last spike (s): %0.4f / %0.4f \n' % (spikes_t[0], spikes_t[-1]))

# Get spike observation intervals
obsint_idx = nwbf_read.units.colnames.index('obs_intervals')
obs_IntervalSeries = nwbf_read.units.columns[obsint_idx][cluster_id] # returns IntervalSeries
assert np.all(np.abs(obs_IntervalSeries.data)==1), "Multiple interval types in an IntervalSeries not supported"
obs_intervals = np.reshape(obs_IntervalSeries.timestamps, (-1,2))
print('# of intervals = %d' % obs_intervals.shape[0])
print('Spike Observation Intervals (s): ')
print(obs_intervals)
print()

assert times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'

### Compute average firing rate across intervals

In [None]:
def mean_firing_rate(spikes_t, obs_intervals):
    assert times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'
    return spikes_t.size / np.diff(obs_intervals, axis=1).sum()

fr_mean_Hz = mean_firing_rate(spikes_t, obs_intervals)
print('Mean firing rate of cluster %s: %0.3f Hz' % (cluster_name, fr_mean_Hz) )
