In [5]:
import os
os.chdir('..')

In [7]:
from utils.spliced_dataset import *

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
sample_rate = 256
window_size = sample_rate * 60 # 取一分钟为窗口长度
batch_size = 64

# Default train 70% : valid 10% : test 20%

num_samples = 8000

for patient in range(1, 24+1):
    print(f"Testing with CHB-MIT patient chb{patient:02d}")
    data_dir = f'<path to your chbmit dataset>/chb{patient:02d}'
    json_path = f'<path to your chbmit segment dir>/chb{patient:02d}/segment_info.json'
    spdst, _ = get_regression_datasets(data_dir, json_path, window_size, sample_rate)

    # 1. Check to see if imported SplicedDataset works well
    n_chans = spdst[0][0].shape[0]
    print(f"SplicedDataset is of length={len(spdst)}, each idx yields (X=Tensor {spdst[0][0].shape}, y={type(spdst[0][1])})")
    
    train_ds, valid_ds, test_ds = spdst.split_by_proportion(train_ratio=0.7, valid_ratio=0.1, test_ratio=0.2)

    # Use Callable Lambda Object to define different step strategy on different category labels.
    step_func = lambda y: 4*window_size if y > 7200 else window_size//2 # Overlap during pre/onset, non-overlap during inter

    train_iter = DataLoader(train_ds, batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(train_ds, random_offset=True, step_size=step_func)) # Expected interval 240.0s or 30.0s
    
    # step_size param can be either an pure function or a callabel object which has internal state or an integer
    valid_iter = DataLoader(valid_ds, batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(valid_ds, random_offset=True, step_size=sample_rate//32)) # Expected interval 0.03125s
    test_iter  = DataLoader(test_ds , batch_size=batch_size, num_workers=8, 
                            sampler=AugmentSequentialSampler(test_ds, random_offset=True, step_size=sample_rate)) # Expected interval 1.0s
    
    # 2. Check to see if DataLoader with custom sampler works well
    x, y = next(iter(train_iter)); print(x.shape, y)
    x, y = next(iter(valid_iter)); print(x.shape, y)
    x, y = next(iter(test_iter)); print(x.shape, y)

    del x, y, train_iter, valid_iter, test_iter, train_ds, valid_ds, test_ds # Trigger GC in advance

    # 3. Check Classification.* Class as well
    spdst, _ = get_classification_datasets(data_dir, json_path, window_size, sample_rate)
    seg_mask = 'r' * len(spdst.datasets) # Use all data for training: 'r' for training，'v' for valid，'t' for test，others ignore
    train_ds, _, _ = spdst.split_by_seg(seg_mask)
    weights = {classification_label.pre: 1, 
               classification_label.onset: 1, 
               classification_label.inter: 2} # Expected Pre:Onset:Inter = 1:1:2
    
    # [Deprecated] Deprecate because of our large dataset size(about 2^32) and torch.multinomial()'s limited input size
    # Please Use AugmentRandomDataLoader instead!
    # train_iter = DataLoader(train_ds, batch_size=batch_size, num_workers=8, # TODO RuntimeError: number of categories cannot exceed 2^24
    #                         sampler=AugmentRandomSampler(data_source=spdst, weights=weights, num_samples=num_samples))

    train_iter = AugmentRandomDataLoader(ds_lst=train_ds.datasets, weights=weights, batch_size=batch_size, 
                                         num_samples=num_samples, num_workers=8)

    x, y = next(iter(train_iter)); print(x.shape, y)
    y = y.tolist()
    print(sum(k == classification_label.pre for k in y), end=' ')
    print(sum(k == classification_label.onset for k in y), end=' ')
    print(sum(k == classification_label.inter for k in y))

    print()

Testing with CHB-MIT patient chb01
SplicedDataset is of length=17913883, each idx yields (X=Tensor torch.Size([22, 15360]), y=<class 'torch.Tensor'>)
torch.Size([64, 22, 15360]) tensor([10175.8008,  9935.8008,  9695.8008,  9455.8008,  9215.8008,  8975.8008,
         8735.8008,  8495.8008,  8255.8008,  8015.8008,  7775.8008,  7535.8008,
         7176.0703,  7146.0703,  7116.0703,  7086.0703,  7056.0703,  7026.0703,
         6996.0703,  6966.0703,  6936.0703,  6906.0703,  6876.0703,  6846.0703,
         6816.0703,  6786.0703,  6756.0703,  6726.0703,  6590.0039,  6560.0039,
         6530.0039,  6500.0039,  6470.0039,  6440.0039,  6410.0039,  6380.0039,
         6350.0039,  6320.0039,  6290.0039,  6260.0039,  6230.0039,  6200.0039,
         6170.0039,  6140.0039,  6110.0039,  6080.0039,  6050.0039,  6020.0039,
         5990.0039,  5960.0039,  5930.0039,  5900.0039,  5870.0039,  5840.0039,
         5810.0039,  5780.0039,  5750.0039,  5720.0039,  5690.0039,  5660.0039,
         5630.0039,  5