In [1]:
import argparse
import glob
import math
import ntpath
import os
import shutil
#import urllib
#import urllib2
from datetime import datetime
import numpy as np
import pandas as pd
from mne import Epochs, pick_types, find_events
from mne.io import concatenate_raws, read_raw_edf
import import_ipynb
import dhedfreader

importing Jupyter notebook from dhedfreader.ipynb


In [2]:
# Label values
W = 0
N1 = 1
N2 = 2
N3 = 3
REM = 4
UNKNOWN = 5
stage_dict = {
    "W": W,
    "N1": N1,
    "N2": N2,
    "N3": N3,
    "REM": REM,
    "UNKNOWN": UNKNOWN
}
class_dict = {
    0: "W",
    1: "N1",
    2: "N2",
    3: "N3",
    4: "REM",
    5: "UNKNOWN"
}
ann2label = {
    "Sleep stage W": 0,
    "Sleep stage 1": 1,
    "Sleep stage 2": 2,
    "Sleep stage 3": 3,
    "Sleep stage 4": 3,
    "Sleep stage R": 4,
    "Sleep stage ?": 5,
    "Movement time": 5
}
EPOCH_SEC_SIZE = 30

In [4]:
def main():
    # Select channel
    select_ch = 'EEG Fpz-Cz'
    # Read raw and annotation EDF files
    psg_fnames = glob.glob(os.path.join('C:/Users/pratik/Desktop/DM Project/sleep-edf-database-expanded-1.0.0/sleep-cassette', "*PSG.edf"))
    ann_fnames = glob.glob(os.path.join('C:/Users/pratik/Desktop/DM Project/sleep-edf-database-expanded-1.0.0/sleep-cassette', "*Hypnogram.edf"))
    psg_fnames.sort()
    ann_fnames.sort()
    psg_fnames = np.asarray(psg_fnames)
    ann_fnames = np.asarray(ann_fnames)
    #for i in range(len(psg_fnames)):
    for i in range(2):    
        # if not "ST7171J0-PSG.edf" in psg_fnames[i]:
        #     continue
        raw = read_raw_edf(psg_fnames[i], preload=True, stim_channel=None)
        print('what is sfreq')
        print(raw.info['sfreq'])
        sampling_rate = raw.info['sfreq']
        raw_ch_df = raw.to_data_frame(scaling_time=100.0)[select_ch]
        raw_ch_df = raw_ch_df.to_frame()
        raw_ch_df.set_index(np.arange(len(raw_ch_df)))
        #print(raw_ch_df)
        # Get raw header
        #f = open(psg_fnames[i], 'r',encoding='utf-8') 
        f = open(psg_fnames[i], 'r',errors='ignore')
        reader_raw = dhedfreader.BaseEDFReader(f)
        #print(reader_raw)
        reader_raw.read_header()
        h_raw = reader_raw.header
        #print(h_raw)
        f.close()
        raw_start_dt = datetime.strptime(h_raw['date_time'], "%Y-%m-%d %H:%M:%S")
        # Read annotation and its header
        f = open(ann_fnames[i], 'r',errors='ignore')
        reader_ann = dhedfreader.BaseEDFReader(f)
        #print('ann test')
        #print(reader_ann)
        reader_ann.read_header()
        h_ann = reader_ann.header
        #print(h_ann)
        #print(reader_ann.records())
        #print('zip:',zip(reader_ann.records()))
        _, _, ann = zip(*reader_ann.records())
        #ann = zip(reader_ann.records())
        f.close()
        ann_start_dt = datetime.strptime(h_ann['date_time'], "%Y-%m-%d %H:%M:%S")
        #print(ann_start_dt)
        # Assert that raw and annotation files start at the same time
        assert raw_start_dt == ann_start_dt
        # Generate label and remove indices
        remove_idx = []    # indicies of the data that will be removed
        labels = []        # indicies of the data that have labels
        label_idx = []
        for a in ann[0]:
            onset_sec, duration_sec, ann_char = a
            ann_str = "".join(ann_char)
            label = ann2label[ann_str]
            if label != UNKNOWN:
                if duration_sec % EPOCH_SEC_SIZE != 0:
                    raise Exception("Something wrong")
                duration_epoch = int(duration_sec / EPOCH_SEC_SIZE)
                label_epoch = np.ones(duration_epoch, dtype=np.int) * label
                labels.append(label_epoch)
                idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
                label_idx.append(idx)
                print("Include onset:{}, duration:{}, label:{} ({})".format(onset_sec, duration_sec, label, ann_str))
            else:
                idx = int(onset_sec * sampling_rate) + np.arange(duration_sec * sampling_rate, dtype=np.int)
                remove_idx.append(idx)
                print("Remove onset:{}, duration:{}, label:{} ({})".format(onset_sec, duration_sec, label, ann_str))
        labels = np.hstack(labels)
        print("before remove unwanted: {}".format(np.arange(len(raw_ch_df)).shape))
        if len(remove_idx) > 0:
            remove_idx = np.hstack(remove_idx)
            select_idx = np.setdiff1d(np.arange(len(raw_ch_df)), remove_idx)
        else:
            select_idx = np.arange(len(raw_ch_df))
        print("after remove unwanted: {}".format(select_idx.shape))
        # Select only the data with labels
        print("before intersect label: {}".format(select_idx.shape))
        label_idx = np.hstack(label_idx)
        select_idx = np.intersect1d(select_idx, label_idx)
        print("after intersect label: {}".format(select_idx.shape))
        # Remove extra index
        if len(label_idx) > len(select_idx):
            print("before remove extra labels: {}, {}".format(select_idx.shape, labels.shape))
            extra_idx = np.setdiff1d(label_idx, select_idx)
            # Trim the tail
            if np.all(extra_idx > select_idx[-1]):
                n_trims = len(select_idx) % int(EPOCH_SEC_SIZE * sampling_rate)
                n_label_trims = int(math.ceil(n_trims / (EPOCH_SEC_SIZE * sampling_rate)))
                select_idx = select_idx[:-n_trims]
                labels = labels[:-n_label_trims]
            print("after remove extra labels: {}, {}".format(select_idx.shape, labels.shape))
        # Remove movement and unknown stages if any
        raw_ch = raw_ch_df.values[select_idx]
        # Verify that we can split into 30-s epochs
        if len(raw_ch) % (EPOCH_SEC_SIZE * sampling_rate) != 0:
            raise Exception("Something wrong")
        n_epochs = len(raw_ch) / (EPOCH_SEC_SIZE * sampling_rate)
        # Get epochs and their corresponding labels
        x = np.asarray(np.split(raw_ch, n_epochs)).astype(np.float32)
        y = labels.astype(np.int32)
        assert len(x) == len(y)
        # Select on sleep periods
        w_edge_mins = 30
        nw_idx = np.where(y != stage_dict["W"])[0]
        start_idx = nw_idx[0] - (w_edge_mins * 2)
        end_idx = nw_idx[-1] + (w_edge_mins * 2)
        if start_idx < 0: start_idx = 0
        if end_idx >= len(y): end_idx = len(y) - 1
        select_idx = np.arange(start_idx, end_idx+1)
        print("Data before selection: {}, {}".format(x.shape, y.shape))
        x = x[select_idx]
        y = y[select_idx]
        print("Data after selection: {}, {}".format(x.shape, y.shape))
        # Save
        filename = ntpath.basename(psg_fnames[i]).replace("-PSG.edf", ".npz")
        save_dict = {
            "x": x, 
            "y": y, 
            "fs": sampling_rate,
            "ch_label": select_ch,
            "header_raw": h_raw,
            "header_annotation": h_ann,
        }
        print('s    a    v    e dct')
        print(save_dict)
        np.savez(os.path.join('C:/Users/pratik/Desktop/DM Project/sleep-edf', filename), **save_dict)
        print("\n=======================================\n")

if __name__ == "__main__":

    main()

Extracting EDF parameters from C:\Users\pratik\Desktop\DM Project\sleep-edf-database-expanded-1.0.0\sleep-cassette\SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7949999  =      0.000 ... 79499.990 secs...
what is sfreq
100.0
Converting "time" to "<class 'numpy.int64'>"...
read_header
read_header
record self
read_record
read raw record
convert record
Include onset:0.0, duration:30630.0, label:0 (Sleep stage W)
Include onset:30630.0, duration:120.0, label:1 (Sleep stage 1)
Include onset:30750.0, duration:390.0, label:2 (Sleep stage 2)
Include onset:31140.0, duration:30.0, label:3 (Sleep stage 3)
Include onset:31170.0, duration:30.0, label:2 (Sleep stage 2)
Include onset:31200.0, duration:150.0, label:3 (Sleep stage 3)
Include onset:31350.0, duration:30.0, label:3 (Sleep stage 4)
Include onset:31380.0, duration:60.0, label:3 (Sleep stage 3)
Include onset:31440.0, duration:60.0, label:3 (Sleep stage 4)
Include onset:315

after remove unwanted: (7950000,)
before intersect label: (7950000,)
after intersect label: (7950000,)
Data before selection: (2650, 3000, 1), (2650,)
Data after selection: (841, 3000, 1), (841,)
s    a    v    e dct
{'x': array([[[  8.111356 ],
        [ 17.488646 ],
        [ 21.239561 ],
        ...,
        [-10.361905 ],
        [-11.112088 ],
        [ -2.1098902]],

       [[-10.736997 ],
        [-11.393407 ],
        [ -4.4542127],
        ...,
        [ 58.84249  ],
        [ 48.339928 ],
        [ 53.684982 ]],

       [[ 61.37436  ],
        [ 38.68132  ],
        [ 49.558975 ],
        ...,
        [ 33.617584 ],
        [ 34.367767 ],
        [ 31.835897 ]],

       ...,

       [[ -9.330403 ],
        [ -2.6725276],
        [ -2.4849818],
        ...,
        [ 24.14652  ],
        [ 26.678389 ],
        [ 23.115019 ]],

       [[ 24.709158 ],
        [ 22.271063 ],
        [ 27.991209 ],
        ...,
        [ 18.895239 ],
        [ 24.42784  ],
        [ 15.8945055]],


after remove unwanted: (8487000,)
before intersect label: (8487000,)
after intersect label: (8487000,)
Data before selection: (2829, 3000, 1), (2829,)
Data after selection: (1127, 3000, 1), (1127,)
s    a    v    e dct
{'x': array([[[ -5.1516485 ],
        [ -7.9010987 ],
        [  4.624176  ],
        ...,
        [  9.308425  ],
        [  7.271795  ],
        [  0.5509158 ]],

       [[  7.5772896 ],
        [  4.9296703 ],
        [-11.770696  ],
        ...,
        [ 18.676924  ],
        [ 10.021245  ],
        [ 15.927472  ]],

       [[  4.827839  ],
        [ 12.2615385 ],
        [ 30.69304   ],
        ...,
        [  6.5589743 ],
        [ 10.530403  ],
        [  5.7443223 ]],

       ...,

       [[ -0.8747253 ],
        [  3.9113553 ],
        [ -9.530403  ],
        ...,
        [-10.446886  ],
        [ -5.9663005 ],
        [-15.232967  ]],

       [[-19.102564  ],
        [-21.342857  ],
        [-27.758242  ],
        ...,
        [ 39.654213  ],
        [ 23.0556

In [None]:
b = np.load('C:/Users/pratik/Desktop/DM Project/sleep-edf/SC4001E0.npz')

In [None]:
print(b)