In [5]:
import os
os.chdir('..')

In [7]:
from utils.spliced_dataset import *

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
sample_rate = 256
window_size = sample_rate * 60 # 取一分钟为窗口长度
batch_size = 64

# Default train 70% : valid 10% : test 20%

num_samples = 8000

for patient in range(1, 24+1):
    print(f"Testing with CHB-MIT patient chb{patient:02d}")
    data_dir = f'<path to your chbmit dataset>/chb{patient:02d}'
    json_path = f'<path to your chbmit segment dir>/chb{patient:02d}/segment_info.json'
    spdst, _ = get_regression_datasets(data_dir, json_path, window_size, sample_rate)

    # 1. Check to see if imported SplicedDataset works well
    n_chans = spdst[0][0].shape[0]
    print(f"SplicedDataset is of length={len(spdst)}, each idx yields (X=Tensor {spdst[0][0].shape}, y={type(spdst[0][1])})")
    
    train_ds, valid_ds, test_ds = spdst.split_by_proportion()

    # Use Callable Lambda Object to define different step strategy on different category labels.
    step_func = lambda y: 4*window_size if y == classification_label.inter else window_size//2
    train_iter = DataLoader(train_ds, batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(train_ds, True, step_size=step_func)) # Expected interval 240.0s or 30.0s
    valid_iter = DataLoader(valid_ds, batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(valid_ds, True, sample_rate//32)) # Expected interval 0.03125s
    test_iter  = DataLoader(test_ds , batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(test_ds, True, sample_rate)) # Expected interval 1.0s
    
    # 2. Check to see if DataLoader with custom sampler works well
    x, y = next(train_iter); print(x.shape, y)
    x, y = next(valid_iter); print(x.shape, y)
    x, y = next(test_iter); print(x.shape, y)

    del x, y, train_iter, valid_iter, test_iter, train_ds, valid_ds, test_ds # Trigger GC in advance

    # 3. Check Regression* Class as well
    spdst, _ = get_classification_datasets(data_dir, json_path, window_size, sample_rate)
    seg_mask = 'r' * len(spdst.datasets) # Use all data for training
    train_ds, _, _ = spdst.split_by_seg(seg_mask)
    weights = {classification_label.pre: 1, 
               classification_label.onset: 1, 
               classification_label.inter: 2} # Pre:Onset:Inter = 1:1:2
    train_iter = DataLoader(train_ds, batch_size=batch_size, num_workers=8, # TODO RuntimeError: number of categories cannot exceed 2^24
                            sampler=AugmentRandomSampler(data_source=spdst, weights=weights, num_samples=num_samples))
    x, y = next(train_iter); print(x.shape, y)

    print()
