In [1]:
import torch
from random import shuffle
from wearsed.dataset.WearSEDDataset import WearSEDDataset

In [2]:
dataset = WearSEDDataset(mesaid_path='../wearsed/dataset/data_ids/', signals_to_read=['SpO2', 'Pleth'])

In [16]:
def get_random_start(labels, max_time, seq_length):
    start = torch.randint(0, max_time, (1,)).item()
    end = start + seq_length
    return start, labels[start:end].sum() > 0

def get_batch(signals, labels, batch_size, seq_length):
    (hypnogram, spo2, pleth) = signals
    max_time = len(labels) - seq_length

    tries = 0
    random_starts = []
    while len(random_starts) < batch_size:
        random_start, has_positive_class = get_random_start(labels, max_time, seq_length)
        if has_positive_class or tries >= batch_size // 2:
            random_starts.append(random_start)
        tries += 1
    shuffle(random_starts)

    batch_signals = []
    batch_labels = []
    for start in random_starts:
        end = start + seq_length
        seq_hypnogram = hypnogram[start:end].view((1, -1))
        seq_spo2 = spo2[start:end].view((1, -1))
        seq_pleth = pleth[start*256:end*256].view((256, -1))
        combined_signal = torch.cat([seq_hypnogram, seq_spo2, seq_pleth], dim=0)
        batch_signals.append(combined_signal)
        batch_labels.append(labels[start:end])

    return torch.stack(batch_signals), torch.stack(batch_labels)

def get_multi_batch(dataset, i, multi_batch_size, batch_size, seq_length):
    multi_batch_signals = []
    multi_batch_labels  = []
    for j in range(multi_batch_size):
        (hypnogram, spo2, pleth), event_or_not = dataset[multi_batch_size*i+j]
        batch_signal, batch_label = get_batch((hypnogram, spo2, pleth), event_or_not, batch_size, seq_length)
        multi_batch_signals.append(batch_signal)
        multi_batch_labels.append(batch_label)
    return torch.cat(multi_batch_signals), torch.cat(multi_batch_labels)

In [17]:
mb_sig, mb_lbl = get_multi_batch(dataset, 0, 4, 32, 30*60)

print(f'{mb_sig.shape=}')
print(f'{mb_lbl.shape=}')

mb_sig.shape=torch.Size([128, 258, 1800])
mb_lbl.shape=torch.Size([128, 1800])
