In [1]:
from torch.utils.data import DataLoader
import sys
sys.path.append("..")
from utils.hdf5_data_split_generator import HDF5DataSplitGenerator

In [7]:
lr_channel_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
hr_channel_names = ['Fp1', 'Fz', 'F3', 'F7', 'FT9', 'FC5', 'FC1', 'C3', 'T7', 'TP9', 'CP5', 'CP1', 'Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'TP10', 'CP6', 'CP2', 'C4', 'T8', 'FT10', 'FC6', 'FC2', 'F4', 'F8', 'Fp2', 'AF7', 'AF3', 'AFz', 'F1', 'F5', 'FT7', 'FC3', 'C1', 'C5', 'TP7', 'CP3', 'P1', 'P5', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'P6', 'P2', 'CPz', 'CP4', 'TP8', 'C6', 'C2', 'FC4', 'FT8', 'F6', 'AF8', 'AF4', 'F2', 'FCz']

train_dataset = HDF5DataSplitGenerator(
    dataset_type="train",
    dataset_split="70/25/5",
    eeg_epoch_mode="around_evoked_event",
    fixed_length_duration=3,
    duration_before_onset=0.05,
    duration_after_onset=0.6,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    # subject=1,
    # session=1
)
    
val_dataset = HDF5DataSplitGenerator(
    dataset_type="val",
    dataset_split="70/25/5",
    eeg_epoch_mode="around_evoked_event",
    fixed_length_duration=3,
    duration_before_onset=0.05,
    duration_after_onset=0.6,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    # subject=1,
    # session=1
)

test_dataset = HDF5DataSplitGenerator(
    dataset_type="test",
    dataset_split="70/25/5",
    eeg_epoch_mode="around_evoked_event",
    fixed_length_duration=3,
    duration_before_onset=0.05,
    duration_after_onset=0.6,
    lr_channel_names=lr_channel_names,
    hr_channel_names=hr_channel_names,
    # subject=1,
    # session=1
)
    
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
    
val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [6]:
len(train_loader), len(val_loader), len(test_loader)

(82, 350, 70)

In [8]:
train_dataset[0]

{'index_before_random_split': np.int64(11912),
 'sfreq': np.float64(512.0),
 'mode': 'around_evoked_event',
 'total_duration': 0.65,
 'subject': np.int32(3),
 'session': np.int32(2),
 'sample_number': np.int32(353700),
 'lo_res': array([[ 0.23202977,  0.20920864,  0.14622381, ...,  0.00134021,
         -0.15771133, -0.34955075],
        [ 0.01221056, -0.02503948, -0.02038559, ...,  0.3969575 ,
          0.38693252,  0.23240507],
        [ 0.1607427 ,  0.23098284,  0.22597659, ...,  0.94048464,
          0.9919666 ,  0.8506374 ],
        ...,
        [ 0.36972958,  0.28939837,  0.31899136, ..., -0.8681541 ,
         -0.9038392 , -0.95805675],
        [-0.06118771, -0.19649881, -0.24224031, ..., -0.67309356,
         -0.75212777, -0.5791296 ],
        [ 0.25438666,  0.22789586,  0.20551479, ..., -0.7824524 ,
         -0.7823532 , -1.0273788 ]], dtype=float32),
 'hi_res': array([[ 0.22763331,  0.17271604,  0.0899248 , ..., -0.02852346,
         -0.15296294, -0.31382313],
        [ 0.95442

In [9]:
for batch in train_loader:
    print(batch)
    break

{'index_before_random_split': tensor([15645, 37958, 30889, 33205, 38299,   199,  4266, 43883, 16668, 44145,
        12024, 22598, 23097, 15023, 11618, 27193, 25504, 36730, 24712, 39117,
        20831, 27668, 18128, 11137, 24034, 29408,  6451, 22046, 18615, 26143,
        25806,  7154]), 'sfreq': tensor([512., 512., 512., 512., 512., 512., 512., 512., 512., 512., 512., 512.,
        512., 512., 512., 512., 512., 512., 512., 512., 512., 512., 512., 512.,
        512., 512., 512., 512., 512., 512., 512., 512.], dtype=torch.float64), 'mode': ['around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_event', 'around_evoked_eve

In [12]:
for batch in train_loader:
    for img_id in batch['evoked_event_id']:
        if not (0 <= img_id - 1 < 960):
            print("Error: ", img_id)
            break

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 171, in collate
    {
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 172, in <dictcomp>
    key: collate(
         ^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 285, in collate_numpy_array_fn
    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\utils\data\_utils\collate.py", line 270, in collate_tensor_fn
    storage = elem._typed_storage()._new_shared(numel, device=elem.device)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\storage.py", line 1198, in _new_shared
    untyped_storage = torch.UntypedStorage._new_shared(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "s:\PolySecLabProjects\eeg-image-decode\env\Lib\site-packages\torch\storage.py", line 410, in _new_shared
    return cls._new_using_filename_cpu(size)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Couldn't open shared file mapping: <torch_30184_1417690782_1>, error code: <1455>


In [5]:
train_dataset.get_evoked_event_metadata_for_item(5752)

[np.int32(935)]

In [11]:
train_dataset.train_indices, train_dataset.get_evoked_event_ids_for_item(13206)

(array([ 5752, 13206, 10755, ...,  7859,  8297, 10659]),
 [[8, 1, 747534, 801],
  [8, 1, 747936, 795],
  [8, 1, 748339, 814],
  [8, 1, 748737, 741]])

In [7]:
for batch in train_loader:
    for index in batch['actual_index']:
        print(index)
        break
    break

tensor(10773)


In [7]:
from tqdm import tqdm

In [10]:
progress_bar = tqdm(train_loader, desc=f"Epoch {1}/{2}", leave=False)

# for i, batch in enumerate(progress_bar):
#     lo_res = batch['lo_res']
#     hi_res = batch['hi_res']

#     # if i % 10 == 0:
#     progress_bar.set_postfix(loss=f"{2.79}", mae=f"{3.0}")

len(progress_bar)


Epoch 1/2:   0%|          | 0/1006 [00:00<?, ?it/s]

1006

In [6]:
for batch in train_loader:
    # print(batch['lo_res'].shape)
    for index in batch['actual_index']:
        evoked_ids = train_dataset.get_evoked_event_ids_for_item(index)
        print(evoked_ids)
    break

# train_dataset[2]

[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
[]
