CSD estimation is about twice as fast if you can use parameters from another file and avoid L-Curve. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings(action='once')

In [3]:
from sglxarray import load_trigger
from ecephys.xrsig import get_kcsd
import ecephys_analyses as ea

In [4]:
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.signal import find_peaks
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from datetime import datetime
import json

In [5]:
def plot_epoched_profile(da, figsize=(36, 8)):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(da, xticklabels=da.time.values.round(), yticklabels=da.channel.values, cbar=False)
    ax.set(xticks=ax.get_xticks()[::2], yticks=ax.get_yticks()[::4], xlabel='Epoch center time (s)', ylabel='Channel')

In [6]:
def plot_profile(da, figsize=(36, 5), ylabel=None, negative_peaks=False):
    fig, ax = plt.subplots(figsize=figsize)
    channel_indices = np.arange(len(da.channel))
    sns.barplot(x=channel_indices, y=da, color='steelblue', ax=ax)
    ax.set(xlabel="Channel", ylabel=ylabel)
    ax.set_xticks(channel_indices);
    ax.set_xticklabels(da.channel.values, rotation=90);
    
    if negative_peaks:
        peaks, _ = find_peaks(-da.values, prominence=1000, distance=10)
    else:
        peaks, _ = find_peaks(da.values, prominence=1000, distance=10)
    sns.scatterplot(x=peaks, y=da.values[peaks], marker='x', color='red', ax=ax)

In [7]:
def write_file_report(csd, epoch_length, pdf_path):
    
    if csd.time.values.max() < epoch_length:
        return
    
    epoched_csd_nadirs = csd.coarsen(time=int(csd.fs *epoch_length), boundary='trim', coord_func={"time": "min"}).min()
    epoched_csd_variance = csd.coarsen(time=int(csd.fs *epoch_length), boundary='trim', coord_func={"time": "min"}).var()
    
    pdf_path = pdf_path.parent / (pdf_path.stem + f'-{epoch_length}s' + pdf_path.suffix)
    with PdfPages(pdf_path) as pdf:
        plot_epoched_profile(epoched_csd_variance)
        plt.title(f'CSD variance, {epoch_length}s epochs')
        pdf.savefig()
        plt.close()
        
        plot_epoched_profile(-epoched_csd_nadirs)
        plt.title(f'CSD nadirs, {epoch_length}s epochs')
        pdf.savefig()
        plt.close()

        for epoch in range(epoched_csd_nadirs.shape[1]):
            plot_profile(epoched_csd_variance[:, epoch], ylabel="CSD variance")
            plt.title(f'CSD variance, epoch: {epoch}, time: {epoched_csd_nadirs.time.values[epoch]}s')
            pdf.savefig()
            plt.close()
            
            plot_profile(epoched_csd_nadirs[:, epoch], ylabel="Deepest CSD sink (mA/mm)", negative_peaks=True)
            plt.title(f'CSD nadirs, epoch: {epoch}, time: {epoched_csd_nadirs.time.values[epoch]}s')
            pdf.savefig()
            plt.close()

        d = pdf.infodict()
        d['Title'] = f'{pdf_path.stem}'
        d['Author'] = 'Graham Findlay'
        d['Subject'] = 'CSD nadir profiles for tracking CA1 drift'
        d['CreationDate'] = datetime.now()

In [8]:
def write_condition_reports(subject, experiment, probe, use_spw_params=True):
    drift_tracking_chans = ea.get_channels(subject, experiment, probe, "drift_tracking")

    if use_spw_params:
        spw_params_path = ea.get_experiment_file("sharp_wave_detection_params.json", experiment, subject)
        #spw_params_path = ea.get_datapath("sharp_wave_detection_params.json", subject, experiment)
        with open(spw_params_path) as spw_params_file:
            spw_params = json.load(spw_params_file)
            csd_params = spw_params["csd_params"]
    else:
        internal_reference = ea.get_channels(
            subject, experiment, probe, "internal_reference"
        )
    
    bin_paths = ea.get_lfp_bin_paths(subject, experiment, probe=probe)
    #bin_paths = ea.get_sglx_style_datapaths(subject, experiment, condition, ext="lf.bin")
    pdf_paths = ea.get_analysis_counterparts(bin_paths, "CSD_SR_markers.pdf", subject)
    #pdf_paths = ea.get_sglx_style_datapaths(subject, experiment, condition, ext="CSD_SR_markers.pdf")

    for bin_path, pdf_path in zip(bin_paths, pdf_paths):
        sig = load_trigger(bin_path, drift_tracking_chans)
        
        if use_spw_params:
            csd = get_kcsd(
                sig,
                np.arange(0, len(sig.channel)) * csd_params["electrode_pitch"],
                drop_chans=csd_params["channels_omitted_from_csd_estimation"],
                do_lcurve=False,
                gdx=csd_params["gdx"],
                R_init=csd_params["R"],
                lambd=csd_params["lambd"],
            ).swap_dims({"pos": "channel"})
        else:
            electrode_pitch = 0.020
            csd = get_kcsd(
                sig,
                np.arange(len(sig.channel)) * electrode_pitch,
                drop_chans=internal_reference,
                do_lcurve=True,
                gdx=electrode_pitch,
            ).swap_dims({"pos": "channel"})

        write_file_report(csd, 300, pdf_path)
        write_file_report(csd, 60, pdf_path)

        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"{current_time}: Finished {str(bin_path)}")

In [None]:
write_condition_reports("Adrian", "novel_objects_deprivation", "imec1")