In [41]:
import sys

sys.path.insert(0, "../..")
from pathlib import Path
from src.data import dataloader
from src.data import make_dataset

In [42]:
eegPath = make_dataset.get_eeg()

2022-06-12 13:46:43.183 | INFO     | src.data.make_dataset:get_eeg:18 - Data is downloaded to ../../data/raw/datasets/eeg.


In [43]:
dataset = dataloader.BaseDataset(datapath = eegPath) #make dataset
print(len(dataset))

24


In [44]:
dataset.__getitem__(0)[0].shape #get shape of 1st item 

torch.Size([188, 14])

First the BaseDataIterator is used. This iterator automatically changes the window size if its larger than the smallest chunck in BaseDataSet. It also takes another chunck if the chunck does not have enough lines. For these reasons we always have the same shape tensors in one batch. The downside is that we throw away information by doing this.

In [45]:
loader = dataloader.BaseDataIterator(dataset = dataset, window_size=10, batchsize=32)
iterator = iter(loader)
batch = next(loader)
batch[0].shape

torch.Size([32, 10, 14])

If trying to use a window size of 21, the BaseDataIterator will tell and set to 21:

In [49]:
loader = dataloader.BaseDataIterator(dataset = dataset, window_size=25, batchsize=32)
iterator = iter(loader)
batch = next(loader)
batch[0].shape


Maximum window length is 21, setting window length to 21. Use PaddedDataIterator for bigger window size


torch.Size([32, 21, 14])

In [50]:
batch

(tensor([[[4312.3101, 4022.0500, 4278.4600,  ..., 4284.6201, 4612.3101,
           4368.2100],
          [4304.1001, 4016.9199, 4273.8501,  ..., 4282.5601, 4606.1499,
           4364.1001],
          [4303.0801, 4016.9199, 4270.7700,  ..., 4276.9199, 4602.0498,
           4362.5601],
          ...,
          [4277.4399, 3990.7700, 4246.6699,  ..., 4257.9502, 4591.7900,
           4339.4902],
          [4284.6201, 3991.7900, 4251.2798,  ..., 4267.1802, 4596.4102,
           4350.7700],
          [4287.6899, 3997.4399, 4260.0000,  ..., 4274.3599, 4597.9502,
           4350.7700]],
 
         [[4238.4600, 3994.3601, 4236.9199,  ..., 4253.8501, 4533.3301,
           4261.5400],
          [4229.7402, 3989.2300, 4234.3599,  ..., 4248.7202, 4530.2598,
           4255.3799],
          [4224.6201, 3990.2600, 4232.3101,  ..., 4245.6401, 4527.1802,
           4250.7700],
          ...,
          [4241.0298, 4007.1799, 4233.8501,  ..., 4246.6699, 4541.0298,
           4284.6201],
          [4244.1

Another way of using the data is padding the data if the window size is too big. In the PaddedDataIterator the data is padded if the sequence is too short. For this reason we can use a larger window size than the size of the shortest observation. The padded data iterator also has an extra argument, min_nr_lines which specifies how many items there should at least be in an observation, if this condition does not equal to true, the previous window will be taken. 

In [51]:
loaderPadded = dataloader.PaddedDataIterator(dataset = dataset, window_size=40, batchsize=32, min_nr_lines= 5)
iterator = iter(loaderPadded)
batch = next(loaderPadded)
batch = next(loaderPadded)

batch[0].shape

torch.Size([32, 40, 14])

In [39]:
batch

(tensor([[[4299.4902, 3995.8999, 4260.0000,  ..., 4281.0298, 4583.0801,
           4349.7402],
          [4300.0000, 3995.3799, 4266.1499,  ..., 4285.1299, 4587.1802,
           4351.7900],
          [4299.4902, 3993.8501, 4265.6401,  ..., 4286.6699, 4587.6899,
           4349.7402],
          ...,
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000]],
 
         [[4296.9199, 3991.7900, 4264.1001,  ..., 4283.5898, 4613.3301,
           4364.6201],
          [4297.9502, 3992.8201, 4266.6699,  ..., 4280.0000, 4617.9502,
           4364.6201],
          [4298.9702, 3993.8501, 4265.6401,  ..., 4281.5400, 4621.0298,
           4363.0801],
          ...,
          [4296.9199, 3996.9199, 4269.7402,  ..., 4279.4902, 4610.7700,
           4356.9199],
          [4299.4