In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import firwin
from tqdm import tqdm
from portiloopml.portiloop_python.ANN.wamsley_utils import detect_lacourse
import os

In [19]:


def shift_numpy(arr, num, fill_value=np.nan):
    result = np.empty_like(arr)
    if num > 0:
        result[:num] = fill_value
        result[num:] = arr[:-num]
    elif num < 0:
        result[num:] = fill_value
        result[:num] = arr[-num:]
    else:
        result[:] = arr
    return result

class FIR:
    def __init__(self, nb_channels, coefficients, buffer=None):
        
        self.coefficients = np.expand_dims(np.array(coefficients), axis=1)
        self.taps = len(self.coefficients)
        self.nb_channels = nb_channels
        self.buffer = np.array(buffer) if buffer is not None else np.zeros((self.taps, self.nb_channels))
    
    def filter(self, x):
        self.buffer = shift_numpy(self.buffer, 1, x)
        filtered = np.sum(self.buffer * self.coefficients, axis=0)
        return filtered

    
class FilterPipeline:
    def __init__(self,
                 nb_channels,
                 sampling_rate,
                 power_line_fq=60,
                 use_custom_fir=False,
                 custom_fir_order=20,
                 custom_fir_cutoff=30,
                 alpha_avg=0.1,
                 alpha_std=0.001,
                 epsilon=0.000001,
                 filter_args=[]):
        if len(filter_args) > 0:
            use_fir, use_notch, use_std = filter_args
        else:
            use_fir=True,
            use_notch=False,
            use_std=True
        self.use_fir = use_fir
        self.use_notch = use_notch
        self.use_std = use_std
        self.nb_channels = nb_channels
        assert power_line_fq in [50, 60], f"The only supported power line frequencies are 50 Hz and 60 Hz"
        if power_line_fq == 60:
            self.notch_coeff1 = -0.12478308884588535
            self.notch_coeff2 = 0.98729186796473023
            self.notch_coeff3 = 0.99364593398236511
            self.notch_coeff4 = -0.12478308884588535
            self.notch_coeff5 = 0.99364593398236511
        else:
            self.notch_coeff1 = -0.61410695998423581
            self.notch_coeff2 =  0.98729186796473023
            self.notch_coeff3 = 0.99364593398236511
            self.notch_coeff4 = -0.61410695998423581
            self.notch_coeff5 = 0.99364593398236511
        self.dfs = [np.zeros(self.nb_channels), np.zeros(self.nb_channels)]
        
        self.moving_average = None
        self.moving_variance = np.zeros(self.nb_channels)
        self.ALPHA_AVG = alpha_avg
        self.ALPHA_STD = alpha_std
        self.EPSILON = epsilon
        
        if use_custom_fir:
            self.fir_coef = firwin(numtaps=custom_fir_order+1, cutoff=custom_fir_cutoff, fs=sampling_rate)
        else:
            self.fir_coef = [
                0.001623780150148094927192721215192250384,
                0.014988684599373741992978104065059596905,
                0.021287595318265635502275046064823982306,
                0.007349500393709578957568417933998716762,
                -0.025127515717112181709014251396183681209,
                -0.052210507359822452833064687638398027048,
                -0.039273839505489904766477593511808663607,
                0.033021568427940004020193498490698402748,
                0.147606943281569008563636202779889572412,
                0.254000252034505602516389899392379447818,
                0.297330876398883392486283128164359368384,
                0.254000252034505602516389899392379447818,
                0.147606943281569008563636202779889572412,
                0.033021568427940004020193498490698402748,
                -0.039273839505489904766477593511808663607,
                -0.052210507359822452833064687638398027048,
                -0.025127515717112181709014251396183681209,
                0.007349500393709578957568417933998716762,
                0.021287595318265635502275046064823982306,
                0.014988684599373741992978104065059596905,
                0.001623780150148094927192721215192250384]
        self.fir = FIR(self.nb_channels, self.fir_coef)
        
    def filter(self, value):
        """
        value: a numpy array of shape (data series, channels)
        """
        for i, x in enumerate(value):  # loop over the data series
            # FIR:
            if self.use_fir:
                x = self.fir.filter(x)
            # notch:
            if self.use_notch:
                denAccum = (x - self.notch_coeff1 * self.dfs[0]) - self.notch_coeff2 * self.dfs[1]
                x = (self.notch_coeff3 * denAccum + self.notch_coeff4 * self.dfs[0]) + self.notch_coeff5 * self.dfs[1]
                self.dfs[1] = self.dfs[0]
                self.dfs[0] = denAccum
            # standardization:
            if self.use_std:
                if self.moving_average is not None:
                    delta = x - self.moving_average
                    self.moving_average = self.moving_average + self.ALPHA_AVG * delta
                    self.moving_variance = (1 - self.ALPHA_STD) * (self.moving_variance + self.ALPHA_STD * delta**2)
                    moving_std = np.sqrt(self.moving_variance)
                    x = (x - self.moving_average) / (moving_std + self.EPSILON)
                else:
                    self.moving_average = x
            try:
                value[i] = x
            except:
                print(f"Error in filtering: {x}")
                continue
        return value

In [20]:
def online_detrend(y, alpha=0.95):
    detrended_y = np.zeros_like(y)
    trend = 0
    for i in range(len(y)):
        trend = alpha * trend + (1 - alpha) * y[i]
        detrended_y[i] = y[i] - trend
    return detrended_y

In [39]:
def raw2filtered(raw):
    '''
    Take in the raw data and filter it online, detrend it
    '''
    filtering4lac = FilterPipeline(nb_channels=1, sampling_rate=250, filter_args=[True, True, False])
    filtered4lac = []
    print(f'Filtering Data Online')
    for i in tqdm(raw):
        filtered4lac.append(filtering4lac.filter(np.array([i])))

    print(f'Detrending Data')
    detrended_data = online_detrend(np.array(filtered4lac).flatten())

    print(f"Running Lacourse")
    data_detect = np.array(detrended_data)
    mask = np.ones(len(data_detect), dtype=bool)
    lacourse = detect_lacourse(
        data_detect,
        mask,
        sampling_rate=250,
    )

    if len(lacourse) == 0:
        return None, None, None

    print(f"Lacourse found {len(lacourse)} spindles")

    # Filter data online like on Portiloop
    print(f"Filtering Online with standardization")
    filtering_online = FilterPipeline(nb_channels=1, sampling_rate=250, filter_args=[True, True, True])
    filtered_online = []
    for i in tqdm(raw):
        filtered_online.append(filtering_online.filter(np.array([i])))
    filtered_online = np.array(filtered_online).flatten()

    return detrended_data, lacourse, filtered_online 


In [41]:
path = '/project/portinight-raw/PN_08_AC/'
save_path = '/project/portinight-dataset/'
subject_id = 'PN_08_AC'
age = 24
gender = 'F'

# Iterate through all the csvs in the folder
data = {}
for idx, file in enumerate(os.listdir(path)):
    if file.endswith(".csv"):
        print(f"Processing {file}")

        # Get the filename
        filename = file.split('_')[:-1]
        filename = '_'.join(filename)

        df = pd.read_csv(path + file, on_bad_lines='warn', encoding_errors='ignore')
        # convert useful data to floats
        df['converted'] = pd.to_numeric(df.iloc[:, 1], errors='coerce')
        df['converted'] = df['converted'].fillna(method='ffill')
        useful_data = df['converted'].values
        filtered4lac, lacourse, filtered_online = raw2filtered(useful_data)

        if filtered4lac is None:
            print(f"No spindles found in {file}, skipping")
            continue

        spindle_info_mass = {}
        spindle_info_mass[filename] = {
            'onsets': [],
            'offsets': [],
            'labels_num': []
        }

        for spindle in lacourse:
            spindle_info_mass[filename]['onsets'].append(spindle[0])
            spindle_info_mass[filename]['offsets'].append(spindle[1])
            spindle_info_mass[filename]['labels_num'].append(1)

        data[filename] = {}

        data[filename]['age'] = age
        data[filename]['gender'] = gender
        data[filename]['signal_mass'] = filtered4lac
        data[filename]['signal_filt'] = filtered_online
        data[filename]['ss_label'] = np.ones(len(filtered4lac)) * 5
        data[filename]['spindle_mass_lacourse'] = spindle_info_mass

print(f"Saving f{subject_id}.npz")
np.savez_compressed(os.path.join(save_path, f"{subject_id}.npz"), data)


Processing PN_08_AC_Night6a_Stim.csv


Skipping line 57060: expected 7 fields, saw 12

Skipping line 339025: expected 7 fields, saw 8
Skipping line 344135: expected 7 fields, saw 9

Skipping line 2033544: expected 7 fields, saw 13
Skipping line 2033549: expected 7 fields, saw 10

  df = pd.read_csv(path + file, on_bad_lines='warn', encoding_errors='ignore')
  df['converted'] = df['converted'].fillna(method='ffill')


Filtering Data Online


  value[i] = x
100%|██████████| 2162978/2162978 [01:16<00:00, 28278.71it/s]


Detrending Data
Running Lacourse
Lacourse found 471 spindles
Filtering Online with standardization


100%|██████████| 2162978/2162978 [01:40<00:00, 21502.35it/s]


Processing PN_08_AC_Night2_Stim.csv


Skipping line 297054: expected 7 fields, saw 8
Skipping line 297077: expected 7 fields, saw 8
Skipping line 297112: expected 7 fields, saw 8
Skipping line 299105: expected 7 fields, saw 16
Skipping line 299130: expected 7 fields, saw 8
Skipping line 301967: expected 7 fields, saw 13
Skipping line 301969: expected 7 fields, saw 11
Skipping line 303816: expected 7 fields, saw 11
Skipping line 303817: expected 7 fields, saw 11

Skipping line 426256: expected 7 fields, saw 8
Skipping line 426257: expected 7 fields, saw 17
Skipping line 426275: expected 7 fields, saw 10
Skipping line 426325: expected 7 fields, saw 12
Skipping line 426326: expected 7 fields, saw 12
Skipping line 426830: expected 7 fields, saw 19
Skipping line 430414: expected 7 fields, saw 12
Skipping line 430436: expected 7 fields, saw 11
Skipping line 430780: expected 7 fields, saw 8
Skipping line 431521: expected 7 fields, saw 31
Skipping line 431526: expected 7 fields, saw 14
Skipping line 431527: expected 7 fields, saw 

Filtering Data Online


  value[i] = x
100%|██████████| 6655663/6655663 [03:55<00:00, 28268.49it/s]


Detrending Data
Running Lacourse


  out[i] = (dat[i] - mean(windat)) / std(stddat)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


No spindles found in PN_08_AC_Night2_Stim.csv, skipping
Processing PN_08_AC_Night3_Sham.csv


Skipping line 1294303: expected 7 fields, saw 13
Skipping line 1294376: expected 7 fields, saw 11

Skipping line 1415599: expected 7 fields, saw 12
Skipping line 1415601: expected 7 fields, saw 8
Skipping line 1415613: expected 7 fields, saw 13
Skipping line 1415643: expected 7 fields, saw 14
Skipping line 1415646: expected 7 fields, saw 12
Skipping line 1433296: expected 7 fields, saw 8
Skipping line 1433341: expected 7 fields, saw 8

Skipping line 2869941: expected 7 fields, saw 11

Skipping line 2916664: expected 7 fields, saw 8
Skipping line 2916706: expected 7 fields, saw 11
Skipping line 2916710: expected 7 fields, saw 8
Skipping line 2917012: expected 7 fields, saw 18
Skipping line 2917789: expected 7 fields, saw 9
Skipping line 2918834: expected 7 fields, saw 12
Skipping line 2918835: expected 7 fields, saw 12
Skipping line 2918961: expected 7 fields, saw 14
Skipping line 2918973: expected 7 fields, saw 10
Skipping line 2919134: expected 7 fields, saw 12
Skipping line 2919135: 

Filtering Data Online


  value[i] = x
100%|██████████| 6836126/6836126 [04:04<00:00, 27925.63it/s]


Detrending Data
Running Lacourse


  out[i] = (dat[i] - mean(windat)) / std(stddat)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


No spindles found in PN_08_AC_Night3_Sham.csv, skipping
Processing PN_08_AC_Night4_Stim.csv


Skipping line 1701418: expected 7 fields, saw 8

Skipping line 4971264: expected 7 fields, saw 14

Skipping line 7045695: expected 7 fields, saw 9

Skipping line 7198436: expected 7 fields, saw 11
Skipping line 7198437: expected 7 fields, saw 13
Skipping line 7198454: expected 7 fields, saw 9
Skipping line 7198455: expected 7 fields, saw 8
Skipping line 7198457: expected 7 fields, saw 10
Skipping line 7198476: expected 7 fields, saw 10
Skipping line 7198479: expected 7 fields, saw 12

Skipping line 7244529: expected 7 fields, saw 11

  df = pd.read_csv(path + file, on_bad_lines='warn', encoding_errors='ignore')
  df['converted'] = df['converted'].fillna(method='ffill')


Filtering Data Online


  value[i] = x
100%|██████████| 7431566/7431566 [04:24<00:00, 28097.20it/s]


Detrending Data
Running Lacourse
Lacourse found 1638 spindles
Filtering Online with standardization


100%|██████████| 7431566/7431566 [05:45<00:00, 21518.50it/s]


Processing PN_08_AC_Night6b_Stim.csv


Skipping line 360787: expected 7 fields, saw 13
Skipping line 360805: expected 7 fields, saw 8
Skipping line 360807: expected 7 fields, saw 14
Skipping line 360815: expected 7 fields, saw 19
Skipping line 362859: expected 7 fields, saw 15
Skipping line 362865: expected 7 fields, saw 8
Skipping line 362866: expected 7 fields, saw 16
Skipping line 366875: expected 7 fields, saw 10
Skipping line 366877: expected 7 fields, saw 13
Skipping line 366885: expected 7 fields, saw 13
Skipping line 366916: expected 7 fields, saw 12

Skipping line 446079: expected 7 fields, saw 10
Skipping line 447225: expected 7 fields, saw 12
Skipping line 448345: expected 7 fields, saw 8
Skipping line 448365: expected 7 fields, saw 19
Skipping line 482681: expected 7 fields, saw 11

  df = pd.read_csv(path + file, on_bad_lines='warn', encoding_errors='ignore')
  df['converted'] = df['converted'].fillna(method='ffill')


Filtering Data Online


  value[i] = x
100%|██████████| 5021967/5021967 [02:59<00:00, 27978.78it/s]


Detrending Data
Running Lacourse


  out[i] = (dat[i] - mean(windat)) / std(stddat)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


No spindles found in PN_08_AC_Night6b_Stim.csv, skipping
Processing PN_08_AC_Night5_Sham.csv


Skipping line 236252: expected 7 fields, saw 11
Skipping line 236814: expected 7 fields, saw 10
Skipping line 236816: expected 7 fields, saw 11
Skipping line 237715: expected 7 fields, saw 11

Skipping line 263038: expected 7 fields, saw 13
Skipping line 263040: expected 7 fields, saw 11
Skipping line 271956: expected 7 fields, saw 12
Skipping line 272487: expected 7 fields, saw 12
Skipping line 272835: expected 7 fields, saw 12
Skipping line 273053: expected 7 fields, saw 16
Skipping line 273086: expected 7 fields, saw 8
Skipping line 273448: expected 7 fields, saw 8
Skipping line 273640: expected 7 fields, saw 8

Skipping line 559005: expected 7 fields, saw 13
Skipping line 598073: expected 7 fields, saw 13
Skipping line 598075: expected 7 fields, saw 8
Skipping line 598670: expected 7 fields, saw 12
Skipping line 647910: expected 7 fields, saw 9

Skipping line 775543: expected 7 fields, saw 8
Skipping line 775554: expected 7 fields, saw 8
Skipping line 775573: expected 7 fields, saw

Filtering Data Online


  value[i] = x
100%|██████████| 6268914/6268914 [03:41<00:00, 28295.91it/s]


Detrending Data
Running Lacourse


  out[i] = (dat[i] - mean(windat)) / std(stddat)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


No spindles found in PN_08_AC_Night5_Sham.csv, skipping
Saving fPN_08_AC.npz


In [36]:
# Count how many NaNs are in the data
df['converted'].isna().sum()

0

In [26]:
save_path = '/project/portinight-dataset/'
subject_id = 'PN_01_HJ'

# Try loading the portinight data:
data = np.load(os.path.join(save_path, f"{subject_id}.npz"), allow_pickle=True)

In [27]:

data['arr_0']

array({'PN_01_HJ_Night1': {'age': 28, 'gender': 'M', 'signal_mass': array([  16.36395368,  248.29117582, 1204.59631932, ...,   62.28175158,
         78.37960546,   90.34512397]), 'signal_filt': array([17.2252144 , 28.4747362 , 27.93541527, ...,  1.04844287,
        1.33055178,  1.51649733]), 'ss_label': array([5., 5., 5., ..., 5., 5., 5.]), 'spindle_mass_lacourse': {'PN_01_HJ_Night1': {'onsets': [4375, 13300, 14450, 24650, 32450, 34425, 38500, 39400, 43400, 56875, 63725, 65950, 66650, 71600, 79900, 87250, 98825, 107025, 114450, 129324, 130250, 140375, 146475, 150000, 156450, 169300, 169775, 172425, 178200, 181675, 188825, 189100, 190150, 190400, 190725, 192075, 195850, 202325, 207225, 207925, 209025, 209900, 213200, 214075, 222250, 222675, 224950, 227100, 228000, 229325, 239200, 240275, 241500, 242475, 249575, 250550, 254850, 256925, 263325, 263775, 264700, 267775, 271150, 273125, 279225, 281900, 296650, 300350, 309450, 342000, 343650, 347350, 347875, 355700, 370750, 392475, 404150, 41

In [60]:
data.keys()

dict_keys(['PN_01_HJ_Night1', 'PN_01_HJ_Night4', 'PN_01_HJ_Night5', 'PN_01_HJ_Night6', 'PN_01_HJ_Night3'])

In [66]:
# clean the dictionary keys
for new_key in list(data.keys()):
    print(data[new_key]['spindle_mass_lacourse'].keys())
    old_key = list(data[new_key]['spindle_mass_lacourse'].keys())[0]
    print(old_key)
    data[new_key]['spindle_mass_lacourse'][new_key] = data[new_key]['spindle_mass_lacourse'].pop(old_key)

dict_keys(['PN_01_HJ_Night1_Stim_night0'])
PN_01_HJ_Night1_Stim_night0
dict_keys(['PN_01_HJ_Night4_Sham_night1'])
PN_01_HJ_Night4_Sham_night1
dict_keys(['PN_01_HJ_Night5_Stim_night2'])
PN_01_HJ_Night5_Stim_night2
dict_keys(['PN_01_HJ_Night6_Sham_night5'])
PN_01_HJ_Night6_Sham_night5
dict_keys(['PN_01_HJ_Night3_Stim_night6'])
PN_01_HJ_Night3_Stim_night6


In [67]:
data['PN_01_HJ_Night1']['spindle_mass_lacourse'].keys()

dict_keys(['PN_01_HJ_Night1'])

In [69]:
# Save the cleaned data
np.savez_compressed(os.path.join(save_path, f"{subject_id}.npz"), data)

In [13]:
from portiloopml.portiloop_python.ANN.data.mass_data_new import MassDataset

data_path = '/project/MASS/mass_spindles_dataset/'
subjects = ['PN_01_HJ_Night4', '01-01-0001']

dataset = MassDataset(data_path, 30, 30, 30, subjects=subjects, use_filtered=False, sampleable='spindles', compute_spindle_labels=False)

Time taken to load PN_01_HJ_Night4: 21.723674297332764
Time taken to load 01-01-0001: 2.1120471954345703
Time taken to create lookup table: 1.442598819732666
Number of sampleable indices: 14841701
Number of spindle indexes: 422714
Number of spindles: 2301
Number of N1 indexes: 0
Number of N2 indexes: 0
Number of N3 indexes: 0
Number of R indexes: 0
Number of W indexes: 0


In [51]:
# Read the list of subjects from the file
subject_list_path = '/home/ubuntu/portiloop-training/subjects_portinight.txt'

with open(subject_list_path, 'r') as f:
    subjects = f.readlines()

subjects = [x.strip() for x in subjects]


for subject in subjects[15:]:
    dataset = MassDataset(data_path, 30, 30, 30, subjects=[subject], use_filtered=False, sampleable='spindles', compute_spindle_labels=False)

Time taken to load PN_08_AC_Night4: 1.5086579322814941
Time taken to create lookup table: 0.6257777214050293
Number of sampleable indices: 7430667
Number of spindle indexes: 189157
Number of spindles: 1638
Number of N1 indexes: 0
Number of N2 indexes: 0
Number of N3 indexes: 0
Number of R indexes: 0
Number of W indexes: 0
Time taken to load PN_08_AC_Night6a: 1.574091911315918
Time taken to create lookup table: 0.1373581886291504
Number of sampleable indices: 2162079
Number of spindle indexes: 54092
Number of spindles: 471
Number of N1 indexes: 0
Number of N2 indexes: 0
Number of N3 indexes: 0
Number of R indexes: 0
Number of W indexes: 0


KeyError: 'PN_08_AC_Night6b'