In [1]:
import sys
import os
import torch
from torch.utils.data import IterableDataset

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

from preprocessing.transform.dataset_builder import Builder
from preprocessing.transform.probabilistic_mixing_dataset import ProbabilisticMixingDataset
from preprocessing.downloader.gift_eval import load_gifteval_dataset_wrapper

In [2]:
class PostProcessingDataset(IterableDataset):
    def __init__(self, gift_eval_datasets, window_size=512, prediction_depth=1, seed=42):
        super().__init__()

        dataset_dict = {
            str(i): (
                Builder(ds)
                .sliding_window(window_size + prediction_depth)
                .map(lambda t: (t[:window_size], t[window_size:]))
                .build()
            )
            for i, ds in enumerate(gift_eval_datasets)
        }

        self.ds = ProbabilisticMixingDataset(dataset_dict, seed=seed)
        
    def __iter__(self):
        for item in iter(self.ds):
            yield item

In [3]:
ds = PostProcessingDataset([
    load_gifteval_dataset_wrapper(['bdg-2_fox']),
    load_gifteval_dataset_wrapper(['bdg-2_bear'])
])

In [4]:
it = iter(ds)

In [5]:
for i, item in enumerate(iter(ds)):
    print(i, (item[0].shape, item[1].shape))

    if i > 10:
        break

0 (torch.Size([512, 1]), torch.Size([1, 1]))
1 (torch.Size([512, 1]), torch.Size([1, 1]))
2 (torch.Size([512, 1]), torch.Size([1, 1]))
3 (torch.Size([512, 1]), torch.Size([1, 1]))
4 (torch.Size([512, 1]), torch.Size([1, 1]))
5 (torch.Size([512, 1]), torch.Size([1, 1]))
6 (torch.Size([512, 1]), torch.Size([1, 1]))
7 (torch.Size([512, 1]), torch.Size([1, 1]))
8 (torch.Size([512, 1]), torch.Size([1, 1]))
9 (torch.Size([512, 1]), torch.Size([1, 1]))
10 (torch.Size([512, 1]), torch.Size([1, 1]))
11 (torch.Size([512, 1]), torch.Size([1, 1]))
