In [None]:
import glob
import numpy as np
import os
import scipy
import json
from ismember import ismember
import random
import pyedflib
import mne

In [None]:
# This file is dedicated to convert ground truth labels of all 3 datasets into 30 seconds w.r.t. AASM guidelines
# CAP, 30 sec, R&K. Labels are separated into two files (two halfs of the recordings)

# WESA dataset, 20sec, AASM. REM labeled as 5. Also contains 6 as an indicator of end of file or bad channel. 
# Sometimes I meet 4 in the file (just a couple of times), which is there by mistake. In this case, I convert it to NREM...

# MASS dataset 
# SS1 --> no changes required (maybe also convert to .txt)
# SS2 --> 20sec, R&K
# SS3 --> no changes required (maybe also convert to .txt)
# SS4 --> 20sec, R&K
# SS5 -->20sec, R&K
# Important that in MASS dtaset there are epochs that labeled as '?', need to check what it means
# Leave '?' for now

# Sleep Stages
# Wake --> 0
# N1 --> 1
# N2 --> 2
# N3 --> 3
# REM --> 4

In [None]:
def cap_gt_processing(pat, gt_files, save_path):
    # extract ground truth
    gt = []
    gt = [x for x in gt_files if x.endswith(str(pat) + ".mat")]

    hypno0 = scipy.io.loadmat(gt[0])
    hypno = np.array(hypno0['hyp'][:,0])
    
    # convert true labels to the AASM system
    for jj in range(len(hypno)):
        if hypno[jj] == 4:
            hypno[jj] = 3
        elif hypno[jj] == 5:
            hypno[jj] = 4

    hypno = hypno.flatten()
    hypno_new = [str(x) for x in hypno]

    # save new ground truth in txt files
    with open(save_path, 'w') as f:   
        for items in hypno_new:
            f.write('%s\n' %items)

def convert_20s_to_30s(file):
    # lists to store the name of the file and expert labels that were used for gt
    name = []
    exp = []
    labeled_exp = []
    label_30s = []
    # list to store the cases where all three 20sec windows are different
    change = []
    
    # load the file
    f_load = scipy.io.loadmat(file)

    for jj in f_load.keys():
        if "sleepStage_score" in jj:
            labeled_exp.append(jj)

    if len(labeled_exp) == 1:
        key = labeled_exp[0]
    else:
        v = random.randint(0,len(labeled_exp)-1)
        key = labeled_exp[v]
        
    label_20s = f_load[key][0]
    total_len = (len(label_20s)-1)*20

    new_num_epochs = np.round(total_len/30)

    # last value is 6, which indicated eof and i don't take in account (WESA)
    for jjj in range(0, len(label_20s)-1, 3):
        if jjj+3 <= len(label_20s)-2:
            label_temp = label_20s[jjj:jjj+3]
        else:
            eof = len(label_20s)-1
            label_temp = label_20s[jjj:eof]

        # store information if all 3 stages in a row are labeled differently (rather exception or artifact) 
        if np.sum(np.diff(label_temp)) > 1:
            change.append(jjj)

        if len(label_temp) == 3:
            label_30s.append(label_temp[0])
            label_30s.append(label_temp[2])
        elif len(label_temp) == 0:
            print('Done')
        else:
            label_30s.append(label_temp[0])
            
    assert(len(label_30s) == new_num_epochs)

    return key, change, label_30s

def edf_to_txt(file):
    data = mne.read_annotations(file)
    stages = data.description
    onset = data.onset[0]

    for jj in range(len(stages)):
        stages[jj] = stages[jj].replace('Sleep stage ', '')

    return onset, stages 
    
    

CAP dataset

In [None]:
# CAP HC
# extract mat files with ground truth
gt_path = "Path\\to\\load\\the\\data"
gt_files = glob.glob(gt_path)

# create list of rbd patients that exist in CAP dataset
hc_pat = ["n1", "n2", "n3", "n5", "n10", "n11", ]

for j in range(len(hc_pat)):

    pat = hc_pat[j]
    save_path = os.path.join("Path\\to\\store\\the\\new\\30secodEpochs", f'{pat}.txt')

    cap_gt_processing(pat, gt_files, save_path)

In [None]:
# CAP
# extract mat files with ground truth
gt_path = "Path\\to\\load\\the\\data"
gt_files = glob.glob(gt_path)

# create list of rbd patients that exist in CAP dataset
rbd_pat = ["rbd1", "rbd2", "rbd3", "rbd4", "rbd5", "rbd6", "rbd7", "rbd8", "rbd9", "rbd10", "rbd12",
           "rbd13", "rbd14", "rbd15", "rbd16", "rbd17", "rbd18", "rbd19", "rbd20", "rbd21", "rbd22"]

for j in range(len(rbd_pat)):

    pat = rbd_pat[j]
    save_path = os.path.join("Path\\to\\store\\the\\new\\30secodEpochs", f'{pat}.txt')

    cap_gt_processing(pat, gt_files, save_path)


WESA dataset

In [None]:
# WESA
path_wesa = "Path\\to\\load\\the\\data"
path_to_save = "Path\\to\\store\\the\\new\\30secodEpochs"
sham_path = os.path.join(path_wesa, 'sham\\*.mat')
verum_path = os.path.join(path_wesa, 'verum\\*.mat')

sham_files = glob.glob(sham_path)
verum_files = glob.glob(verum_path)

experts = ["sleepStage_score_E1", "sleepStage_score_E2", "sleepStage_score_E3", "sleepStage_score_E4"]

# The file can contain labels for 1 or more experts. In case of 1 expert it is clear. If more than one, randomly of the experts was chosen. 
name = []
exp = []
jumps = []

for j in range(len(verum_files)):

    # read a file
    file = verum_files[j]
    # only for WESA dataset
    idx = verum_files[0].find('WESA_EEG')
    wesa_name = file[idx-1:-4]

    expert, change, label_30s = convert_20s_to_30s(file)

    jumps.append(file)
    jumps.append(change)
    name.append(file)
    exp.append(expert)
    
    # WESA
    # check if there are any 4 (a mistake) and convert to 3 then 
    # convert 5 to 4 (REM)

    for i in range(len(label_30s)):
        if label_30s[i] == 4:
            label_30s[i] = 3
        elif label_30s[i] == 5:
            label_30s[i] = 4
        elif label_30s[i] == 6:
            print('There is 6 in the file')
    
    label_30s_new = [str(x) for x in label_30s]

    # save new ground truth in txt files
    # np.savetxt(save_path, hypno_new, delimiter = ',')
    with open(path_to_save + f"verum\\{wesa_name}.txt", 'w') as f: 
        for items in label_30s_new:
            f.write('%s\n' %items)

# save a list of experts whose scores were used as gt to .txt file
with open(path_to_save + "exp_info_verum.txt", "w")as f:
    for jj in range(len(name)):
        f.write('%s, %s\n' %(name[jj], exp[jj]))

MASS dataset

In [None]:
# MASS SS1 and SS3
path = "Path\\to\\load\\the\\data"
gt_f = glob.glob(path)

path_to_save = "Path\\to\\store\\the\\new\\30secodEpochs"
folder = 'hypno_gt\\'
symbols = len(folder)

for jj in range(len(gt_f)):
    file = gt_f[jj]
    ix = file.find('hypno_gt\\')

    # extract stages
    onset_time, stages = edf_to_txt(file)

    # N1, N2, N3 are labeled as 1, 2, 3. Wake is labeled as W (change to 0)  and REM is labeled as R (change to 4). '?' leave for now
    for jjj in range(len(stages)):
        if stages[jjj] == 'W':
            stages[jjj] = '0'
        elif stages[jjj] == 'R':
            stages[jjj] = '4'

    with open(path_to_save + f"\\{file[ix+symbols:-4]}.txt", 'w') as f:  
        f.write('%s %s\n' %('Onset time is ', onset_time))
        for items in stages:
            f.write('%s\n' %items)


In [None]:
# MASS SS2, SS4 and SS5
path = "Path\\to\\load\\the\\data"
gt_f = glob.glob(path)

path_to_save = "Path\\to\\store\\the\\new\\30secodEpochs"
folder = 'hypno_gt\\'
symbols = len(folder)


change = []

for jj in range(len(gt_f)):
    label_30s = []
    
    file = gt_f[jj]
    ix = file.find('hypno_gt\\')

    # extract stages
    onset_time, label_20s = edf_to_txt(file)

    total_len = (len(label_20s))*20
    new_num_epochs = np.round(total_len/30)

    
    for jjj in range(0, len(label_20s), 3):
        if jjj+3 <= len(label_20s)-2:
            label_temp = label_20s[jjj:jjj+3]
        else:
            eof = len(label_20s)
            label_temp = label_20s[jjj:eof]

        if len(label_temp) == 3:
            label_30s.append(label_temp[0])
            label_30s.append(label_temp[2])
        elif len(label_temp) == 0:
            print('Done')
        else:
            label_30s.append(label_temp[0])
            
    # N1, N2, N3 are labeled as 1, 2, 3. Wake is labeled as W (change to 0)  and ReM is labeled as R (change to 4). '?' leave for now
    for ii in range(len(label_30s)):
        if label_30s[ii] == 'W':
            label_30s[ii] = '0'
        elif label_30s[ii] == '4':
            label_30s[ii] = '3'    
        elif label_30s[ii] == 'R':
            label_30s[ii] = '4'

    with open(path_to_save + f"\\{file[ix+symbols:-4]}.txt", 'w') as f:   # sham, verum
        f.write('%s %s\n' %('Onset time is ', onset_time))
        for items in label_30s:
            f.write('%s\n' %items)

