In [None]:
%load_ext autoreload
%autoreload 2

import pynwb

import numpy as np

from datetime import datetime
import pytz

import nspike_helpers as ns 

### Session-specific import parameters
File paths, etc

In [None]:
# debug settings
limit_num_of_tets = 3 # To speed up testing. Set to None to load all tets

# Session-specific params
data_dir = '/opt/data46/FrankData/kkay/Bon/'
nwb_filename = 'bon03_eegtest.nwb'
data_source = 'Animal Bond'
anim = 'Bon' 
day = 3
day_str = '03'

# 'Wall clock' (i.e. actual) date and time of the Nspike time = 0 for this experiment.
# NOTE: this is not the actual zero_time, as we don't have easy access to that.
dataset_zero_time = datetime(2006, 1, 1, 12, 0, 0, tzinfo=pytz.timezone('US/Pacific'))

# General params/presets
file_create_date = datetime.now()

source = 'NSpike data acquisition system'
eeg_samprate = 1500.0 # Hz

eeg_subdir = "EEG"
epochs_file = "times.mat"
tetinfo_file = "tetinfo.mat"
timestamps_per_sec = 10000

### Parse inputs and create NWBfile

In [None]:
# check the input arguments
if not os.path.exists(data_dir):
        print('Error: data_dir %s does not exist' % data_dir)
        exit(-1)

# get filename prefix and file locations
prefix = anim.lower()
eeg_path = os.path.join(data_dir, eeg_subdir)

# Calculate the POSIX timestamp when Nspike clock = 0 (seconds)
Nspike_posixtime_offset = dataset_zero_time.timestamp()

In [None]:
nwbf = pynwb.NWBFile(data_source,
               'Converted NSpike data from %s' % data_dir,
               anim+day_str,
               dataset_zero_time,
               file_create_date=file_create_date,
               lab='Frank Laboratory',
               experimenter='Mattias Karlsson',
               institution='UCSF',
               experiment_description='Recordings from awake behaving rat',
               session_id=data_dir)

### Animal Behavior

In [None]:
# create position, direction and speed
position_list = list()
pos = pynwb.behavior.Position(data_source, position_list)

direction_list = list()
dir = pynwb.behavior.CompassDirection(data_source, direction_list)

speed_list = list()
speed = pynwb.behavior.BehavioralTimeSeries(data_source, speed_list)

# NOTE that day_inds is 0 based
time_list = dict()
nwb_epoch = dict()
pos_files = ns.get_files_by_day(data_dir, prefix, 'pos')
task_files = ns.get_files_by_day(data_dir, prefix, 'task')

In [None]:
mat = ns.loadmat2(task_files[day])
task_struct = mat['task'][0,day - 1]
# find the pos file for this day and load it

mat = ns.loadmat2(pos_files[day])
pos_struct = mat['pos'][0,day-1][0,:] # pos_struct is a row vector, 1 cell per epoch

# compile a list of time intervals in an array and create the position, head direction and velocity structures
time_list = [] #np.zeros((0,2))
time_index = 0

In [None]:
for epoch_ind, pos_epoch in enumerate(pos_struct):
        epoch_num = epoch_ind + 1
        if pos_epoch.size > 0:
                pos_epoch = pos_epoch[0,0] # retrieve dict from 1x1 ndarray
            
                # Assume field order: (time,x,y,dir,vel)
                (time_idx, x_idx, y_idx, dir_idx, vel_idx) = range(5)

                # convert times to POSIX time
                timestamps = pos_epoch['data'][:,time_idx] + Nspike_posixtime_offset
                
                # TODO: create a shared TimeSeries for timestamps, across all behavioral timeseries
                # ?? timestamps_obj = pynwb.TimeSeries(timestamps=timestamps...)
                
                # collect times of epoch start and end
                time_list.append([timestamps[0], timestamps[-1]])
                
                m_per_pixel = float(pos_epoch['cmperpixel'])/100 # NWB wants meters per pixel
                
                # we can also create new SpatialSeries for the position, direction and velocity information
                #NOTE: Each new spatial series has to have a unique name.
                pos.create_spatial_series(name='Position d%d e%d' % (day, epoch_num), 
                                          source='overhead camera',
                                          timestamps = timestamps,
                                          data=pos_epoch['data'][:, (x_idx, y_idx)],
                                          reference_frame='corner of video frame',
                                          conversion=m_per_pixel,
                                          #unit='m'
                                          ) # *after* conversion

                dir.create_spatial_series(name='Head Direction d%d e%d'% (day, epoch_num), 
                                          source='overhead camera',
                                          timestamps=timestamps,
                                          data=pos_epoch['data'][:, dir_idx],
                                          reference_frame='0=facing top of video frame (?), positive clockwise (?)',
                                          #unit='radians'
                                          )
                
                speed.create_timeseries(name='Speed d%d e%d' % (day, epoch_num),
                                         source='overhead camera',
                                         timestamps=timestamps,
                                         data=pos_epoch['data'][:, vel_idx],
                                         unit='m/s', # *after* conversion. data values are in pixels/s
                                         conversion=m_per_pixel,
                                         description='smoothed movement speed estimate')
time_list = np.asarray(time_list)
                                
    

In [None]:
# create a Processing module for behavior
behav_mod = nwbf.create_processing_module('Behavior', data_source, 'Behavioral variables')
# add the position, direction and speed data
behav_mod.add_data_interface(pos)
behav_mod.add_data_interface(dir)
behav_mod.add_data_interface(speed)

### Epochs (Not currently implemented)

In [None]:
# # create a list to store all of the fields in the task structure and 
# # a parallel list to store the created task interval structures
# task_fields = list()
# task_intervals = list()

# # each day will be defined as a single Epoch in NWB so we go through and get the first and last time from the
# # position data
# day_start = time_list[0,0]
# day_end = time_list[-1,1]

# nwb_epoch = nwbf.create_epoch('day %s' % day, data_source, day_start, day_end, [], 'day %s' % day)
# # add ignore intervals for the spaces between our epochs (it's not clear if this is necessary, but it won't hurt)
# # also, there's probably a more "python-ic" way to do this, but I don't know what it is 8-)
# if len(time_list) > 1:
#         n = 1
#         while n <= #NO!# len(time_list):
#                 #nwb_epoch.add_ignore_interval(time_list[n-1][1], time_list[n][0])
#                 n += 1

# # now we go through the task structure and add a new interval series or an interval to an existing interval series
# # for each element in the task structure
# for epoch_num, task_epoch in enumerate(task_struct):
#         if task_epoch.size > 0:
#                 task_epoch = task_epoch[0,0] # retrieve dict from 1x1 ndarray
#                 for field_name, value in task_epoch.items()
#                         if field_name not in task_fields:
#                                 # add the field_name to the list and create a new IntervalSeries for it
#                                 task_fields.append(field_name)
# #                               tmp_array = np.ndarray(2);
# #                               tmp_array = [time_list[epoch_num][0], time_list[epoch_num][1]]
#                                 tmp_interval = IntervalSeries(field_name, 'matlab task structure')
#                                 # add the interval for this epoch
#                                 tmp_interval.add_interval(*time_list[epoch_ind,:])
#                                 task_intervals.append(tmp_interval)
#                         else:
#                                 # add the interval to appropriate element of the list
#                                 task_intervals[task_fields.index(field_name)].add_interval(*time_list[epoch_ind])


# # Now add the complete list of task intervals to the behav_mod module
# for interval in task_intervals:
#         behav_mod.add_data_interface(BehavioralEpochs('task information', interval))

## Tetrode info
Load in `tetinfo` struct and populate ElectrodeTable, electrode groups, etc.

In [None]:
# Create the electrode table.
# The logic here is as follows:
#   Each Tetrode gets its own ElectrodeGroup and ElectrodeTableRegion
#   Each individual recording channel gets its own row in nwbfile.electrodes

# we first create the ElectrodeTable that all the electrodes will go into
nchan_per_tetrode = 4 #these files all contain tetrodes, so we assume four channels
tetinfo_filename = "%s/%s%s" % (data_dir, prefix, tetinfo_file)
recording_device = nwbf.create_device('NSpike acquisition system', data_source)
tet_electrode_group = dict()
tet_electrode_table_region = dict()
lfp_electrode_table_region = dict()

In [None]:
mat = ns.loadmat2(tetinfo_filename)
# tetinfo = mat['tetinfo'][0, day-1][0, 0].squeeze(axis=0)

# for (i,t) in enumerate(tetinfo):
#     print(i)
#     print(type(t))
#     print(t.shape)
#     print(type(t[0, 0]['area'][0]))
#     print(t[0, 0]['area'].dtype)
# #     print(t[0, 0]['area'][0])



In [None]:
# create a dict using 1-indexed tetrode numbers (pretty names)
tet_structs = mat['tetinfo'][0,day-1][0,0].squeeze(axis=0) #only look at first epoch because rest are duplicates
tets = {i+1:tet_struct[0,0] for (i,tet_struct) in enumerate(tet_structs) if tet_struct.size > 0}

# For debugging, limit number of tets to import
subset_keys = sorted(tets.keys())[0:limit_num_of_tets]
tets = {k:v for (k,v) in tets.items() if k in subset_keys}

print(limit_num_of_tets)
print("Using tetrode numbers:")
print(subset_keys)

In [None]:
# kenny's data has a nested [day][epoch][tetrode] structure but duplicates the info across epochs, so we can just
# use the first epoch for everything
chan_num = 0 # this will hold an incrementing channel number for the entire day of data
for tet_num, tet in tets.items():
        #print('making electrode group for day %d, tet %d' % (day, tet_ind))
        # go through the list of fields
        hemisphere = '?'
        # tet.area/.subarea are 1-d arrays of Unicode strings
        area = str(tet['area'][0]) if 'area' in tet else '?' # h5py barfs on numpy.str_ type objects?
        if 'sub_area' in tet: 
            sub_area = str(tet['sub_area'][0]) # h5py barfs on numpy.str_ type objects?
            location = area + ' ' + sub_area
        else:
            sub_area = '?'
            location = area 

        # tet.depth is a 1x1 cell array in tetinfo struct for some reason (multiple depths?)
        # (which contains the expected 1x1 numeric array)
        coord = [np.nan, np.nan, tet['depth'][0, 0][0, 0] / 12 / 80 * 25.4] if 'depth' in tet else [np.nan, np.nan, np.nan]
        impedance = np.nan
        filtering = 'unknown - likely 600Hz-6KHz'

        channel_location = [location, location, location, location]
        channel_coordinates = [coord, coord, coord, coord]
        electrode_name = "%02d-%02d" % (day, tet_num)
        description = "tetrode {tet_num} located in {location} on day {day}".format(tet_num=tet_num,
                                                                                   location=location,
                                                                                   day=day)

        # we need to create an electrode group for this tetrode
        tet_electrode_group[tet_num] = nwbf.create_electrode_group(electrode_name,
                                                            data_source,
                                                            description,
                                                            location,
                                                            recording_device)

        for i in range(nchan_per_tetrode):
                # now add an electrode
                nwbf.add_electrode(x = coord[0],
                                   y = coord[1],
                                   z = coord[2],
                                   imp = impedance,
                                   location = location,
                                   filtering = filtering,
                                   group = tet_electrode_group[tet_num],
                                   group_name = tet_electrode_group[tet_num].name,
                                   id = chan_num)
                chan_num = chan_num + 1

        # now that we've created four entries, one for each channel of the tetrode, we create a new
        # electrode table region for this tetrode and number it appropriately
        table_region_description = 'ntrode %d region' % tet_num
        table_region_name = '%d' % tet_num
        tet_electrode_table_region[tet_num] = nwbf.create_electrode_table_region(
            list(range(chan_num-nchan_per_tetrode,chan_num)),
            table_region_description,
            table_region_name)

        # Also create electrode_table_regions for each tetrode's LFP recordings
        # (Assume that LFP is taken from the first channel)
        lfp_electrode_table_region[tet_num] = nwbf.create_electrode_table_region(
            [chan_num-nchan_per_tetrode],
            table_region_description,
            table_region_name)

In [None]:
# tet_electrode_table_region[1].region
# nwbf.ec_electrode_groups['03-01'].description
tet_electrode_group

In [None]:
# # get the (unique) electrodeTable for this file
# et = nwbf.get_acquisition('LFP').electrical_series['boneeg-3-1']
# # query for rows in table with known group name (day-tetnum, in this case)
# # eti = et.which(group_name='03-01')
# # do a list comprehension to get the tuples
# # et_result = [nwbf.ec_electrodes.data[idx] for idx in eti] # list comprehension
# # get electrodeTable column index for field 'group'
# # colidx_group = et.columns.index('group')

# # print(et.data[eti[0]][colidx_group]) # electrodeGroup for tet 3-10
# # print(et_result[0][colidx_group])   # same as above

# print(et) # NB no 'fields', why?

## LFP

In [None]:
%%time

eeg_files = ns.get_eeg_by_day(eeg_path, prefix, 'eeg')
#LFP data
lfp_data = list()
lfp = pynwb.ecephys.LFP(data_source, lfp_data)
# read data from EEG/*eeg*.mat files and build TimeSeries object

print('processing LFP data for day %2d' % day)
dayfiles = eeg_files[day]
for tet_num in tets.keys():
    print(' -> tet_num: %d' % tet_num)
    timestamps, data = ns.build_day_eeg(dayfiles[tet_num], eeg_samprate)
    # convert the timestamps to POSIX time:
    timestamps += Nspike_posixtime_offset
    name = "{prefix}eeg-{day}-{tet}".format(prefix=prefix, day=day, tet=tet_num)

    #lfp_data.append(ElectricalSeries(name, source, data, electrode_table_region[day][tet_num], starttime=0,
    #                                                                rate=eeg_samprate, timestamps=timestamps))
    #     lfp.create_electrical_series(name, source, data, electrode_table_region[tet_num], starting_time=0.0,
    #                                                          rate=eeg_samprate, timestamps=timestamps)
    print(str(timestamps.dtype) + ' ' + str(data.dtype))
    lfp.create_electrical_series(name, source, data, lfp_electrode_table_region[tet_num], timestamps=timestamps)
print('processed LFP data from day %d' % day)

In [None]:
#add the lfp data to the file
nwbf.add_acquisition(lfp)

## Spikes

In [None]:
# Create unit metadata first
# External clustering software gives names for each cluster--we want to preserve these
nwbf.add_unit_column('cluster_name',  '(str) cluster name from clustering software')
nwbf.add_unit_column('elec_group',    '(electrodeGroup) nTrode on which spikes were recorded')
# For tetrode data, this will usually be all channels in the tetrode
nwbf.add_unit_column('neighborhood',  '(electrodeTableRegion) list of electrodes on which spikes were clustered')
# AKA 'Valid_times'--the times during which a spike from this cluster could have possibly been observed.
# (handle periods between behavior epochs, acquisition system dropouts, etc.)
nwbf.add_unit_column('obs_intervals', '(intervalSeries) Observation Intervals for the spike times')

In [None]:
#get the spike times from the spikes files
#each cluster gets a unique number starting at zero

spike_files = ns.get_files_by_day(data_dir, prefix, 'spikes')
print('\nLoading spikes file :' + spike_files[day])
mat = ns.loadmat2(spike_files[day])

In [None]:
spike_mod = nwbf.create_processing_module('Spike Data', data_source, 'Clustered Spikes')
spike_UnitTimes = pynwb.misc.UnitTimes(data_source)

spike_unit = []
obs_intervals = {}
cluster_by_tet = {}
cluster_id = 0

# Matlab structs are nested by: day, epoch, tetrode, cluster, but we will want to save all spikes from a give cluster
# *across multiple epochs* in same spike list. So we rearrange the nested matlab structures for convenience. We 
# create a nested dict, keyed by 1) tetrode, 2) cluster number, then 3) epoch. NB the keys are 1-indexed, to be 
# consistent with the original data collection. (We only process one day at a time for now, so no need to nest days).

In [None]:
spike_struct = mat['spikes'][0,day - 1].T # spike_struct is a row vector, 1 Matlab 'cell' per epoch
for epoch_ind, espikes in enumerate(spike_struct):
        espikes = espikes[0].T # espikes is a row vector, 1 Matlab 'cell' per tetrode 
        # use epoch_ind + 1 to keep the numbers consistent with the original data collection
        epoch_num = epoch_ind + 1
        for tet_ind, tspikes in enumerate(espikes):
                tspikes = tspikes[0].T # tspikes is a row vector, 1 Matlab 'cell' per cluster 
                # use tet_ind + 1 to keep the numbers consistent with the original data collection
                tet_num = tet_ind + 1
                # respect tet subset selection done above
                if tet_num not in tets.keys():
                        continue
                if tet_num not in cluster_by_tet.keys():
                        cluster_by_tet[tet_num] = dict()
                for cluster_ind, cspikes in enumerate(tspikes):
                        cspikes = cspikes[0]
                        cluster_num = cluster_ind+1 # keep numbers consistent with original data collection
                        # check to see if there is something in the structure
                        if len(cspikes):
                                if cluster_num not in cluster_by_tet[tet_num].keys():
                                        cluster_by_tet[tet_num][cluster_num] = dict()
                                cluster_by_tet[tet_num][cluster_num][epoch_num] = cspikes[0,0]
                                

In [None]:
# now we create the SpikeEventStructures and their containing EventWaveform objects
for tet_num in cluster_by_tet.keys():
        obs_intervals[tet_num] = {}
        for cluster_num in cluster_by_tet[tet_num].keys():
                print('Adding cluster id %d' % cluster_id)

                cluster_name = 'd%d t%d c%d' % (day, tet_num, cluster_num)
                print('cluster name: ' + cluster_name)
                
                cluster_tmp = cluster_by_tet[tet_num][cluster_num]
                
                # construct a full data array and a parallel list of observation intervals
                obs_intervals[tet_num][cluster_num] = pynwb.misc.IntervalSeries(name = cluster_name, 
                                                            source = source,
                                                            description = 'Observation intervals for spikes from cluster ' +
                                                            str(cluster_num) + ' on tetrode ' + str(tet_num))
                
                spikes_ep = []
                for epoch in cluster_tmp.keys():
                        if cluster_tmp[epoch]['data'].shape[0]:
                                spikes_ep.append(cluster_tmp[epoch]['data'][:,0])
                        for obs_intervals_cl_ep in cluster_tmp[epoch]['timerange']:
                                obs_intervals[tet_num][cluster_num].add_interval(*obs_intervals_cl_ep.T.astype(float))
    
                spiketimes = np.concatenate(spikes_ep)
                
                # Add Observation Intervals to nwbfile willy-nilly (1 per cluster), 
                # so that we can successfully refer to them in the Unit metadata table
                spike_mod.add_data_interface(obs_intervals[tet_num][cluster_num])
                
                nwbf.add_unit(data = {'cluster_name': cluster_name, 
                                      'elec_group': tet_electrode_group[tet_num], 
                                      # can't just refer to electrode_table_region itself: are never added to nwbfile to 
                                      # begin with. Instead, use 'data' field, which is a list of electrodeTable indices.
                                      'neighborhood': lfp_electrode_table_region[tet_num], # tet_electrode_table_region[tet_num].data, 
                                      'obs_intervals': obs_intervals[tet_num][cluster_num]},
                              id = cluster_id)
                
                spike_UnitTimes.add_spike_times(cluster_id, spiketimes)
                                
                cluster_id += 1

In [None]:
# Add UnitTimes to the spike_mod ProcessingModule
spike_mod.add_data_interface(spike_UnitTimes)

### Write out NWBfile!

In [None]:
# make an NWBFile
with pynwb.NWBHDF5IO(nwb_filename, mode='w') as iow:
    iow.write(nwbf)

### Read our NWBfile, and check some roundtrip data

In [None]:
io = pynwb.NWBHDF5IO(nwb_filename, mode='r')
nwbf_read = io.read()

In [None]:
io.close()

In [None]:
nwbf_read.get_acquisition('LFP').electrical_series['boneeg-3-2']

In [None]:
cl_id = 1
print(nwbf_read.modules['Spike Data']['UnitTimes'].get_unit_spike_times(cl_id).shape)
clname_idx = nwbf_read.units.colnames.index('cluster_name')
obsint_idx = nwbf_read.units.colnames.index('obs_intervals')
print(nwbf_read.units.columns[clname_idx][cl_id])
print(nwbf_read.units.columns[obsint_idx][cl_id].timestamps)


In [None]:
es = nwbf_read.get_acquisition('LFP').electrical_series['boneeg-3-2']

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(es.timestamps[0:10000], es.data[0:10000])