In [1]:
%load_ext autoreload
%autoreload 2

# Imports and definitions

In [2]:
import numpy as np
import pandas as pd
from datetime import datetime
import xarray as xr
from pathlib import Path

In [3]:
from ecephys_analyses.data import channel_groups, paths, load
from ecephys.sglx_utils import load_timeseries
import ecephys.signal.timefrequency as tfr

In [4]:
xr.set_options(keep_attrs=True)

<xarray.core.options.set_options at 0x7fd4b1adf0d0>

In [5]:
def get_spectrogram(sig, **kwargs):
    nperseg = int(4 * sig.fs) # 4 second window
    noverlap = nperseg // 4 # 1 second overlap
    freqs, spg_times, spg_data = tfr.parallel_spectrogram_welch(sig.values, sig.fs, nperseg=nperseg, noverlap=noverlap, **kwargs)
    spg_times = spg_times + sig.time.values.min()
    spg = xr.DataArray(
        spg_data,
        dims=("frequency", "time", "channel"),
        coords={"frequency": freqs, "time": spg_times, "channel": sig.channel.values},
        attrs={'units': f"{sig.units}^2/Hz", 'file_start': sig.fileCreateTime}
    )
    return spg.median(dim="channel").sel(frequency=slice(0, 300))

In [6]:
def get_condition_spgs(subject, experiment, condition):
    cx_chans = channel_groups.superficial_ctx[subject]
    wm_chans = channel_groups.white_matter[subject]
    
    bin_paths = paths.get_sglx_style_datapaths(subject, experiment, condition, "lf.bin")
    sr_chans_paths = paths.get_sglx_style_datapaths(subject, experiment, condition, "sr_chans.csv")
    spg_paths = paths.get_sglx_style_datapaths(subject, experiment, condition, "spg2.nc")
    
    for bin_path, sr_chans_path, spg_path in zip(bin_paths, sr_chans_paths, spg_paths):
        sr_chans_df = load.load_sr_chans(sr_chans_path)
        epochs = list()
        for epoch in sr_chans_df.itertuples():
            if not epoch.sr_chans:
                continue
            cx = load_timeseries(bin_path, cx_chans, start_time=epoch.start_time, end_time=epoch.end_time)
            wm = load_timeseries(bin_path, wm_chans, start_time=epoch.start_time, end_time=epoch.end_time)
            sr = load_timeseries(bin_path, epoch.sr_chans, start_time=epoch.start_time, end_time=epoch.end_time)

            cx_wm_ref = cx - wm.values
            sr_wm_ref = sr - wm.values

            spgs = xr.Dataset({'mpta_wm_ref': get_spectrogram(cx_wm_ref), 
                               'sr_wm_ref': get_spectrogram(sr_wm_ref)})
            epochs.append(spgs)
            
        if not epochs:
            continue
        spgs = xr.concat(epochs, dim="time")
        spgs.attrs['file_start'] = cx.fileCreateTime
        Path(spg_path).parent.mkdir(parents=True, exist_ok=True) # Create parent directories if they do not already exist.
        spgs.to_netcdf(spg_path)
        spgs.close()
        
        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"{current_time}: Finished {str(bin_path)}")