In [86]:
from typing import Optional, Callable, List
import pandas as pd
import numpy as np
from gluonts.dataset.loader import TrainDataLoader, Batch
from functools import partial
from gluonts.mx.batchify import batchify
from gluonts.dataset import DataBatch, Dataset
from gluonts.itertools import Cyclic, IterableSlice
from gluonts.transform import AdhocTransform, Transformation, Identity

In [37]:
dataset1 = [
    {
        "start": pd.Period("2022-03-04 00", freq="H") + 47 * k,
        "target": k * np.arange(1000, dtype=np.float32),
    }
    for k in range(10)
]

dataset2 = [
    {
        "start": pd.Period("2022-03-04 00", freq="H") + 47 * k,
        "target": k * np.arange(1000, dtype=np.float32),
    }
    for k in range(20)
]

datasets = [dataset1, dataset2]

In [23]:
train_data_loader = TrainDataLoader(dataset1, 
                                    batch_size=4, 
                                    stack_fn=partial(batchify))
step = 0
for data in train_data_loader:
    print(data['target'].shape)
    step += 1
    if step > 5:
        break

(4, 1000)
(4, 1000)
(4, 1000)
(4, 1000)
(4, 1000)
(4, 1000)


In [116]:
class MetaTrainDataLoader:

    def __init__(self, 
                 datasets: List[Dataset], 
                 batch_size: int, 
                 prob: np.array,  
                 stack_fn: Callable, 
                 transform: Transformation = Identity(),
                 data_loaders: Optional[List[TrainDataLoader]]= None, 
                 num_batches_per_epoch: Optional[int] = None,     
                 shuffle_buffer_length: Optional[int] = None,
                 random_state: np.random.RandomState = None
):
        self.datasets = datasets
        self.batch_size = batch_size
        self.prob = prob
        self.stack_fn = stack_fn
        self.transform = transform
        self.num_batches_per_epoch = num_batches_per_epoch
        self.shuffle_buffer_length = shuffle_buffer_length
        self.dataset_loader_mapping = {}
        self.random_state = random_state

        for i in range(len(datasets)):
            if data_loaders:
                self.dataset_loader_mapping[i] = data_loaders[i]
            else:
                self.dataset_loader_mapping[i] = TrainDataLoader(self.datasets[i],
                                                                 transform=self.transform,
                                                                 batch_size=self.batch_size,
                                                                 stack_fn = self.stack_fn, 
                                                                 num_batches_per_epoch=self.num_batches_per_epoch, 
                                                                 shuffle_buffer_length=self.shuffle_buffer_length)

    def __iter__(self):
        return self
    
    def __next__(self):
        if not self.random_state:
            self.random_state = np.random.RandomState()
        dataset_ind = self.random_state.choice(range(len(self.datasets)), p=self.prob)
        print(dataset_ind)
        loader = self.dataset_loader_mapping[dataset_ind]
        return next(iter(loader))


In [117]:
datasets = [dataset1, dataset2]
dataset_sizes = np.array([len(dataset)for dataset in datasets])
prob = dataset_sizes / np.sum(dataset_sizes)
random_state = np.random.RandomState(seed=69)
train_data_loader = MetaTrainDataLoader(datasets, 4, prob, partial(batchify), random_state=random_state)
print(prob)
step = 0
for data in train_data_loader:
    step += 1
    if step > 10:
        break

[0.33333333 0.66666667]
0
1
1
1
1
0
0
0
1
1
1
