# Batch sampler
https://github.com/galatolofederico/pytorch-balanced-batch

https://pytorch.org/docs/stable/torchvision/datasets.html#mnist

In [1]:
import torch
import torchvision
import torch.utils.data
import random

class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset, labels=None):
        self.labels = labels
        self.dataset = dict()
        self.balanced_max = 0
        # Save all the indices for all the classes
        for idx in range(0, len(dataset)):
            label = self._get_label(dataset, idx)
            if label not in self.dataset:
                self.dataset[label] = list()
            self.dataset[label].append(idx)
            self.balanced_max = len(self.dataset[label]) \
                if len(self.dataset[label]) > self.balanced_max else self.balanced_max
        
        # Oversample the classes with fewer elements than the max
        for label in self.dataset:
            while len(self.dataset[label]) < self.balanced_max:
                self.dataset[label].append(random.choice(self.dataset[label]))
        self.keys = list(self.dataset.keys())
        self.currentkey = 0
        self.indices = [-1]*len(self.keys)

    def __iter__(self):
        while self.indices[self.currentkey] < self.balanced_max - 1:
            self.indices[self.currentkey] += 1
            yield self.dataset[self.keys[self.currentkey]][self.indices[self.currentkey]]
            self.currentkey = (self.currentkey + 1) % len(self.keys)
        self.indices = [-1]*len(self.keys)
    
    def _get_label(self, dataset, idx, labels = None):
        if self.labels is not None:
            return self.labels[idx].item()
        else:
            # Trying guessing
            dataset_type = type(dataset)
            if dataset_type is torchvision.datasets.MNIST: # MNIST datatype
                return dataset.train_labels[idx].item() # return int value from tensor
            elif dataset_type is torchvision.datasets.ImageFolder: # ?
                return dataset.imgs[idx][1] # return int value from tensor
            elif dataset_type is EndoDataset: # EndoDaset
                return train_dataset.labels[idx] # return int value from tensor
            else:
                raise Exception("You should pass the tensor of labels to the constructor as second argument")

    def __len__(self):
        return self.balanced_max*len(self.keys)

In [3]:
mnist_= torchvision.datasets.MNIST(root='./mnistdata',
    train=True,
    transform=None,
    target_transform=None,
    download=True)

In [5]:
if type(mnist_) is torchvision.datasets.MNIST:
    print('yes')

yes


In [8]:
mnist_.train_labels.shape

torch.Size([60000])

In [9]:
mnist_.train_labels[0]

tensor(5)

In [11]:
type(mnist_.train_labels[0].item())

int

In [12]:
type(mnist_)

torchvision.datasets.mnist.MNIST