In [71]:
from typing import Optional, Callable, List, Iterable, Iterator
from dataclasses import dataclass
import pandas as pd
import numpy as np
from gluonts.dataset.loader import TrainDataLoader, Batch
from functools import partial
from gluonts.mx.batchify import batchify
from gluonts.dataset import DataBatch, Dataset
from gluonts.itertools import Cyclic, IterableSlice
from gluonts.transform import AdhocTransform, Transformation, Identity

In [62]:
dataset1 = [
    {
        "start": pd.Period("2022-03-04 00", freq="H") + 47 * k,
        "target": k * np.arange(1000, dtype=np.float32),
    }
    for k in range(2)
]

dataset2 = [
    {
        "start": pd.Period("2022-03-04 00", freq="H") + 47 * k,
        "target": k * np.arange(1000, dtype=np.float32),
    }
    for k in range(4)
]

datasets = [dataset1, dataset2]

In [63]:
train_data_loader1 = TrainDataLoader(dataset1, 
                                    batch_size=4, 
                                    stack_fn=partial(batchify))

train_data_loader2 = TrainDataLoader(dataset2, 
                                    batch_size=4, 
                                    stack_fn=partial(batchify))


In [106]:
@dataclass
class MixIterables:
    iterables: List[Iterable] 
    probabilities: List[float]
    random_state: np.random.RandomState = np.random.RandomState()

    def __post_init__(self):
        self.iterators = [iter(iterable) for iterable in self.iterables]

    def __iter__(self):
        idx = self.random_state.choice(range(len(self.iterators)), p=self.probabilities)
        print(idx)
        try:
            yield next(self.iterators[idx])
        except StopIteration as e:
            self.iterators[idx] = iter(self.iterables[idx])
            yield next(self.iterators[idx])
                


In [107]:
dataset = MixIterables([dataset1, dataset2], [0.33, 0.67])
train_data_loader = TrainDataLoader(dataset, 
                                    batch_size=4, 
                                    stack_fn=partial(batchify))

for i, data in enumerate(train_data_loader):
    print(("batch", i))
    if i >= 2:
        break

1
1
0
1
('batch', 0)
1
0
1
1
('batch', 1)
0
1
0
1
('batch', 2)


In [108]:
@dataclass
class MixIterators:
    iterators: List[Iterator] 
    probabilities: List[float]
    random_state: np.random.RandomState = np.random.RandomState()

    def __iter__(self):
        idx = self.random_state.choice(range(len(self.iterators)), p=self.probabilities)
        print(idx)
        yield next(self.iterators[idx])


In [104]:
dataset = MixIterators([iter(Cyclic(dataset1)), iter(Cyclic(dataset2))], [0.33, 0.67])
train_data_loader = TrainDataLoader(dataset, 
                                    batch_size=4, 
                                    stack_fn=partial(batchify))

for i, data in enumerate(train_data_loader):
    print(("batch", i))
    if i >= 2:
        break

1
1
1
1
('batch', 0)
1
1
0
0
('batch', 1)
1
1
1
1
('batch', 2)


### For each batch, sample it from a dataset with a specified probabilities

In [61]:
train_data_loader = MixIterables([train_data_loader1, train_data_loader2], [0.33, 0.67])
step = 0
for data in train_data_loader:
    step += 1
    if step > 4:
        break

0


### STOP READING

In [11]:
class MetaTrainDataLoader:

    def __init__(self, 
                 datasets: List[Dataset], 
                 batch_size: int, 
                 prob: np.array,  
                 stack_fn: Callable, 
                 transform: Transformation = Identity(),
                 data_loaders: Optional[List[TrainDataLoader]]= None, 
                 num_batches_per_epoch: Optional[int] = None,     
                 shuffle_buffer_length: Optional[int] = None,
                 random_state: np.random.RandomState = None
):
        self.datasets = datasets
        self.batch_size = batch_size
        self.prob = prob
        self.stack_fn = stack_fn
        self.transform = transform
        self.num_batches_per_epoch = num_batches_per_epoch
        self.shuffle_buffer_length = shuffle_buffer_length
        self.dataset_loader_mapping = {}
        self.random_state = random_state

        for i in range(len(datasets)):
            if data_loaders:
                self.dataset_loader_mapping[i] = iter(data_loaders[i])
            else:
                self.dataset_loader_mapping[i] = iter(TrainDataLoader(self.datasets[i],
                                                                 transform=self.transform,
                                                                 batch_size=self.batch_size,
                                                                 stack_fn = self.stack_fn, 
                                                                 num_batches_per_epoch=self.num_batches_per_epoch, 
                                                                 shuffle_buffer_length=self.shuffle_buffer_length))

    def __iter__(self):
        return self
    
    def __next__(self):
        if not self.random_state:
            self.random_state = np.random.RandomState()
        dataset_ind = self.random_state.choice(range(len(self.datasets)), p=self.prob)
        print(dataset_ind)
        loader = self.dataset_loader_mapping[dataset_ind]
        return next(loader)


In [12]:
datasets = [dataset1, dataset2]
dataset_sizes = np.array([len(dataset)for dataset in datasets])
prob = dataset_sizes / np.sum(dataset_sizes)
random_state = np.random.RandomState(seed=69)
train_data_loader = MetaTrainDataLoader(datasets, 4, prob, partial(batchify), random_state=random_state)
print(prob)
step = 0
for data in train_data_loader:
    step += 1
    if step > 10:
        break

[0.33333333 0.66666667]
0
1
1
1
1
0
0
0
1
1
1
