In [26]:
import math
import torch
from torch.utils.data.sampler import RandomSampler


class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """
    def __init__(self, dataset, batch_size, rank, gpus, shuffle=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets])

        self.number_selected_samples = int(self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets) / gpus)
        self.number_of_total_size = self.number_selected_samples*gpus
        print('number_selected_samples', self.number_selected_samples)
        print('total sample epoch', self.number_of_datasets * self.largest_dataset_size)
        self.gpus = gpus
        self.rank = rank
        self.shuffle = shuffle
        self.epoch = 0

    def __len__(self):
        return self.number_selected_samples

    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch
            g = torch.Generator()
            g.manual_seed(self.epoch)
            # indices = torch.randperm(len(self.dataset), generator=g)
        else:
            g = torch.Generator()
            g.manual_seed(0)

        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset, generator=g)
            samplers_list.append(sampler)
            cur_sampler_iterator = iter(list(sampler.__iter__())[self.rank:self.number_selected_samples:self.gpus])
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = iter(list(samplers_list[i].__iter__())[self.rank:self.number_of_total_size:self.gpus])
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)
        
        return iter(final_samples_list[:self.number_selected_samples])
    
    def set_epoch(self, epoch):
        self.epoch = epoch

from typing import Iterable
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
from torch.utils.data.dataset import Dataset
import bisect

class ConCatDatasetWithIndex(ConcatDataset):
    def __init__(self, datasets: Iterable[Dataset]) -> None:
        super().__init__(datasets)
    
    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return dataset_idx, self.datasets[dataset_idx][sample_idx]

In [58]:
class TestDataset(Dataset):
    def __init__(self, size=64):
        super().__init__()
        self.size=size
    
    def __len__(self):
        return 15
    
    def __getitem__(self, idx):
        return torch.randn((self.size // 8, 3, self.size, self.size))

In [59]:
d1 = TestDataset(64)
d2 = TestDataset(128)

In [70]:
Concat_dataset = ConCatDatasetWithIndex([d1, d2])

In [77]:
sampler = BatchSchedulerSampler(dataset=Concat_dataset, batch_size=1, rank=0, gpus=1, shuffle=True)

number_selected_samples 30
total sample epoch 30


In [80]:
dataloader = torch.utils.data.DataLoader(dataset=Concat_dataset, batch_size=1, shuffle=False, sampler=sampler)

In [81]:
for idx, i in enumerate(dataloader):
    idy, batch = i
    print(idx)
    print(idy[0].item())
    print(batch.shape)

0
tensor(0)
torch.Size([1, 8, 3, 64, 64])
1
tensor(1)
torch.Size([1, 16, 3, 128, 128])
2
tensor(0)
torch.Size([1, 8, 3, 64, 64])
3
tensor(1)
torch.Size([1, 16, 3, 128, 128])
4
tensor(0)
torch.Size([1, 8, 3, 64, 64])
5
tensor(1)
torch.Size([1, 16, 3, 128, 128])
6
tensor(0)
torch.Size([1, 8, 3, 64, 64])
7
tensor(1)
torch.Size([1, 16, 3, 128, 128])
8
tensor(0)
torch.Size([1, 8, 3, 64, 64])
9
tensor(1)
torch.Size([1, 16, 3, 128, 128])
10
tensor(0)
torch.Size([1, 8, 3, 64, 64])
11
tensor(1)
torch.Size([1, 16, 3, 128, 128])
12
tensor(0)
torch.Size([1, 8, 3, 64, 64])
13
tensor(1)
torch.Size([1, 16, 3, 128, 128])
14
tensor(0)
torch.Size([1, 8, 3, 64, 64])
15
tensor(1)
torch.Size([1, 16, 3, 128, 128])
16
tensor(0)
torch.Size([1, 8, 3, 64, 64])
17
tensor(1)
torch.Size([1, 16, 3, 128, 128])
18
tensor(0)
torch.Size([1, 8, 3, 64, 64])
19
tensor(1)
torch.Size([1, 16, 3, 128, 128])
20
tensor(0)
torch.Size([1, 8, 3, 64, 64])
21
tensor(1)
torch.Size([1, 16, 3, 128, 128])
22
tensor(0)
torch.Size([1, 8, 