### Compute power spectra for the data

In [1]:
import sys

sys.path.insert(1, "/home/INT/lima.v/projects/phase_coupling_analysis")

In [2]:
import argparse
import os

import numpy as np
import xarray as xr
from tqdm import tqdm

from src.metrics.spectral import xr_psd_array_multitaper
from src.session import session

# from src.session import session
from src.util import _extract_roi, get_dates

###############################################################################
##### Functions
###############################################################################

In [3]:
def load_session_data(root, sid, monkey, align):
    # Instantiate class
    ses = session(
        raw_path=root,
        monkey=monkey,
        date=sid,
        session=1,
        slvr_msmod=False,
        only_unique_recordings=False,
        align_to=align,
        evt_dt=[-0.65, 2.0],
    )

    # Read data from .mat files
    ses.read_from_mat(verbose=False)

    # Load XYZ coordinates
    coords = np.concatenate(
        (ses.get_xy_coords(), ses.recording_info["depth"][:, None]), axis=1
    )

    # Filtering by trials
    data_task = ses.filter_trials(trial_type=[1], behavioral_response=[1])
    data_fixation = ses.filter_trials(trial_type=[2], behavioral_response=None)

    attrs_task, attrs_fixation = data_task.attrs, data_fixation.attrs

    stim = np.hstack((attrs_task["stim"], attrs_fixation["stim"]))
    t_cue_on = np.hstack((attrs_task["t_cue_on"], attrs_fixation["t_cue_on"]))
    t_cue_off = np.hstack((attrs_task["t_cue_off"], attrs_fixation["t_cue_off"]))
    t_match_on = np.hstack((attrs_task["t_match_on"], attrs_fixation["t_match_on"]))

    np.nan_to_num(stim, nan=6, copy=False)

    data = xr.concat((data_task, data_fixation), "trials")
    data.attrs = attrs_task
    data.attrs["stim"] = stim
    data.attrs["t_cue_on"] = t_cue_on
    data.attrs["t_cue_off"] = t_cue_off
    data.attrs["t_match_on"] = t_match_on
    data.attrs["x"] = coords[:, 0]
    data.attrs["y"] = coords[:, 1]
    data.attrs["z"] = coords[:, 2]

    # ROIs with channels
    rois = [
        f"{roi}_{channel}" for roi, channel in zip(data.roi.data, data.channels_labels)
    ]
    data = data.assign_coords({"roi": rois})
    # data.attrs = attrs
    data.values *= 1e6

    # return node_xr_remove_sca(data)
    return data

In [4]:
def create_epoched_data(data):
    t_match_on = (data.attrs["t_match_on"] - data.attrs["t_cue_on"]) / data.fsample
    t_match_on = np.round(t_match_on, 1)

    epoch_data = []

    for i in range(data.sizes["trials"]):
        stages = [
            [-0.4, 0.0],
            [0, 0.4],
            [0.5, 0.9],
            [0.9, 1.3],
            [t_match_on[i] - 0.4, t_match_on[i]],
        ]

        temp = []

        for t_i, t_f in stages:
            temp += [data[i].sel(time=slice(t_i, t_f)).data]

        epoch_data += [np.stack(temp, axis=-2)]

    epoch_data = xr.DataArray(
        np.stack(epoch_data),
        dims=("trials", "roi", "epochs", "time"),
        coords={
            "trials": data.trials,
            "roi": data.roi,
        },
        attrs=data.attrs,
    )

    return epoch_data

###############################################################################
##### Get session and monkey
###############################################################################

In [5]:
align = "cue"
monkey = "lucy"
session_number = get_dates(monkey)[0]

In [6]:
# Root directory
# _ROOT = os.path.expanduser("/media/lima.v/1.42.6-25556/GrayLab/")
_ROOT = os.path.expanduser("/home/INT/lima.v/data/GrayLab/")
_SAVE = os.path.expanduser(
    f"/home/INT/lima.v/Results/phase_encoding/psd/{monkey}/{session_number}"
)

if not os.path.exists(_SAVE):
    os.makedirs(_SAVE)

In [7]:
data = load_session_data(_ROOT, session_number, monkey, align)

###########################################################################
##### Create epoched data
###########################################################################

In [8]:
epoch_data = create_epoched_data(data)

###########################################################################
##### Find peak frequency and corresponding power
###########################################################################

In [9]:
stim_labels = epoch_data.attrs["stim"]

sxx = []
for i in range(epoch_data.sizes["epochs"]):
    sxx += [
        xr_psd_array_multitaper(
            epoch_data.sel(epochs=i),
            n_jobs=20,
            bandwidth=5,
        )
    ]

sxx = xr.concat(sxx, "epochs")
sxx.attrs = data.attrs

    Using multitaper spectrum estimation with 1 DPSS windows
    Using multitaper spectrum estimation with 1 DPSS windows
    Using multitaper spectrum estimation with 1 DPSS windows
    Using multitaper spectrum estimation with 1 DPSS windows
    Using multitaper spectrum estimation with 1 DPSS windows


###########################################################################
#### All sessions
###########################################################################

In [None]:
session_numbers = get_dates(monkey)

for session_number in tqdm(session_numbers):
    ####################################################### SAVE PATH #######################################################
    _SAVE = os.path.expanduser(
        f"/home/INT/lima.v/Results/phase_encoding/psd/{monkey}/{session_number}"
    )

    if not os.path.exists(_SAVE):
        os.makedirs(_SAVE)
    ####################################################### LOAD DATA #######################################################
    data = load_session_data(_ROOT, session_number, monkey, align)

    ####################################################### EPOCH DATA #######################################################
    epoch_data = create_epoched_data(data)

    ####################################################### FREQ. DOMAIN #######################################################
    stim_labels = epoch_data.attrs["stim"]

    sxx = []
    for i in range(epoch_data.sizes["epochs"]):
        sxx += [
            xr_psd_array_multitaper(
                epoch_data.sel(epochs=i),
                n_jobs=20,
                bandwidth=5,
            )
        ]

    sxx = xr.concat(sxx, "epochs")
    sxx.attrs = data.attrs

    ####################################################### SAVE DATA #######################################################
    sxx.to_netcdf(os.path.join(_SAVE, "sxx.nc"))