### Environment setup

In [None]:
!pip install pynwb
!pip install matplotlib

In [None]:
!pip install scipy

In [1]:
from pynwb import NWBHDF5IO
from scipy import interpolate
import numpy as np
import matplotlib.pyplot as plt

### Read NWB files

In [2]:
lfp_filepath = "../../../data/Illusion/lfp_testing/probeA_lfp.nwb"
stim_filepath = "../../../data/Illusion/lfp_testing/spike_times.nwb"

In [3]:
lfp_io = NWBHDF5IO(lfp_filepath, mode="r", load_namespaces=True)
lfp_file = lfp_io.read()
stim_io = NWBHDF5IO(stim_filepath, mode="r", load_namespaces=True)
stim_file = stim_io.read() 

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


In [4]:
lfp = lfp_file.acquisition["probe_0_lfp_data"]
series = lfp.data

In [5]:
stimulus_names = list(stim_file.intervals.keys())
print(stimulus_names)

['ICkcfg0_presentations', 'ICkcfg1_presentations', 'ICwcfg0_presentations', 'ICwcfg1_presentations', 'RFCI_presentations', 'invalid_times', 'sizeCI_presentations', 'spontaneous_presentations']


### Visualizing stimuli

In [6]:
stim_name = "ICwcfg1_presentations"
stim_table = stim_file.intervals[stim_name]

print({frame for frame in stim_table.frame})

{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0}


In [7]:
print(stim_table[0:50])

    start_time   stop_time stimulus_name  stimulus_block  frame  \
id                                                                
0    99.256575   99.656919       ICwcfg1             0.0    0.0   
1    99.656919  100.057262       ICwcfg1             0.0    0.0   
2   100.057262  100.457598       ICwcfg1             0.0    0.0   
3   100.457598  100.857927       ICwcfg1             0.0    0.0   
4   100.857927  101.258255       ICwcfg1             0.0    0.0   
5   101.258255  101.658603       ICwcfg1             0.0    0.0   
6   101.658603  102.058950       ICwcfg1             0.0    0.0   
7   102.058950  102.459292       ICwcfg1             0.0    0.0   
8   102.459292  102.859629       ICwcfg1             0.0    0.0   
9   102.859629  103.259965       ICwcfg1             0.0    0.0   
10  103.259965  103.660305       ICwcfg1             0.0    0.0   
11  103.660305  104.060644       ICwcfg1             0.0    0.0   
12  104.060644  104.460976       ICwcfg1             0.0    0.

49  [(50, 1, timestamps pynwb.base.TimeSeries at 0...  


In [8]:
select_frame = 13.0

In [9]:
### extract timestamps for given stimulus frame

stim_timestamps = []
prev_frame = None
for i in range(len(stim_table)):
    row = stim_table[i]
    if row.frame[i] != prev_frame and row.frame[i] == select_frame:
        stim_timestamps.append(row.start_time[i])
    prev_frame = row.frame[i]
print(stim_timestamps)

[128.08084646720096, 139.29023498262623, 142.492946467201, 148.09768446720096, 175.32051498262624, 188.13134646720096, 236.17172846720095, 242.577108467201, 250.583848467201, 254.58722046720095, 260.19194246720093, 329.8504967249136, 343.4619449826262, 347.4652849826262, 358.6747204672009, 363.47873498262624, 370.68482046720095, 371.4854849826263, 375.4888249826262, 377.0902127249136, 411.51910498262623, 428.33328046720095, 447.5494449826262, 481.9783927249136, 502.795910467201, 505.9985747249136, 510.802626467201, 523.6133649826262, 526.0154287249137, 529.2181067249136, 541.2282127249135, 551.6369649826262, 559.6436749826262, 562.846384467201, 571.6537649826262, 589.2686027249135, 590.0692787249136, 592.4712984672009, 596.474656467201, 603.6807049826263, 614.890148467201, 617.2921567249136, 618.0928247249136, 630.903616467201, 651.7210749826262, 661.3291587249136, 666.9338924672011, 674.1399267249136, 687.7513849826262, 695.7581149826262, 702.1635207249136, 719.7783249826263, 720.5790

## BREAK

In [None]:
# finding nearest neighbor in lfp time for each stim time

intp_stim_timestamps = []
lfp_ts_idx = 0
for stim_ts in stim_timestamps:
    nearest_lfp_ts = lfp.timestamps[lfp_ts_idx]
    # find nearest neighbor to stim timestamp within lfp timestamps
    while(stim_ts - nearest_lfp_ts > 0):
        diff_t = stim_ts - nearest_lfp_ts
        lfp_ts_idx += 1
        nearest_lfp_ts = lfp.timestamps[lfp_ts_idx]
    # save nearest neighbor timestamp
    intp_stim_timestamps.append(nearest_lfp_ts)
print(intp_stim_timestamps)

3 arrays:
stim_timestamps
lfp.timestamps
lfp.data

goal: using lfp.timestamps, interpolate every lfp.data element to a stim_timestamps element
    iterpolate lfp.timestamps to7 stim_timestamps to create interpolated_timestamps (the main axis) which should be same length as lfp.data
    then use interpolated_timestamps as display axis for lfp data
    view window can then be filtered based on this axis
    


1. align stimulus and lfp
2. identify events within lfp
3. get window of lfp around event
4. split lfp eventwise into 3d array
5. calculate average across events
6. calculate average across nodes and events

In [15]:
stop_time = min(lfp.timestamps[-1],stim_timestamps[-1])
time_axis = np.arange(0,stop_time,step=0.001)
print(len(time_axis))

4357237


In [16]:
lfp.timestamps.shape

(11252278,)

In [17]:
lfp.data.shape

(11252278, 89)

In [18]:
f = interpolate.interp1d(lfp.timestamps, lfp.data, axis=0, kind="nearest", fill_value="extrapolate")
intp_lfp = f(time_axis)
print((len(lfp.data)))
print(lfp.data[0:100])

print(len(intp_lfp))
print(intp_lfp[0:100])


11252278
[[ 5.55750012e-05  6.37650010e-05  3.97799995e-05 ... -1.95000007e-06
   7.99499958e-06  7.99499958e-06]
 [ 3.12000011e-05  3.72450013e-05 -1.71600004e-05 ... -1.17000000e-05
   1.38449996e-05 -1.17000002e-06]
 [ 7.27349980e-05  9.88649990e-05  5.92799988e-05 ... -2.00849991e-05
  -1.15049997e-05 -5.84999998e-06]
 ...
 [-1.81740004e-04 -1.41960001e-04 -2.02800002e-04 ... -4.17299998e-05
  -3.50999994e-06 -6.43500016e-06]
 [-1.82324991e-04 -1.60290001e-04 -2.15474996e-04 ... -3.25649999e-05
  -3.31499996e-06  1.20899995e-05]
 [-2.08650003e-04 -1.86809993e-04 -2.50574987e-04 ... -2.82750007e-05
   5.84999998e-06 -1.95000007e-06]]
4357237
[[ 5.5575e-05  6.3765e-05  3.9780e-05 ... -1.9500e-06  7.9950e-06
   7.9950e-06]
 [ 5.5575e-05  6.3765e-05  3.9780e-05 ... -1.9500e-06  7.9950e-06
   7.9950e-06]
 [ 5.5575e-05  6.3765e-05  3.9780e-05 ... -1.9500e-06  7.9950e-06
   7.9950e-06]
 ...
 [ 5.5575e-05  6.3765e-05  3.9780e-05 ... -1.9500e-06  7.9950e-06
   7.9950e-06]
 [ 5.5575e-05  6.3

In [None]:
intp_lfp_measurements = []
lfp_idx = 0
for ts in time_axis:
    nearest_lfp_ts = lfp.timestamps[lfp_idx]
    # find nearest neighbor to timestamp within lfp timestamps
    print(ts)
    while(nearest_lfp_ts - ts > 0):
        diff_t = nearest_lfp_ts - ts
        lfp_idx += 1
        nearest_lfp_ts = lfp.timestamps[lfp_idx]
        print(lfp_idx, nearest_lfp_ts)
    # save nearest neighbor timestamp
    print("saving",lfp_idx)
    intp_lfp_measurements.append(lfp.data[lfp_idx])
    break
print(intp_lfp_measurements)

## BREAK

In [None]:
### get closest corresponding lfp index to the stimulus timestamps

lfp_events = []
lfp_ts_idx = 0
for stim_ts in stim_timestamps:

    lfp_ts = lfp.timestamps[lfp_ts_idx]
    while(lfp_ts < stim_ts):
        lfp_ts_idx += 1
        lfp_ts = lfp.timestamps[lfp_ts_idx]
    lfp_ts_idx -= 1
    lfp_events.append(lfp_ts_idx)

print(lfp_events)

In [None]:
interval_start = -0.05
interval_end = 0.25

In [None]:
# get lfp data intervals for viewing

if interval_start > 0:
    raise ValueError("interval start must be non-positive")
if interval_end <= 0:
    raise ValueError("interval end must be positive")
    
lfp_event_intervals = []
for lfp_event in lfp_events:
    event_ts = lfp.timestamps[lfp_event]
    
    # iterate backward to get start time idx
    start_idx = lfp_event
    start_ts = lfp.timestamps[start_idx]
    while(start_ts > event_ts + interval_start):
        start_idx -= 1
        if start_idx < 0:
            break
        start_ts = lfp.timestamps[start_idx]
    
    # iterate forward to get end time idx
    end_idx = lfp_event
    end_ts = lfp.timestamps[end_idx]
    while(end_ts < event_ts + interval_end):
        end_idx += 1
        if end_idx >= len(lfp.timestamps):
            break
        end_ts = lfp.timestamps[end_idx]
    
    if start_idx >= end_idx:
        raise ValueError("interval too small")
        
    lfp_event_intervals.append((start_idx, end_idx))
    print(f"Interval for event at {lfp_event}: {start_idx},{end_idx}")

In [None]:
node = 1
event_num = 100

In [None]:
# nodewise view
%matplotlib inline
fig, ax = plt.subplots(figsize=(8, 6))
start_idx, end_idx = lfp_event_intervals[event_num]
ax.plot(lfp.data[start_idx:end_idx,node])
plt.show()