In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import os
import cupy as cp
import numpy as np
import datajoint as dj
import spyglass as nd
import pandas as pd
import matplotlib.pyplot as plt
import multiprocessing
import pandas as pd

import pynwb

# ignore datajoint+jupyter async warnings
import warnings
warnings.simplefilter('ignore', category=DeprecationWarning)
warnings.simplefilter('ignore', category=ResourceWarning)

from spyglass.common import (Session, IntervalList,LabMember, LabTeam, Raw, Nwbfile,
                            Electrode,StateScriptFile)

from spyglass.common.nwb_helper_fn import get_nwb_copy_filename
from spyglass.common.common_task import TaskEpoch

# Here are the analysis tables specific to Shijie Gu
from spyglass.shijiegu.Analysis_SGU import TrialChoice,RippleTimes

In [4]:
import matlab.engine
eng=matlab.engine.start_matlab()

In [5]:
# MATLAB path
eng.addpath('/home/shijiegu/Documents/MATLAB/radial_sequence', nargout=0)

In [6]:
def translate_time(trodes_sample_time,sample_count,time_seconds):
    '''
    INPUT:
    trodes_sample_time, (n,), trodes time in sample count to be translated to system time in seconds
    sample_count: numpy array, (N,), trodes time in sample count for the whole recording
    time_seconds: numpy array, (N,), system time in seconds for the whole recording
    see also MATLAB counterpart translate_time
    
    RETURN: translated_sys_time, (n,), system time in seconds for inquired trodes sample time
    
    '''
    nan_ind=np.isnan(trodes_sample_time)
    xy,ind1,ind2=np.intersect1d(trodes_sample_time[~nan_ind],sample_count,return_indices=True)
    assert np.sum(~nan_ind)==len(ind2)
    translated_sys_time=np.zeros_like(trodes_sample_time)+np.nan
    translated_sys_time[~nan_ind]=time_seconds[ind2]
    
    return translated_sys_time*10**-9

    '''
    translated_sys_time=np.zeros(len(trodes_sample_time))+np.nan;
    for i in range(len(trodes_sample_time)):
        if not np.isnan(trodes_sample_time[i]):
            ind=np.argwhere(trodes_sample_time[i]<=sample_count).ravel()[0];
            translated_sys_time[i]=time_seconds[ind];
    '''

### Input nwb file, look at the epoch names

In [62]:
# the only cell to be edited
nwb_file_name = 'molly20220416.nwb'
nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name)

In [63]:
nwb_file_abs_path = (Nwbfile & {'nwb_file_name':nwb_copy_file_name}
                     ).fetch1('nwb_file_abs_path')
io = pynwb.NWBHDF5IO(nwb_file_abs_path,'r')
nwbf = io.read()

sample_count=np.array(nwbf.processing['sample_count'].data_interfaces['sample_count'].data)
time_seconds=np.array(nwbf.processing['sample_count'].data_interfaces['sample_count'].timestamps)

  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)
  arr = numpy.ndarray(selection.mshape, dtype=new_dtype)
  warn("Length of data does not match length of timestamps. Your data may be transposed. Time should be on "


In [64]:
def parse_behavior(nwb_name,epoch,time_seconds,sample_count):
    # get epoch name
    epoch_name=(TaskEpoch() &
                {'nwb_file_name':nwb_name,
                 'epoch':epoch}).fetch1('interval_list_name')
    
    # get the statescript file content
    ssfile=(StateScriptFile & {'nwb_file_name':nwb_name,
                   'epoch':epoch}).fetch_nwb()
    sscontent=ssfile[0]['file'].content

    # parse statescript: result is in LOG or SSLOG_dict
    [log,variablename]=eng.parse_behavior4python(sscontent,nargout=2)
    log_np=np.array(log)
    LOG=log_np.copy()
    LOG[:,0]=log_np[:,0]*30 # from ms to Trodes sample time: sample = ms * 30 sample /ms 
    LOG[:,2]=log_np[:,2]*30
    
    # get the session start time in seconds
    # Trodes re-start counts of sample at some point.
    # restrict to start time and end time allows us to go around the problem
    start_time=(IntervalList & {'nwb_file_name':nwb_copy_file_name,
                'interval_list_name':epoch_name}).fetch1('valid_times')[0][0]
    end_time=(IntervalList & {'nwb_file_name':nwb_copy_file_name,
                'interval_list_name':epoch_name}).fetch1('valid_times')[-1][-1]
    
    session_ind=np.logical_and(time_seconds>=start_time*10**9,
                               time_seconds<=end_time*10**9)
    sample_count_session=sample_count[session_ind]
    time_seconds_session=time_seconds[session_ind]
    
    # from Trodes sample time to seconds
    LOG[:,0]=translate_time(LOG[:,0],sample_count_session,time_seconds_session)
    LOG[:,2]=translate_time(LOG[:,2],sample_count_session,time_seconds_session)
    
    # Dataframe
    SSLOG=pd.DataFrame(LOG,columns=variablename,index=np.arange(log_np.shape[0])+1)
    
    # Dataframe to dictionary because Datajoint does not allow pd dataframe
    SSLOG_dict=SSLOG.to_dict()
    return SSLOG_dict

In [65]:
epoch_num_name=(TaskEpoch() & {'nwb_file_name':nwb_copy_file_name}).fetch('epoch','interval_list_name')
epoch_name=epoch_num_name[1]
epoch_num=epoch_num_name[0]
epoch_name

array(['01_Seq2Sleep1', '02_Seq2Session1', '03_Seq2Sleep2',
       '04_Seq2Session2', '05_Seq2Sleep3', '06_Seq2Session3',
       '07_Seq2Sleep4', '08_Seq2Session4', '09_Seq2Sleep5',
       '10_Seq2Session5', '11_Seq2Sleep6'], dtype=object)

In [66]:
# find run epochs
epoch_num2insert=[]
for i in range(len(epoch_name)):
    n=epoch_name[i] #name of epoch
    if n[-8:-1].lower()=='session': #all lower case in case typo in data input
        epoch_num2insert.append(epoch_num[i])  
epoch_num2insert

[2, 4, 6, 8, 10]

In [67]:
# for all run epochs, insert parsed result into TrialChoice table
for e in epoch_num2insert:
    parsedlog=parse_behavior(nwb_copy_file_name,e,time_seconds,sample_count)
    key={'nwb_file_name':nwb_copy_file_name,'epoch':e,'choice_reward':parsedlog}
    TrialChoice().make(key,replace=True)

  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)
  arr = numpy.ndarray(selection.mshape, dtype=new_dtype)
  warn("Length of data does not match length of timestamps. Your data may be transposed. Time should be on "


Great. Ends meet.


  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)


Great. Ends meet.


  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)


Great. Ends meet.


  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)


Great. Ends meet.


  return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string)


Great. Ends meet.


In [68]:
TrialChoice() & {'nwb_file_name':nwb_copy_file_name}

nwb_file_name  name of the NWB file,epoch  the session epoch for this task and apparatus(1 based),"epoch_name  session name, get from IntervalList","choice_reward  pandas dataframe, choice"
molly20220416_.nwb,2,02_Seq2Session1,=BLOB=
molly20220416_.nwb,4,04_Seq2Session2,=BLOB=
molly20220416_.nwb,6,06_Seq2Session3,=BLOB=
molly20220416_.nwb,8,08_Seq2Session4,=BLOB=
molly20220416_.nwb,10,10_Seq2Session5,=BLOB=


### Confrim that the choice reward information is in ```TrialChoice```

In [51]:
logtest=(TrialChoice & {'nwb_file_name':'molly20220420_.nwb','epoch':10}).fetch1('choice_reward')

In [53]:
pd.DataFrame(logtest)

Unnamed: 0,timestamp_H,Home,timestamp_O,OuterWellIndex,rewardNum,current,future_H,future_O,past,past_reward
1,1.650492e+09,1.0,1.650492e+09,3.0,2.0,3.0,3.0,4.0,,
2,1.650492e+09,1.0,1.650492e+09,4.0,2.0,4.0,4.0,2.0,3.0,3.0
3,1.650492e+09,1.0,1.650492e+09,2.0,2.0,2.0,2.0,3.0,4.0,4.0
4,1.650492e+09,1.0,1.650492e+09,3.0,1.0,3.0,3.0,1.0,2.0,2.0
5,1.650492e+09,1.0,1.650492e+09,1.0,2.0,1.0,1.0,4.0,3.0,2.0
...,...,...,...,...,...,...,...,...,...,...
77,1.650494e+09,1.0,1.650494e+09,4.0,2.0,4.0,4.0,1.0,2.0,3.0
78,1.650494e+09,1.0,1.650494e+09,1.0,1.0,1.0,1.0,3.0,4.0,4.0
79,1.650494e+09,1.0,1.650494e+09,3.0,1.0,3.0,3.0,2.0,1.0,4.0
80,1.650494e+09,1.0,1.650494e+09,2.0,3.0,2.0,2.0,,3.0,4.0
