# Load the data

In [2]:
from portiloopml.portiloop_python.ANN.data.mass_data_new import MassDataset

In [3]:
dataset_path = '/project/MASS/mass_spindles_dataset'
subject = '01-02-0002'
dataset = MassDataset(
        dataset_path,
        subjects=[subject],
        window_size=54,
        seq_stride=42,
        seq_len=1,
        use_filtered=False)

Time taken to load 01-02-0002: 4.301700830459595
Time taken to create lookup table: 1.0775446891784668
Number of sampleable indices: 8753947
Number of spindle indexes: 169797
Number of spindles: 909
Number of N1 indexes: 345000
Number of N2 indexes: 4555000
Number of N3 indexes: 1335000
Number of R indexes: 1000000
Number of W indexes: 550000


In [4]:
signal = dataset.data[subject]['signal']
wamsley_spindles = dataset.data[subject]['spindle_mass_fixed'][subject]['onsets']
ss_labels = dataset.data[subject]['ss_label']
mask = (ss_labels == 1) | (ss_labels == 2)

In [18]:
import numpy as np
from scipy.signal import filtfilt, firwin

def detect_SSW_Carrier(signal, config):
    # signals.data is a Ntrials x Ntime (IN MICRO VOLTS)
    # signals.fs is the sampling rate
    # OPTIONS.Sleep contains the parameters for the SSW detection

    # output:
    # SSW ----- .marker[i] ---.nsamp : table [start end] (samples in the data)
    #                         .P2P   : Peak to Peak amplitude of each SSW
    #                         .Neg   : Neg amplitude
    #                         .mfr   : mean frequency of the slow wave
    #                         (and other features)
    #           .stat ----.d_SSW : slow wave density (across the epochs)
    #           .filtered_signals : filtered data in the delta freq. band

    # CRITERES (Carrier et al. EJN, 2011)
    # critere temporel
    # il doit porter sur la partie < 0 et sur la partie > de l'onde
    # separement
    # on separe la partie negative de la partie positive:
    # la duree de la partie < 0 doit etre entre 125 et 1500 msec
    # la duree de la partie > 0 doit etre < 1000 msec
    # l'amplitude PaP > 75 uV
    # l'amplitude Neg > 40 uV
    

    th_PaP = config['PaP']
    th_Neg = config['Neg']
    min_tNe = config['duree_min_max_Neg'][0]
    max_tNe = config['duree_min_max_Neg'][1]
    max_tPo = config['duree_max_Pos']
    fmin_max = config['fmin_max']
    fs = config['fs'] 

    # parameters of the filter (SSW)
    wn = np.array(fmin_max) / (fs / 2)

    # FIR filter: the order could play a role in the density of detected waves.
    ssw_filter = firwin(500, wn, pass_zero=False)

    N_SSW = 0

    # iterate over trials
    length = len(signal) / fs / 60  # minutes

    # filtering in the SSW band
    sigf = filtfilt(ssw_filter, 1, signal)

    # zero crossings 
    n_zc = np.where(np.diff(np.sign(sigf)) < 0)[0]

    # amplitude PaP and other properties (initialization)
    n_t = np.zeros((len(n_zc) - 1, 2), dtype=int)
    P2P = np.zeros(len(n_zc) - 1)
    Neg = np.zeros(len(n_zc) - 1)
    tNe = np.zeros(len(n_zc) - 1)
    tPo = np.zeros(len(n_zc) - 1)
    PaP_raw = np.zeros(len(n_zc) - 1)
    Neg_raw = np.zeros(len(n_zc) - 1)
    mfr = np.zeros(len(n_zc) - 1)
    keep = []

    # Visit all zero crossings
    for i in range(len(n_zc) - 1):
        n_t[i, :] = [n_zc[i], n_zc[i + 1]]
        segment = sigf[n_zc[i] + 1: n_zc[i + 1] - 1]  # exclude the borders
        segNeg = np.where(segment < 0)[0]
        segPos = np.where(segment >= 0)[0]
        P2P[i] = np.abs(max(segment) - min(segment))  # in microV
        Neg[i] = np.max(np.abs(segment[segNeg]))
        tNe[i] = (len(segNeg) - 1) / fs * 1000  # msec
        tPo[i] = (len(segPos) - 1) / fs * 1000  # msec

        # frequency of transition
        # u = np.argmin(segment)
        # v = np.argmax(segment)
        # tfr = fs / (v - u) / 2
        mfr[i] = fs / (n_zc[i + 1] - n_zc[i])

        # signal 'raw data'
        segment_raw = signal[n_zc[i] + 1: n_zc[i + 1] - 1]  # exclude the borders
        PaP_raw[i] = np.abs(max(segment_raw) - min(segment_raw))  # in microV
        Neg_raw[i] = np.max(np.abs(segment_raw[segNeg]))

        # Selection and criterion (Carrier Criterion)
        if (P2P[i] > th_PaP) and (Neg[i] > th_Neg) and (tNe[i] > min_tNe) and (tNe[i] < max_tNe) and (
                tPo[i] < max_tPo):
            keep.append(i)

    # Save the detected SSWs in the trials
    marker = {}
    marker['Thresholds_PaP_Neg'] = [th_PaP, th_Neg]
    marker['nsamp'] = n_t[keep, :]
    marker['P2P'] = P2P[keep]
    marker['Neg'] = Neg[keep]
    marker['tNe'] = tNe[keep]
    marker['tPo'] = tPo[keep]
    marker['P2P_raw'] = PaP_raw[keep]
    marker['Neg_raw'] = Neg_raw[keep]
    marker['mfr'] = mfr[keep]

    N_SSW += len(keep)

    Stat_SSW = {'N_SSW': N_SSW, 'd_SSW': N_SSW / length, 'duree': length}

    # mess = '\tSSW Carrier detector(%3.2f-%3.2f Hz):\n\twe found %d SSW over %d trials\n\t(%3.2f SSW/minute)\n' % (
    #     OPTIONS['Sleep']['SSW']['Carrier']['fmin_max'][0], OPTIONS['Sleep']['SSW']['Carrier']['fmin_max'][1], N_SSW,
    #     Ntrials, N_SSW / length)
    # print(mess)

    # Output
    SSW = {'Comments': '\n\tOutput of the function\n\t%s' % (
            'SSW.marker : One marker field per epoch with the following fields,\n\t' +
            '. nsamp is a table of [start end], the number of line is the number of SSW\n\t' +
            '. Neg is an aray of DOWN phase amplitudes (uV) [filtered signal]\n\t' +
            '. P2P is an aray of peak-to-peak amplitudes (uV) [filtered signal]\n\t' +
            '. Neg_raw is an aray of DOWN phase amplitudes (uV) [raw signal]\n\t' +
            '. P2P_raw is an aray of peak-to-peak amplitudes (uV) [raw signal]\n\t' +
            '. tNe is an array of duration of the DOWN phase\n\t' +
            '. tPo is an array of duration of the UP phase\n\t' +
            '. mfr is an array of the mean frequency\n\t'),
           'markers': marker,
           'stat': Stat_SSW,
           }

    return SSW


In [19]:
config = {
    'PaP': 75,
    'Neg': 40,
    'duree_min_max_Neg': [125, 1500],
    'duree_max_Pos': 1000,
    'fs': 250,
    'fmin_max': [0.16, 4],
}

SSW = detect_SSW_Carrier(signal, config)

In [20]:
SSW

{'Comments': '\n\tOutput of the function\n\tSSW.marker : One marker field per epoch with the following fields,\n\t. nsamp is a table of [start end], the number of line is the number of SSW\n\t. Neg is an aray of DOWN phase amplitudes (uV) [filtered signal]\n\t. P2P is an aray of peak-to-peak amplitudes (uV) [filtered signal]\n\t. Neg_raw is an aray of DOWN phase amplitudes (uV) [raw signal]\n\t. P2P_raw is an aray of peak-to-peak amplitudes (uV) [raw signal]\n\t. tNe is an array of duration of the DOWN phase\n\t. tPo is an array of duration of the UP phase\n\t. mfr is an array of the mean frequency\n\t',
 'markers': {'Thresholds_PaP_Neg': [75, 40],
  'nsamp': array([[  15878,   16174],
         [  33753,   33997],
         [  72292,   72584],
         ...,
         [8719532, 8719636],
         [8720740, 8720981],
         [8720981, 8721212]]),
  'P2P': array([ 78.16192341,  99.63219848, 206.50419983, ..., 118.01226497,
         146.04794385, 153.83840919]),
  'Neg': array([ 58.3511435 