In [None]:
import os
import os.path as osp
import pandas as pd
import h5py
import glob
import numpy as np
from mne.filter import resample
from scipy.signal import resample_poly
import scipy.io

# Study U-Time h5 file structure

In [None]:
h5_file = '/home/akara/Workspace/U-Time/processed/mass_ss3/01-03-0001 PSG/01-03-0001 PSG.h5'
# h5_file = '/home/akara/Workspace/U-Time/processed/mass_ss1/01-01-0001 PSG/01-01-0001 PSG.h5'

In [None]:
f = h5py.File(h5_file, 'r')

In [None]:
f.keys()

In [None]:
f['channels'].keys()

In [None]:
print(f.attrs.keys())
for k in f.attrs.keys():
    print(k, f.attrs[k])

In [None]:
f['channels']['F4-CLE']

# Convert

In [None]:
isruc_src_dir = '/home/akara/Workspace/sleep_data/isruc/subgroup_1'
output_dir = '/home/akara/Workspace/U-Time/processed/isruc_sg1'
isruc_fs = 200
target_fs = 128
select_expert = 1
select_channels = [
    'F3_A2', 'C3_A2',
    'F4_A1', 'C4_A1',
    'O1_A2', 'O2_A1',
    'ROC_A1', 'LOC_A2']
ann_dict = {
    0: 'Sleep stage W',
    1: 'Sleep stage 1',
    2: 'Sleep stage 2',
    3: 'Sleep stage 3',
    4: 'Sleep stage R',
    5: 'Sleep stage ?',
}

In [None]:
mat_files = glob.glob(os.path.join(isruc_src_dir, 'subject*.mat'))
for subject_file in mat_files:
    print(subject_file)
    sid = os.path.basename(subject_file).split('.')[0]
    sid = int(sid.replace('subject', ''))
    ann_file = os.path.join(
        os.path.dirname(subject_file), 
        'annotations', 
        f"{sid}_{select_expert}.txt")
    mat = scipy.io.loadmat(subject_file)
    signals = []
    n_epochs = -1
    for ch in select_channels:
        if n_epochs == -1:
            n_epochs = mat.get(ch).shape[0]
        else:
            assert n_epochs == mat.get(ch).shape[0]
        signals.append(mat.get(ch).reshape(-1))
        print(f'extracted {ch} channel {mat.get(ch).shape}')

    # Signals
    resampled_x = resample_poly(
        np.swapaxes(np.array(signals),0,1).astype(np.float64),
        target_fs,
        isruc_fs,
        axis=0)
    
    # Annotations
    y_df = pd.read_csv(ann_file, header=None, names=['ann'])
    y = y_df['ann'].values
    y = y[:n_epochs]
    # Replace the REM label from 5 to 4
    y[y==5] = 4
    print(np.unique(y,return_counts=True))

    # Saving signals to h5file
    out_file = osp.basename(subject_file).split('.')[0]
    out_subject_dir = osp.join(output_dir, f'{out_file} PSG')
    out_signal_file = os.path.join(out_subject_dir, f'{out_file} PSG.h5')
    out_ann_file = os.path.join(out_subject_dir, f'{out_file} Annotations.ids')
    if not osp.isdir(out_subject_dir):
        os.makedirs(out_subject_dir)
    h5file = h5py.File(out_signal_file, 'w')
    h5ch = h5file.create_group('channels')
    for ch_i, ch in enumerate(select_channels):
        h5ch.create_dataset(
            ch.replace('_','-'), 
            data=resampled_x[:,ch_i], 
            dtype='float64')
    print(h5file['channels'].keys())
    for k in h5file['channels'].keys():
        print(k, h5file['channels'][k])
    # Sanity check
    for ch in select_channels:
        x_mat = resample_poly(
            mat.get(ch).reshape(-1).astype(np.float64),
            target_fs,
            isruc_fs,
            axis=0)
        x_h5 = h5file['channels'][ch.replace('_','-')]
        assert np.array_equal(x_mat, x_h5)
    print('Pass sanity check')
    
    # Attribute for sample_rate
    h5file.attrs.create('sample_rate', target_fs)
    print(h5file.attrs.keys())
    print(f"Set sample_rate to {h5file.attrs['sample_rate']}")

    # Save h5 file
    h5file.close()
    print(f"Saved extracted signals to {out_signal_file}")

    # Saving annotations into the standard format
    print(y)
    y_tran_idx = np.where(np.diff(y) != 0)[0] + 1
    y_tran_idx = np.append(y_tran_idx, [len(y)])
    print(f"Total seconds: {resampled_x.shape[0] / target_fs}")
    with open(out_ann_file, 'w') as ann_f:
        prev_end = 0
        epoch_sec = 30
        sum_dur = 0
        for ti in range(len(y_tran_idx)):
            if ti == 0:
                start = 0
                prev_yti = 0
            else:
                start = prev_start + prev_dur
                prev_yti = y_tran_idx[ti-1]
            dur = y_tran_idx[ti] - prev_yti
            print(f"{start*epoch_sec},{dur*epoch_sec},{ann_dict[y[y_tran_idx[ti]-1]]}")
            ann_f.write(f"{start*epoch_sec},{dur*epoch_sec},{ann_dict[y[y_tran_idx[ti]-1]]}\n")
            prev_start = start
            prev_dur = dur
            sum_dur += dur
    print(f"Saved annotations to {out_ann_file}")