### For the meeting

In [1]:
import sys

sys.path.insert(1, "/home/vinicius/storage1/projects/vanderbilt")

In [21]:
import os
from functools import partial

import emd
import igraph as ig
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import PyEMD
import scipy
import seaborn as sns
import skimage as ski
import umap
import xarray as xr
from frites.conn.conn_sliding_windows import define_windows
from frites.utils import parallel_func
from mne.time_frequency import tfr_array_morlet
from scipy.optimize import curve_fit
from skimage.segmentation import watershed
from tqdm import tqdm

from config import metadata
from VUDA.emd import emd_vec
from VUDA.io.loadbinary import LoadBinary

#### Functions

In [5]:
def to_bin_freq(freqs, peaks):

    n_blocks, n_peaks = peaks.shape
    n_freqs = freqs.shape[0]

    def _for_peak(carry, peak):

        vec = jnp.zeros(n_freqs, dtype=int)
        indexes = jnp.stack(
            [jnp.argmin(jnp.abs(freqs - peak[i])) for i in range(n_peaks)]
        )
        vec = vec.at[indexes].set(1)
        return carry, vec

    _, vec = jax.lax.scan(_for_peak, None, peaks)

    return np.asarray(vec)

import numba as nb


# @nb.njit
def overlaps(theta_timings: list, gamma_timings: list):
    n_theta = len(theta_timings)
    n_gamma = len(gamma_timings)

    n_overlaps = np.empty(n_gamma, dtype=np.int64)

    for i in range(n_gamma):
        temp = np.logical_and(
            theta_timings[:, 0] - gamma_timings[i, 0] < 0,
            theta_timings[:, 1] - gamma_timings[i, 0] > 0,
        )

        temp = np.logical_and(
            temp,
            np.logical_and(
                theta_timings[:, 0] - gamma_timings[i, 1] < 0,
                theta_timings[:, 1] - gamma_timings[i, 1] > 0,
            ),
        )

        n_overlaps[i] = temp.sum()

    return n_overlaps

def replace_zeros_with_nan(arr):
    """
    Replaces zeros in a JAX array with NaNs.

    Parameters:
        arr (jax.numpy.ndarray): Input array.

    Returns:
        jax.numpy.ndarray: Array with zeros replaced by NaNs.
    """
    return jnp.where(arr == 0, jnp.nan, arr)


@partial(jax.vmap, in_axes=(None, 0))
def get_n_channels(labels: np.ndarray, label: int):
    """
    Calculates the maximum number of channels associated with a given label.

    Parameters:
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute channels for.

    Returns:
        int: Maximum number of channels associated with the given label.
    """
    return (labels == label).sum(0).max()


def nonzeromean(data: np.ndarray):
    """
    Computes the mean of non-zero elements in the input array.

    Parameters:
        data (numpy.ndarray): Input array.

    Returns:
        float: Mean of non-zero elements.
    """
    return data.sum() / (jnp.abs(data) > 0).sum()


def nonzerostd(data: np.ndarray):
    """
    Computes the standard deviation of non-zero elements in the input array.

    Parameters:
        data (numpy.ndarray): Input array.

    Returns:
        float: Standard deviation of non-zero elements.
    """
    non_zero_elements = data[data != 0]
    mean = non_zero_elements.mean()
    variance = ((non_zero_elements - mean) ** 2).mean()
    std_dev = np.sqrt(variance)
    return std_dev


def get_masked_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Masks the feature vector based on the given label.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to mask the feature vector.

    Returns:
        numpy.ndarray: Masked feature vector.
    """
    return (labels == label) * vector


@partial(jax.vmap, in_axes=(None, None, 0))
def average_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the average of non-zero elements in the masked feature vector, replacing zeros with NaNs.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the average for.

    Returns:
        float: Average of non-zero elements in the masked feature vector.
    """
    masked = get_masked_feature(vector, labels, label)
    return jnp.nanmean(replace_zeros_with_nan(masked))


@partial(jax.vmap, in_axes=(None, None, 0))
def spread_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the standard deviation of non-zero elements in the masked feature vector, replacing zeros with NaNs.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the standard deviation for.

    Returns:
        float: Standard deviation of non-zero elements in the masked feature vector.
    """
    masked = get_masked_feature(vector, labels, label)
    return jnp.nanstd(replace_zeros_with_nan(masked))


@partial(jax.vmap, in_axes=(None, None, 0))
def get_duration(times: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the duration of a particular label based on the provided times.

    Parameters:
        times (numpy.ndarray): Array of times.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the duration for.

    Returns:
        float: Duration of the label.
    """
    arr = get_masked_feature(times[None, :], labels, label)
    arr = replace_zeros_with_nan(arr)
    _min = jnp.nanmin(arr)
    _max = jnp.nanmax(arr)
    return _max - _min


@partial(jax.vmap, in_axes=(None, None, 0, None))
def get_spread(nobs: int, labels: np.ndarray, label: int, axes: tuple):
    """
    Computes the spatial spread of a particular label.

    Parameters:
        nobs (int): Number of channels.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the spatial spread for.

    Returns:
        numpy.ndarray: Array containing the minimum and maximum channel indices.
    """
    out = jnp.nonzero((labels == label).sum(axes), size=nobs, fill_value=jnp.nan)[0]
    ci, cf = jnp.nanmin(out), jnp.nanmax(out)
    return jnp.array([ci, cf], dtype=jnp.int16)

#### Load data

In [27]:
date = "10-20-2022"
monkey = "FN"
max_imfs = None
method = "eemd"
condition = "sleep"

base_path = os.path.expanduser(f"~/funcog/HoffmanData/{monkey}/{date}/")

composites_path = os.path.expanduser(
    os.path.join(
        base_path,
        f"composite_signals_{condition}_method_eemd_max_imfs_None_std_False.nc",
    )
)

ps_composites_path = os.path.expanduser(
    os.path.join(
        base_path,
        f"ps_composite_signals_{condition}_method_eemd_max_imfs_None_std_False.nc",
    )
)

In [28]:
composites = xr.open_dataset(composites_path)
ps_composites = xr.open_dataset(ps_composites_path)

#### Spectral bands channels

In [29]:
kernel = np.hanning(50)

In [30]:
channels = list(composites.keys())

In [None]:
peaks = []

for channel in tqdm(channels):
    data = ps_composites[channel].dropna("IMFs")
    freqs = data.freqs.data

    data_sm = xr.DataArray(
        scipy.signal.fftconvolve(data, kernel[None, None, :], mode="same", axes=2),
        dims=data.dims,
        coords=data.coords,
    )

    freqs = data.freqs.data

    n_blocks, n_IMFs, n_freqs = data.shape

    peaks += [freqs[data_sm.argmax("freqs").data]]

 32%|██████████████████▌                                      | 13/40 [00:05<00:11,  2.36it/s]

In [None]:
plt.figure(figsize=(15, 4))
d = []
for pos in range(len(channels)):
    x = peaks[pos].flatten()
    d += [x]
    plt.scatter(x, [pos + 1] * len(x), s=1, c="k")
plt.ylabel("Channel")
plt.xlabel("Frequency [Hz]")

#### Bursts

In [22]:
def replace_zeros_with_nan(arr):
    """
    Replaces zeros in a JAX array with NaNs.

    Parameters:
        arr (jax.numpy.ndarray): Input array.

    Returns:
        jax.numpy.ndarray: Array with zeros replaced by NaNs.
    """
    return jnp.where(arr == 0, jnp.nan, arr)


@partial(jax.vmap, in_axes=(None, 0))
def get_n_channels(labels: np.ndarray, label: int):
    """
    Calculates the maximum number of channels associated with a given label.

    Parameters:
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute channels for.

    Returns:
        int: Maximum number of channels associated with the given label.
    """
    return (labels == label).sum(0).max()


def nonzeromean(data: np.ndarray):
    """
    Computes the mean of non-zero elements in the input array.

    Parameters:
        data (numpy.ndarray): Input array.

    Returns:
        float: Mean of non-zero elements.
    """
    return data.sum() / (jnp.abs(data) > 0).sum()


def nonzerostd(data: np.ndarray):
    """
    Computes the standard deviation of non-zero elements in the input array.

    Parameters:
        data (numpy.ndarray): Input array.

    Returns:
        float: Standard deviation of non-zero elements.
    """
    non_zero_elements = data[data != 0]
    mean = non_zero_elements.mean()
    variance = ((non_zero_elements - mean) ** 2).mean()
    std_dev = np.sqrt(variance)
    return std_dev


def get_masked_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Masks the feature vector based on the given label.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to mask the feature vector.

    Returns:
        numpy.ndarray: Masked feature vector.
    """
    return (labels == label) * vector


@partial(jax.vmap, in_axes=(None, None, 0))
def average_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the average of non-zero elements in the masked feature vector, replacing zeros with NaNs.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the average for.

    Returns:
        float: Average of non-zero elements in the masked feature vector.
    """
    masked = get_masked_feature(vector, labels, label)
    return jnp.nanmean(replace_zeros_with_nan(masked))


@partial(jax.vmap, in_axes=(None, None, 0))
def spread_feature(vector: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the standard deviation of non-zero elements in the masked feature vector, replacing zeros with NaNs.

    Parameters:
        vector (numpy.ndarray): Feature vector.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the standard deviation for.

    Returns:
        float: Standard deviation of non-zero elements in the masked feature vector.
    """
    masked = get_masked_feature(vector, labels, label)
    return jnp.nanstd(replace_zeros_with_nan(masked))


@partial(jax.vmap, in_axes=(None, None, 0))
def get_duration(times: np.ndarray, labels: np.ndarray, label: int):
    """
    Computes the duration of a particular label based on the provided times.

    Parameters:
        times (numpy.ndarray): Array of times.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the duration for.

    Returns:
        float: Duration of the label.
    """
    arr = get_masked_feature(times[None, :], labels, label)
    arr = replace_zeros_with_nan(arr)
    _min = jnp.nanmin(arr)
    _max = jnp.nanmax(arr)
    return _max - _min


@partial(jax.vmap, in_axes=(None, None, 0, None))
def get_spread(nobs: int, labels: np.ndarray, label: int, axes: tuple):
    """
    Computes the spatial spread of a particular label.

    Parameters:
        nobs (int): Number of channels.
        labels (numpy.ndarray): Array of labels.
        label (int): Label to compute the spatial spread for.

    Returns:
        numpy.ndarray: Array containing the minimum and maximum channel indices.
    """
    out = jnp.nonzero((labels == label).sum(axes), size=nobs, fill_value=jnp.nan)[0]
    ci, cf = jnp.nanmin(out), jnp.nanmax(out)
    return jnp.array([ci, cf], dtype=jnp.int16)

In [23]:
def load_bursts(channel: str = None, rythm: str = "slow"):
    assert rythm in ["slow", "fast"]
    _base = os.path.join(base_path, "bursts")
    fname = f"labeled_bursts_{rythm}_{channel}_{condition}_method_eemd_max_imfs_None_std_False.nc"
    fname = os.path.join(_base, fname)
    return xr.load_dataarray(fname)

In [24]:
bursts = (load_bursts(channel="channel31", rythm="slow")).astype(int)

NameError: name 'base_path' is not defined