In [1]:
%config IPCompleter.use_jedi = False
%pdb off
# %load_ext autoreload
# %autoreload 2
import sys
import traceback # for stack trace formatting
import importlib
from pathlib import Path
from benedict import benedict
import numpy as np

# required to enable non-blocking interaction:
# %gui qt
# !env QT_API="pyqt5"
%gui qt5
# %gui qt6
# from PyQt5.Qt import QApplication
# # start qt event loop
# _instance = QApplication.instance()
# if not _instance:
#     _instance = QApplication([])
# app = _instance

from copy import deepcopy
from numba import jit
import numpy as np
import pandas as pd
from benedict import benedict # https://github.com/fabiocaccamo/python-benedict#usage

# Pho's Formatting Preferences
# from pyphocorehelpers.preferences_helpers import set_pho_preferences, set_pho_preferences_concise, set_pho_preferences_verbose
# set_pho_preferences_concise()

## Pho's Custom Libraries:
from pyphocorehelpers.general_helpers import CodeConversion
from pyphocorehelpers.print_helpers import print_keys_if_possible, print_value_overview_only, document_active_variables

# pyPhoPlaceCellAnalysis:
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import NeuropyPipeline # get_neuron_identities

# NeuroPy (Diba Lab Python Repo) Loading
# from neuropy import core
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.core.epoch import NamedTimerange
from neuropy.core.session.Formats.BaseDataSessionFormats import DataSessionFormatRegistryHolder
from neuropy.core.session.Formats.Specific.BapunDataSessionFormat import BapunDataSessionFormatRegisteredClass
from neuropy.core.session.Formats.Specific.KDibaOldDataSessionFormat import KDibaOldDataSessionFormatRegisteredClass
from neuropy.core.session.Formats.Specific.RachelDataSessionFormat import RachelDataSessionFormat
from neuropy.core.session.Formats.Specific.HiroDataSessionFormat import HiroDataSessionFormatRegisteredClass

## For computation parameters:
from neuropy.analyses.placefields import PlacefieldComputationParameters
from neuropy.utils.dynamic_container import DynamicContainer
from neuropy.utils.result_context import IdentifyingContext
from neuropy.core.session.Formats.BaseDataSessionFormats import find_local_session_paths

# from PendingNotebookCode import _perform_batch_plot, _build_batch_plot_kwargs
from pyphoplacecellanalysis.General.NonInteractiveWrapper import batch_load_session, batch_extended_computations, SessionBatchProgress, batch_programmatic_figures, batch_extended_programmatic_figures
from pyphoplacecellanalysis.General.Pipeline.NeuropyPipeline import PipelineSavingScheme

session_batch_status = {}
session_batch_errors = {}
enable_saving_to_disk = False

global_data_root_parent_path = Path(r'W:\Data') # Windows Apogee
# global_data_root_parent_path = Path(r'/media/MAX/Data') # Diba Lab Workstation Linux
# global_data_root_parent_path = Path(r'/Volumes/MoverNew/data') # rMBP
assert global_data_root_parent_path.exists(), f"global_data_root_parent_path: {global_data_root_parent_path} does not exist! Is the right computer's config commented out above?"

Automatic pdb calling has been turned OFF
build_module_logger(module_name="Spike3D.pipeline"):
	 Module logger com.PhoHale.Spike3D.pipeline has file logging enabled and will log to EXTERNAL\TESTING\Logging\debug_com.PhoHale.Spike3D.pipeline.log


# Load Pipeline

In [2]:
# ==================================================================================================================== #
# Load Data                                                                                                            #
# ==================================================================================================================== #

active_data_mode_name = 'kdiba'

## Data must be pre-processed using the MATLAB script located here: 
#     neuropy/data_session_pre_processing_scripts/KDIBA/IIDataMat_Export_ToPython_2022_08_01.m
# From pre-computed .mat files:

local_session_root_parent_context = IdentifyingContext(format_name=active_data_mode_name) # , animal_name='', configuration_name='one', session_name=self.session_name
local_session_root_parent_path = global_data_root_parent_path.joinpath('KDIBA')

## Animal `gor01`:
local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='gor01', exper_name='one') # IdentifyingContext<('kdiba', 'gor01', 'one')>
local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name) # 'gor01', 'one'
local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=['PhoHelpers', 'Spike3D-Minimal-Test', 'Unused'])

# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='gor01', exper_name='two')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

### Animal `vvp01`:
# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='vvp01', exper_name='one')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

# local_session_parent_context = local_session_root_parent_context.adding_context(collision_prefix='animal', animal='vvp01', exper_name='two')
# local_session_parent_path = local_session_root_parent_path.joinpath(local_session_parent_context.animal, local_session_parent_context.exper_name)
# local_session_paths_list, local_session_names_list =  find_local_session_paths(local_session_parent_path, blacklist=[])

## Build session contexts list:
local_session_contexts_list = [local_session_parent_context.adding_context(collision_prefix='sess', session_name=a_name) for a_name in local_session_names_list] # [IdentifyingContext<('kdiba', 'gor01', 'one', '2006-6-07_11-26-53')>, ..., IdentifyingContext<('kdiba', 'gor01', 'one', '2006-6-13_14-42-6')>]

## Initialize `session_batch_status` with the NOT_STARTED status if it doesn't already have a different status
for curr_session_basedir in local_session_paths_list:
    curr_session_status = session_batch_status.get(curr_session_basedir, None)
    if curr_session_status is None:
        session_batch_status[curr_session_basedir] = SessionBatchProgress.NOT_STARTED # set to not started if not present
        # session_batch_status[curr_session_basedir] = SessionBatchProgress.COMPLETED # set to not started if not present

session_batch_status

local_session_names_list: ['2006-6-07_11-26-53', '2006-6-08_14-26-15', '2006-6-09_1-22-43', '2006-6-09_3-23-37', '2006-6-12_15-55-31', '2006-6-13_14-42-6']


{WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-07_11-26-53'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>,
 WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-08_14-26-15'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>,
 WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-09_1-22-43'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>,
 WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-09_3-23-37'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>,
 WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-12_15-55-31'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>,
 WindowsPath('W:/Data/KDIBA/gor01/one/2006-6-13_14-42-6'): <SessionBatchProgress.NOT_STARTED: 'NOT_STARTED'>}

# Single basedir (non-batch) testing:

In [3]:
%pdb off
basedir = local_session_paths_list[0] # NOT 3
print(f'basedir: {str(basedir)}')

# ==================================================================================================================== #
# Load Pipeline                                                                                                        #
# ==================================================================================================================== #
# curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, saving_mode=PipelineSavingScheme.TEMP_THEN_OVERWRITE, force_reload=True, skip_extended_batch_computations=True)
curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, saving_mode=PipelineSavingScheme.SKIP_SAVING, force_reload=True, skip_extended_batch_computations=True, debug_print=False)
# curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, saving_mode=PipelineSavingScheme.SKIP_SAVING, force_reload=True, skip_extended_batch_computations=True) # temp no-save
## SAVE AFTERWARDS!

# curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, saving_mode=PipelineSavingScheme.SKIP_SAVING, force_reload=False, active_pickle_filename='20221214200324-loadedSessPickle.pkl', skip_extended_batch_computations=True)
# curr_active_pipeline = batch_load_session(global_data_root_parent_path, active_data_mode_name, basedir, saving_mode=PipelineSavingScheme.SKIP_SAVING, force_reload=False, active_pickle_filename='loadedSessPickle - full-good.pkl', skip_extended_batch_computations=True)

Automatic pdb calling has been turned OFF
basedir: W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53
Skipping loading from pickled file because force_reload == True.
Loading matlab import file results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.epochs_info.mat... done.
Loading matlab import file results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.position_info.mat... done.
Loading matlab import file results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.spikes.mat... 



done.
Failure loading .position.npy. Must recompute.

Computing linear positions for all active epochs for session... Saving updated position results results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.position.npy... 2006-6-07_11-26-53.position.npy saved
done.
	 force_recompute is True! Forcing recomputation of .interpolated_spike_positions.npy

Computing interpolate_spike_positions columns results : spikes_df... done.
	 Saving updated interpolated spike position results results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.interpolated_spike_positions.npy... 2006-6-07_11-26-53.interpolated_spike_positions.npy saved
done.
Loading matlab import file results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.laps_info.mat... done.
setting laps object.
session.laps loaded successfully!
Loading matlab import file results : W:\Data\KDIBA\gor01\one\2006-6-07_11-26-53\2006-6-07_11-26-53.replay_info.mat... done.
session.replays loaded success

ImportError: cv2 must be installed manually. Try to: <pip install opencv-python>

# Future: theta-dependent placefields: build separate placefields for each phase of theta (binned in theta). There should be one set (where the animal is representing the present) that nearly perfectly predicts the animal's location.
    # the rest of the variability 

    1. Basic Hilbert transform
    2. But Theta wave-shape (sawtooth) at higher running speeds.
        - do peak-to-trough and trough-to-peak separate
        ** Nat will send me something
        
- remember Eloy's theta-dependent placefields. I'm ashamed that I fucked up with Eloy.


# Imports

https://github.com/diba-lab/ephys/blob/master/Analysis/python/LFP/scripts/theta_phase_stim_verify.py
Nat's code for detecting the sawtooth theta is here (lines 271-393ish): https://github.com/diba-lab/ephys/blob/master/Analysis/python/LFP/scripts/theta_phase_stim_verify.py

It's all based on this paper: https://www.jneurosci.org/content/32/2/423

In [None]:
## Scratchpad for opto
import matplotlib.pyplot as plt
import numpy as np
# import Analysis.python.LFP.preprocess_data as pd

import scipy.signal as signal
import pickle
import os
# import Analysis.python.LFP.helpers as helpers

## LFP analysis functions from https://github.com/diba-lab/ephys/blob/master/Analysis/python/LFP/lfp_analysis.py

# instead of `import Analysis.python.LFP.lfp_analysis as lfp`
class lfp(object):
    ## Create Butterworth filter - copied from scipy-cookbook webpage
    @staticmethod
    def butter_bandpass(lowcut, highcut, fs, order=2):
        """
        Simplify inputs for creating a Butterworth filter. copied from scipy-cookbook webpage.
        :param lowcut: Hz
        :param highcut: Hz
        :param fs: Sampling rate in Hz
        :param order: (optional) 2 = default
        :return:
        """
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        b, a = signal.butter(order, [low, high], btype='band')

        return b, a


    ## filter data through butterworth filter
    @staticmethod
    def butter_bandpass_filter(data, lowcut, highcut, fs, type='filtfilt', order=2):
        """
        Filter data through butterworth bandpass filter. Copied from scipy-cookbook webpage.
        :param data: array of data sampled at fs
        :param lowcut: 4
        :param highcut: 10
        :param type: 'filtfilt' (default) filters both ways, 'lfilt' filters forward only (and likely induces a phase offset).
        :param fs: 30000
        :param order: (optional) default = 2 to match Sieglie et al., eLife (2014)
        :return: filt_data: filtered data
        """
        b, a = butter_bandpass(lowcut, highcut, fs, order=order)
        if type == 'lfilt':
            filt_data = signal.lfilter(b, a, data)
        elif type == 'filtfilt':
            filt_data = signal.filtfilt(b, a, data)

        return filt_data


    ## Peak-trough detection via Belluscio et al. (2012) J. Neuro
    @staticmethod
    def get_local_extrema(trace, type='max'):
        """ Get local extrema, assuming it occurs near the middle of the trace. spits out an np.nan if there a relative min
        or max occurs at the edge.
        :param trace: lfp trace
        :param type: 'max' (default) or 'min'
        :return: index in trace where max/min is located. np.nan if there is a relative minima/maxima at edge of trace.
        """
        if type == 'max':
            temp = signal.argrelmax(trace, order=int(len(trace)/2))[0]
        elif type == 'min':
            temp = signal.argrelmin(trace, order=int(len(trace)/2))[0]

        if temp.size == 1:
            ind_rel_extreme = temp[0]
        else:
            ind_rel_extreme = np.nan

        return ind_rel_extreme

    
from neuropy.analyses import oscillations
## Plot trace in a nice working window

In [None]:
%pdb off

In [None]:
lfpFile = curr_active_pipeline.sess.eegfile # neuropy.io.binarysignalio.BinarysignalIO
traces = lfpFile.get_signal()

In [None]:
ripple_epochs = oscillations.detect_ripple_epochs(traces, curr_active_pipeline.sess.probegroup)
ripple_epochs

In [None]:
traces = lfpFile.get_signal(channel_indx=19)
traces

In [None]:
lfpFile.n_channels

In [None]:
lfpFile.n_frames

In [None]:
curr_active_pipeline.sess

In [None]:
chan_plot = 19  # channel you triggered off of
artifact_chan = 13  # this channel should have good stimulation artifact on it for reference...

trace = traces_ds[plot_bool, chan_plot]
trace_lfilt = lfp.butter_bandpass_filter(trace, lowcut, highcut, SRlfp, order=order, type='lfilt')
trace_filtfilt = lfp.butter_bandpass_filter(trace, lowcut, highcut, SRlfp, order=order, type='filtfilt')

In [None]:
## Peak-trough method (Belluscio et al. 2012 J Neuro) - fold into lfp_analysis.peak_trough_detect eventually
# from Nat's https://github.com/diba-lab/ephys/blob/master/Analysis/python/LFP/scripts/theta_phase_stim_verify.py

## Needs: trace, SRlfp, order
order = 2

lowcut_bell = 1  # Hz
highcut_bell = 80  # Hz
peak_trough_offset_sec = 0.07  # seconds to look for trough of wide-filtered trace next to 4-10Hz filtered trace

wide_filt = lfp.butter_bandpass_filter(trace, lowcut_bell, highcut_bell, SRlfp, order=order)

fig, ax = plt.subplots(1, 1, sharex=True, sharey=True)
fig.set_size_inches([26, 3])
hraw = ax.plot(time_plot, trace)
ax.plot(time_plot, wide_filt, 'm')
ax.plot(time_plot, trace_lfilt, 'k--')
ax.set_xlim([start_time*60, start_time*60 + time_span])
ax.set_ylim([-v_range, v_range])
ax.set_xlabel(['Time (s)'])
ax.set_ylabel('uV')

offset_frames = np.round(peak_trough_offset_sec*SRlfp)

# First detect peak and trough off narrowband filtered signal - do hilbert transform
# trough = -pi->pi, peak = 0 (- -> +)
trace_analytic = signal.hilbert(trace_lfilt)  # get real and imaginary parts of signal
trig_trace_phase = np.angle(trace_analytic)
# ax.plot(time_plot, trig_trace_phase*v_range/8, 'r-')
peak_bool = np.bitwise_and(trig_trace_phase[0:-1] < 0, trig_trace_phase[1:] >= 0)
peak_bool = np.append(peak_bool, False)
trough_bool = np.bitwise_and(trig_trace_phase[0:-1] > 0, trig_trace_phase[1:] <= 0)
trough_bool = np.append(trough_bool, False)

# Indices to peak and trough of narrowband trace
peak_inds = np.where(peak_bool)[0]
trough_inds = np.where(trough_bool)[0]

# Check that above code works...
# ax.plot(time_plot[peak_bool], trace_lfilt[peak_bool], 'r*')
# ax.plot(time_plot[trough_bool], trace_lfilt[trough_bool], 'g*')

##  Plot times between peak and trough - seems likes looking 0.07 seconds to either side should be ok...
fig2, ax2 = plt.subplots(1, 2)
ax2[0].hist(np.diff(np.where(trough_bool))[0]/SRlfp)
ax2[0].set_xlabel('Trough-to-trough times (s)')
ax2[1].hist(np.diff(np.where(peak_bool))[0]/SRlfp)
ax2[1].set_xlabel('Peak-to-peak times (s)')

## now step through and find closest peak/trough in the wide-filtered trace when compared to the narrowband filtered trace.
# THIS IS ALL COMMENTED NOW SO THAT YOU DONT ACCIDENTALLY OVERWRITE EXISTING VALUES - NEED TO IMPLEMENT DOWNSAMPLING FIRST!!!
wide_peak_inds = []
wide_trough_inds = []

# Step through and look for each trough in the WIDE filtered signal between two peaks in the NARROW filtered signal
# how fast is this compared to just running it on all the trace and looking for closest inds? Bet it depends on if I
# downsample first...

n = 0
for idp, idp1 in zip(peak_inds[0:-1], peak_inds[1:]):
    wide_trough_inds.append(lfp.get_local_extrema(wide_filt[idp:idp1], type='min') + idp)
    n = n + 1
    if int(n/100) == n/100:
        print(n)

n = 0
# Ditto to above but for peaks
for idt, idt1 in zip(trough_inds[0:-1], trough_inds[1:]):
    wide_peak_inds.append(lfp.get_local_extrema(wide_filt[idt:idt1], type='max') + idt)
    n = n + 1
    if int(n/100) == n/100:
        print(n)

## looks decent except when there is crappy theta. Filter out these epochs? Put on speed threshold?
wide_peak_inds_good = [idp for idp in wide_peak_inds if not np.isnan(idp)]
wide_trough_inds_good = [idt for idt in wide_trough_inds if not np.isnan(idt)]

ax.plot(time_plot[wide_peak_inds_good], wide_filt[wide_peak_inds_good], 'ro')
ax.plot(time_plot[wide_trough_inds_good], wide_filt[wide_trough_inds_good], 'go')

## Get rise and falling times of theta - trough = -pi/+pi, peak = 0

# if peak times generally lead trough times, chop off first peak value
if np.nanmean(np.array(wide_peak_inds) - np.array(wide_trough_inds)) < 0:
    peak_inds_use = wide_peak_inds[1:]
    trough_inds_use = wide_trough_inds[0:-1]
    next_trough_inds = wide_trough_inds[1:]
else:
    peak_inds_use = wide_peak_inds
    trough_inds_use = wide_trough_inds
    next_trough_inds = wide_trough_inds[1:]


wave_phase_inds = []
wave_phases = []
for idt, idp, idt1 in zip(trough_inds_use, peak_inds_use, next_trough_inds):
    if not np.any(np.isnan([idt, idp, idt1])) and idt < idp < idt1:  # only run below if you have reliable peak/trough info
        trace_snippet = wide_filt[idt:idt1]  # grab a snippet of the trace to use
        if np.all(trace_snippet <= 0) or np.all(trace_snippet >= 0) or trace_snippet[0] > 0 or trace_snippet[-1] > 0\
                or wide_filt[idp] < 0:  # Make sure trace is not all above or below zero and that peak/troughs are above/below zero
            wave_phase_inds.extend([np.nan, np.nan, np.nan, np.nan, np.nan])
        else:
            rise_zero = np.max(np.where(np.bitwise_and(trace_snippet <= 0, np.arange(idt, idt1) < idp))[0])
            fall_zero = np.min(np.where(np.bitwise_and(trace_snippet <= 0, np.arange(idt, idt1) > idp))[0])
            wave_phase_inds.extend([idt, idt + rise_zero, idp, idt + fall_zero, idt1-1])
    else:
        wave_phase_inds.extend([np.nan, np.nan, np.nan, np.nan, np.nan])
    wave_phases.extend([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
    

In [None]:
# now plot
wave_phase_inds_good = [idph for idph in wave_phase_inds if not np.isnan(idph)]
wave_phases_good = [ph for ph, idph in zip(wave_phases, wave_phase_inds) if not np.isnan(idph)]
ax.plot(time_plot[wave_phase_inds_good], np.asarray(wave_phases_good)*v_range/8, 'r-')


## histogram of rise times vs fall times overlaid to prove I'm doing things correctly
fig35, ax35 = plt.subplots(1, 2)
fig35.set_size_inches([13.5, 4.8])
rise_times = (np.array(peak_inds_use) - np.array(trough_inds_use))/SRlfp
fall_times = (np.array(next_trough_inds) - np.array(peak_inds_use))/SRlfp
ax35[0].hist(rise_times, bins=20, range=(-0.15, 0.3))
ax35[0].set_title('Peak-Trough Method')
ax35[0].set_xlabel('Rising Phase Times (s)')
ax35[0].text(0.15, 1000, 'mean = ' + '{:.3f}'.format(np.nanmean(rise_times)) + ' sec')
ax35[1].hist(fall_times, bins=20, range=(-0.15, 0.3))
ax35[1].set_title('Peak-Trough Method')
ax35[1].set_xlabel('Falling Phase Times (s)')
ax35[1].text(0.15, 1000, 'mean = ' + '{:.3f}'.format(np.nanmean(fall_times)) + ' sec')
