In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import json 
import ecephys_analyses as ea
from sglxarray import load_contiguous_triggers, load_trigger
from ecephys.xrsig import get_kcsd
from ecephys.utils import zscore_to_value

In [3]:
def get_spw_detection_parameters(subject, experiment, alias, probe, files, detection_threshold_zscore=2.5, boundary_threshold_zscore=1):
    sr_chans = ea.get_channels(subject, experiment, probe, 'stratum_radiatum')
    hpc_chans = ea.get_channels(subject, experiment, probe, 'hippocampus')
    internal_reference = ea.get_channels(subject, experiment, probe, 'internal_reference')
    
    bin_paths = ea.get_lfp_bin_paths(subject, experiment, alias, probe=probe)[files]
    #bin_paths = ea.get_sglx_style_datapaths(subject, experiment, condition, "lf.bin")[files]
    
    params_path = ea.get_experiment_file("sharp_wave_detection_params.json", experiment, subject)
    #params_path = ea.get_datapath("sharp_wave_detection_params.json", subject, experiment)
    
    sig = load_contiguous_triggers(bin_paths, hpc_chans)
        
    electrode_pitch = 0.020
    ele_pos = np.arange(0, len(sig.channel)) * electrode_pitch
    csd = get_kcsd(sig, ele_pos, drop_chans=internal_reference, do_lcurve=True, gdx=electrode_pitch)
    
    sr_csd = csd.swap_dims({'pos': 'channel'}).sel(channel=sr_chans)
    combined_csd = -sr_csd.sum(dim='channel')
    detection_threshold = zscore_to_value(combined_csd.values, detection_threshold_zscore)
    boundary_threshold = zscore_to_value(combined_csd.values, boundary_threshold_zscore)
    
    csd_params = dict(
        files_used_for_parameter_estimation=[str(path) for path in bin_paths],
        electrode_pitch=electrode_pitch,
        xmin=csd.kcsd.xmin,
        xmax=csd.kcsd.xmax,
        n_estm=csd.kcsd.n_estm,
        gdx=csd.kcsd.gdx,
        lambd=csd.kcsd.lambd,
        R=csd.kcsd.R,
        csd_channels=hpc_chans.tolist(),
        channels_omitted_from_csd_estimation=internal_reference.tolist(),
        ele_pos=csd.kcsd.ele_pos.tolist(),
    )
    spw_params = dict(
        files_used_for_parameter_estimation=[str(path) for path in bin_paths],
        detection_threshold_zscore=detection_threshold_zscore,
        boundary_threshold_zscore=boundary_threshold_zscore,
        detection_threshold=detection_threshold,
        boundary_threshold=boundary_threshold,
        minimum_duration=0.005,
        detection_chans=sr_chans.tolist(),
        csd_params=csd_params
    )
    
    params_path.parent.mkdir(parents=True, exist_ok=True)
    with open(params_path, "x") as params_file:
        json.dump(spw_params, params_file, indent=4)

In [4]:
get_spw_detection_parameters("Adrian", "novel_objects_deprivation", "recovery_sleep", "imec1", slice(None, 1))

nChan: 385, nFileSamp: 18000083
Performing L-curve parameter estimation...
No lambda given, using defaults
min lambda 1e-12
max lambda 0.0126
min lambda 1e-12
max lambda 0.0126
l-curve (all lambda):  0.23
Best lambda and R =  0.0011704755183051712 ,  0.23
