In [None]:
%load_ext autoreload
%autoreload 2

import pynwb
import numpy as np

# pip install python-intervals, *NOT* pip install intervals
import intervals as iv

import query_helpers as qu

import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 12})
plt.rcParams.update({'lines.solid_capstyle': 'butt'})


### Data import parameters

In [None]:
# data_dir = '/home/tjd/Src/nwbquery/examples/franklab/ff_import'
data_dir = '/Users/tjd/Src/nwbquery/examples/franklab/ff_import'
animday = 'Bon04'
nwb_filename = data_dir + '/' + animday + '_test.nwb'

### Read NWB file

In [None]:
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf = io.read()

### Collect spike times and observation intervals

In [None]:
# Select cluster 
# TODO select cluster by metadata
cluster_id = 30

# Get cluster name
clname_idx = nwbf.units.colnames.index('cluster_name')
cluster_name = animday + ' ' + nwbf.units.columns[clname_idx][cluster_id]
print('Cluster name = ' + cluster_name)

# Get spike times
spikes_t = nwbf.modules['Spike Data']['UnitTimes'].get_unit_spike_times(cluster_id)
print('# of spikes = %d' % spikes_t.size)
print('Time of 1st/last spike (s): %0.4f / %0.4f \n' % (spikes_t[0], spikes_t[-1]))

# Get spike observation intervals
obsint_idx = nwbf.units.colnames.index('obs_intervals')
obs_IntervalSeries = nwbf.units.columns[obsint_idx][cluster_id] # returns IntervalSeries
obs_intervals = qu.intervals_from_IntervalSeries(obs_IntervalSeries)

    
print('# of intervals = %d' % len(obs_intervals))
print('Spike Observation Intervals (s): ')
print(obs_intervals)
print()

assert qu.times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'

### Do some sample Time Queries

In [None]:
# Time queries
tqs = []
# (query, short_label, long_label, mockup answers)
# Get spikes from epoch 1
tqs.append((obs_intervals[0], 'A', 'Full epoch'))
# Get spikes from all sleep epochs
tqs.append((obs_intervals[0] | obs_intervals[2] | obs_intervals[4] | obs_intervals[6], 'B', 'Multiple epochs'))
# Get spikes from valid interval with no spiking
tqs.append((iv.closed(5900, 6200), 'C', 'Full overlap, no spikes'))
# partial overlap between query and obs_int
tqs.append((iv.closed(6600, 7050), 'D', 'Partial overlap'))
# non-overlap between query and obs_int
tqs.append((iv.closed(9500, 9900), 'E', 'No overlap'))


# TODO: harmonize timestamps across speed (POSIX time?) and spikes/obs_ints. During import.
speed = nwbf.modules['Behavior']['Speed']['Speed d4 e4']
speed_ts = np.array(speed.timestamps)-nwbf.session_start_time.timestamp()
speed_data = np.array(speed.data)

speed_threshold_fn = lambda x: x > 0.05 # m/s
speed_gt_intervals = qu.intervals_from_continuous(speed_data, speed_ts, speed_threshold_fn)
speed_gt_ep3_intervals = speed_gt_intervals & obs_intervals[2] # limit to run epochs

tqs.append((speed_gt_intervals, 'F', 'Complex behavioral query'))

# responses
rs = []
for tq in tqs:
    tq_intervals = tq[0]
    r_intervals = tq_intervals & obs_intervals # intersection operator!
#     print(iv.to_string(r_intervals))
    r_spikes = qu.times_in_intervals(spikes_t,r_intervals)
    print(len(r_spikes))
    rs.append((r_intervals, r_spikes))
    

### Plot spiking and some sample Time Queries

In [None]:
fig1 = plt.figure(1, figsize=(15,6))
ax1 = fig1.add_subplot(1,1,1)
labels = []
labels.append((1,'Acquired Spiking Data'))

ypos = 1

int_h, times_h = qu.plot_pointprocess(obs_intervals, spikes_t, axis=ax1)
# obsint_h = ax1.plot(obs_intervals_np.T, np.full(obs_intervals_np.T.shape, ypos), 
#                     color='b', linewidth=25, marker='', alpha=0.1)
# spikes_h = ax1.plot(spikes_t, np.full(spikes_t.shape, ypos), 
#                     marker='|', markersize=10, linestyle='', color='b')

In [None]:
y_offset = -5
spacing = 6
cmap = plt.get_cmap("tab10")
plots_h = []
for i, tq in enumerate(tqs):
    tq_intervals = array_from_intervals(tq[0]).T
    ypos = y_offset-i*spacing
    labels.append((ypos, '[ %s ]  %s' % (tq[1], tq[2])))
    line_h = ax1.plot(tq_intervals,np.full(tq_intervals.shape, ypos), 
                        color=cmap(i), linewidth=5, marker='')
    ax1.vlines(tq_intervals, ypos, -1, color=cmap(i), linestyle='--', alpha=0.2)
    plots_h.append(line_h[0])
    
    # plot query response
    r_intervals = array_from_intervals(rs[i][0]).T
    r_spikes = rs[i][1]

    ax1.plot(r_intervals, np.full(r_intervals.shape, ypos-3),
            color='b', linewidth=25, marker='', alpha=0.1)
    spikes_h = ax1.plot(r_spikes, np.full(len(r_spikes), ypos-3), 
                        marker='|', 
                        markersize=10, 
                        linestyle='', 
                        color='b')
    
#     obsint_h = ax1.plot(obs_intervals_np.T, np.full(obs_intervals_np.T.shape, 1), 'b', linewidth=25, marker='', alpha=0.1)
#     spikes_h = ax1.plot(spikes_t, np.full(spikes_t.shape, 1), marker='|', markersize=10, linestyle='', color='b')

ax1.set_ylim([ypos-spacing,2*spacing])
ax1.set_yticks([l[0] for l in labels])
ax1.set_yticklabels([l[1] for l in labels])

ax1.set_xlabel('Time (s)')
ax1.legend([spikes_h[0], obsint_h[0]], ['Spike times', 'Spike observation Intervals'],
           labelspacing=1, borderpad=1, loc='upper right' )
# ax1.legend([spikes_h, obsint_h] + plots_h,
#            ['Spike times', 'Spike observation Intervals'] + [l[1] for l in labels],
#            labelspacing=1, borderpad=1 )
plt.tight_layout()
fig1.savefig('./spike_timequeries.png', dpi=200)
# ax1.set_xlim((4000,4100))
None

In [None]:
ax1.set_xlim(rs[-1][1][0]-1, rs[-1][1][0]+20)
fig1

### Compute average firing rate across intervals

In [None]:
def mean_firing_rate(spikes_t, obs_intervals):
    assert times_in_intervals(spikes_t, obs_intervals), 'Spike times found outside of observation intervals'
    obs_intervals = array_from_intervals(obs_intervals)
    return spikes_t.size / np.diff(obs_intervals, axis=1).sum()

fr_mean_Hz = mean_firing_rate(spikes_t, obs_intervals)
print('Mean firing rate of cluster %s: %0.3f Hz' % (cluster_name, fr_mean_Hz) )

### Collect some behavioral intervals

In [None]:
speed_threshold = lambda x: x > 5
speed_gt_intervals = intervals_from_continuous(speed_data, speed_threshold)
speed_gt_run_intervals = speed_gt_intervals & tqs[1][0] # limit to run epochs
speed_gt_run_spikes = times_in_intervals(spikes_t, speed_gt_run_intervals)

In [None]:
# Get intervals when threshold is met
fig2 = plt.figure(2, figsize=(15,4))
ax2 = fig2.add_subplot(1,1,1)
iv_to_plot = array_from_intervals(speed_gt_run_intervals).T
ax2.plot(iv_to_plot, np.full(iv_to_plot.shape, 1),
         color='b', linewidth=25, marker='', alpha=0.1)
ax2.plot(speed_gt_run_spikes, np.full(len(speed_gt_run_spikes), 1), 
                        marker='|', 
                        markersize=10, 
                        linestyle='', 
                        color='b')

# ax2.set_xlim((0,500))
None

In [None]:
io.close()