In [1]:
import numpy as np

import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
train_dataset = MNIST('data', train=True, download=True,
        transform=transforms.Compose([
            transforms.ToTensor()
        ]))

In [3]:
class SubDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):        
        return self.data[idx], self.targets[idx]

In [4]:
def split_dataset(dataset, n_task=5, n_classes=10):
    subdatasets = []

    data = dataset.data
    targets = dataset.targets

    classes_per_task = int(n_classes / n_task)
    for i in range(n_task):
        classes = np.arange(i * classes_per_task, (i+1) * classes_per_task)
        index = torch.zeros_like(targets).to(torch.bool)
        for k in classes:
            index = torch.logical_or(index, targets == k)

        subdataset = SubDataset(data[index], targets[index])
        subdatasets.append(subdataset)

    return subdatasets

In [5]:
subdatasets = split_dataset(train_dataset, 5, 10)

In [6]:
for i, subdataset in enumerate(subdatasets):

    data_loader = DataLoader(subdataset, batch_size=32, shuffle=True, drop_last=True)

    for x, y in data_loader:
        print(y)
        break

tensor([0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1,
        0, 0, 0, 1, 0, 1, 1, 0])
tensor([3, 2, 2, 3, 3, 2, 3, 2, 3, 3, 2, 3, 3, 2, 3, 2, 2, 2, 2, 3, 2, 3, 3, 2,
        3, 2, 2, 3, 3, 2, 2, 2])
tensor([4, 4, 5, 4, 4, 4, 4, 4, 4, 5, 5, 4, 4, 5, 5, 5, 4, 4, 5, 4, 4, 4, 4, 5,
        4, 4, 5, 4, 4, 4, 4, 5])
tensor([7, 7, 6, 7, 6, 7, 6, 6, 7, 7, 6, 7, 7, 7, 7, 6, 6, 6, 7, 6, 6, 7, 6, 6,
        6, 7, 6, 7, 7, 7, 6, 7])
tensor([8, 9, 8, 9, 9, 9, 9, 8, 9, 8, 8, 8, 8, 8, 9, 9, 8, 9, 8, 9, 8, 9, 9, 9,
        8, 9, 8, 8, 8, 8, 8, 9])
