In [5]:
import os
os.chdir('..')

In [7]:
from utils.spliced_dataset import *

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
sample_rate = 256
window_size = sample_rate * 60 # Use 60 seconds as test window_size
batch_size = 64

# Default train 70% : valid 10% : test 20%

num_samples = 8000

for patient in range(1, 24+1):
    print(f"Testing with CHB-MIT patient chb{patient:02d}")
    data_dir = f'<path to your chbmit dataset>/chb{patient:02d}'
    json_path = f'<path to your chbmit segment dir>/chb{patient:02d}/segment_info.json'
    ds_lst, _ = get_regression_datasets(data_dir, json_path, window_size, sample_rate)

    spdst = SplicedDataset(ds_lst)
    # 1. Check to see if imported SplicedDataset works well
    n_chans = spdst[0][0].shape[0]
    print(f"SplicedDataset is of length={len(spdst)}, each idx yields (X=Tensor {spdst[0][0].shape}, y={type(spdst[0][1])})")
    
    train_ds, valid_ds, test_ds = spdst.split_by_proportion(train_ratio=0.7, valid_ratio=0.1, test_ratio=0.2)
    
    # Use Callable Lambda Object to define different step strategy on different category labels.
    step_func = lambda y: 4*window_size if y>7200 else window_size//2 # Overlap during pre/onset, non-overlap during inter

    # train_iter = DataLoader(train_ds, batch_size=batch_size, num_workers=8, 
    #                         sampler=AugmentSequentialSampler(train_ds, random_offset=True, step_size=step_func)) # Expected interval 240.0s or 30.0s
    
    # # step_size param can be either an pure function or a callabel object which has internal state or an integer
    # valid_iter = DataLoader(valid_ds, batch_size=batch_size, num_workers=8, 
    #                         sampler=AugmentSequentialSampler(valid_ds, random_offset=True, step_size=sample_rate//32)) # Expected interval 0.03125s
    # test_iter  = DataLoader(test_ds , batch_size=batch_size, num_workers=8, 
    #                         sampler=AugmentSequentialSampler(test_ds, random_offset=True, step_size=sample_rate)) # Expected interval 1.0s

    train_iter = SequentialGenerator(train_ds.datasets, random_offset=True)
    valid_iter = SequentialGenerator(valid_ds.datasets, random_offset=True)
    test_iter = SequentialGenerator(test_ds.datasets, random_offset=True)
    
    # 2. Check to see if DataLoader with custom sampler works well
    x, y = train_iter.send(None); print(x.shape, y); x, y = train_iter.send(step_func(y));    print(x.shape, y);
    x, y = valid_iter.send(None); print(x.shape, y); x, y = valid_iter.send(sample_rate//32); print(x.shape, y);
    x, y = test_iter.send(None);  print(x.shape, y); x, y = test_iter.send(sample_rate);      print(x.shape, y);

    del x, y, train_iter, valid_iter, test_iter, train_ds, valid_ds, test_ds, spdst, ds_lst # Trigger GC in advance

    # 3. Check Classification.* Class as well
    ds_lst, _ = get_classification_datasets(data_dir, json_path, window_size, sample_rate)
    spdst = SplicedDataset(ds_lst)
    seg_mask = 'r' * len(ds_lst) # Use all data for training: 'r' for training，'v' for valid，'t' for test，others ignore
    train_ds, _, _ = spdst.split_by_seg(seg_mask)
    expected_ratio_dict = { classification_label.pre: 1, 
                            classification_label.onset: 1, 
                            classification_label.inter: 2} # Expected Pre:Onset:Inter = 1:1:2
    
    train_iter = AugmentRandomDataLoader(ds_lst=train_ds.datasets, ratio_dict=expected_ratio_dict, batch_size=batch_size, 
                                         num_samples=num_samples, num_workers=8)

    x, y = next(iter(train_iter)); print(x.shape, y)
    y = y.tolist()
    print(sum(k == classification_label.pre for k in y), end=' ')
    print(sum(k == classification_label.onset for k in y), end=' ')
    print(sum(k == classification_label.inter for k in y))

    print()