In [None]:
""" Imports """

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import torch
import random
import wfdb

from torch.utils.data import DataLoader, Subset, SequentialSampler

from dataset_utils import get_data_paths, ECGDataset

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
""" Set configuration """

seed = 123
random.seed(seed)
torch.manual_seed(seed)

config = {
    # Dataset configs
    'fs': 256,  # Frequency of the training sequence
    'segment_length': 8,  # Length of every training sequence in seconds
    'split': [.8, .1, .1],
    
    'instance_normalization': True,
    'high_pass': True,
    'notch': True,
    
    'num_samples': 30,
}

PLOT_COLOR = 'black'

In [None]:
""" Load dataset """

data_paths = get_data_paths('data/sleep_dataset/')
random.shuffle(data_paths)

print("Found {} patients in total".format(len(data_paths)))

# Apply same split as during training
split1 = int(np.floor(config['split'][0] * len(data_paths)))
split2 = int(np.floor((config['split'][0] + config['split'][1]) * len(data_paths)))

train_paths = data_paths[:split1]
val_paths = data_paths[split1:split2]
test_paths = data_paths[split2:]

train_dataset = ECGDataset(
    data_paths=train_paths,
    fs=config['fs'],
    seg_length=config['segment_length'],
    instance_normalization=config['instance_normalization'],
    high_pass=config['high_pass'],
    notch=config['notch'],
)

val_dataset = ECGDataset(
    data_paths=val_paths,
    fs=config['fs'],
    seg_length=config['segment_length'],
    instance_normalization=config['instance_normalization'],
    high_pass=config['high_pass'],
    notch=config['notch'],
)

test_dataset = ECGDataset(
    data_paths=test_paths,
    fs=config['fs'],
    seg_length=config['segment_length'],
    instance_normalization=config['instance_normalization'],
    high_pass=config['high_pass'],
    notch=config['notch'],
)

train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
test_loader = DataLoader(test_dataset)

In [None]:
""" Visualize train_dataset """

mpl.rcParams['text.color'] = PLOT_COLOR
mpl.rcParams['axes.labelcolor'] = PLOT_COLOR
mpl.rcParams['xtick.color'] = PLOT_COLOR
mpl.rcParams['ytick.color'] = PLOT_COLOR

for i_sample, sample in enumerate(train_loader):
    if i_sample >= config['num_samples']:
        break
    signal, _ = sample
    signal = signal[0, 0].numpy()
    
    plt.figure(figsize=(20, 4))
    plt.plot(signal)
    # plt.title(i_sample)
    plt.show()

In [None]:
""" Count patients """

# Training
num_train = len(train_paths)
num_male_train, num_female_train = 0, 0
for path in train_paths:
    gender = path[0].split('/')[-2].split('-')[-1]
    if gender == 'M':
        num_male_train += 1
    else:
        num_female_train += 1

num_female_train_percent = num_female_train / num_train * 100
num_male_train_percent = num_male_train / num_train * 100

print("Number of patients for training: {}".format(num_train))
print("Male: {} ({:.2f}%), female: {} ({:.2f}%)".format(
    num_male_train, num_male_train_percent, num_female_train, num_female_train_percent))

# Validation
num_val = len(val_paths)
num_male_val, num_female_val = 0, 0
for path in val_paths:
    gender = path[0].split('/')[-2].split('-')[-1]
    if gender == 'M':
        num_male_val += 1
    else:
        num_female_val += 1

num_female_val_percent = num_female_val / num_val * 100
num_male_val_percent = num_male_val / num_val * 100

print("Number of patients for validation: {}".format(num_val))
print("Male: {} ({:.2f}%), female: {} ({:.2f}%)".format(
    num_male_val, num_male_val_percent, num_female_val, num_female_val_percent))

# Test
num_test = len(test_paths)
num_male_test, num_female_test = 0, 0
for path in test_paths:
    gender = path[0].split('/')[-2].split('-')[-1]
    if gender == 'M':
        num_male_test += 1
    else:
        num_female_test += 1

num_female_test_percent = num_female_test / num_test * 100
num_male_test_percent = num_male_test / num_test * 100

print("Number of patients for test: {}".format(num_test))
print("Male: {} ({:.2f}%), female: {} ({:.2f}%)".format(
    num_male_test, num_male_test_percent, num_female_test, num_female_test_percent))