In [7]:
import sys
import os
import torch
from datasets import load_dataset
from typing import List
from torch.utils.data import IterableDataset

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "..")))

from preprocessing.transform.dataset_builder import Builder
from preprocessing.transform.probabilistic_mixing_dataset import ProbabilisticMixingDataset

In [8]:
class GiftEvalLoadingDataset(IterableDataset):
    def __init__(self, gift_eval_shard: IterableDataset):
        super().__init__()
        self.ds = gift_eval_shard
        

    def __iter__(self):
        dataset = iter(self.ds)
        while True:
            try:
                example = next(dataset)
                
                # in the first iteration we do not care about features; only about targets
                
                for row in example["target"]:
                    result = torch.unsqueeze(
                        torch.tensor(row, dtype=torch.float32), 
                        -1
                    )
                    
                    yield torch.nan_to_num(result, nan=0)
            except StopIteration:
                return
            # except Exception as e:
            #     print(f"Skipping example due to error: {e}")
            #     continue

In [14]:
class PostProcessingDataset(IterableDataset):
    def __init__(self, gift_eval_datasets, window_size=512, prediction_depth=1, seed=42):
        super().__init__()

        dataset_dict = {
            str(i): (
                Builder(ds)
                .sliding_window(window_size + prediction_depth)
                .map(lambda t: (t[:window_size], t[window_size:]))
                .build()
            )
            for i, ds in enumerate(gift_eval_datasets)
        }

        self.ds = ProbabilisticMixingDataset(dataset_dict, seed=seed)
        
    def __iter__(self):
        for item in iter(self.ds):
            yield item

In [10]:
gift_eval_ds = load_dataset("Salesforce/GiftEvalPretrain", split="train", streaming=True)

Resolving data files:   0%|          | 0/6528 [00:00<?, ?it/s]

In [15]:
get_shard = lambda index: GiftEvalLoadingDataset(gift_eval_ds.shard(num_shards=6528, index=index))

ds = PostProcessingDataset([
    get_shard(0),
    get_shard(1)
])

In [16]:
for i, item in enumerate(iter(ds)):
    if i % 10 == 0:
        print(i, (item[0].shape, item[1].shape))
        break

0 (torch.Size([512, 1]), torch.Size([1, 1]))
