<a href="https://colab.research.google.com/github/SeongBeomLEE/Study/blob/main/Experiment/Pytorch_StratifiedSampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pytorch StratifiedSampler 사용법

https://discuss.pytorch.org/t/how-to-enable-the-dataloader-to-sample-from-each-class-with-equal-probability/911/7

In [31]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

import warnings
warnings.filterwarnings(action='ignore')

In [32]:
from sklearn.model_selection import StratifiedKFold

class StratifiedSampler(torch.utils.data.Sampler):
    """Stratified batch sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, y, batch_size, shuffle=True):
        if torch.is_tensor(y):
            y = y.cpu().numpy()
        assert len(y.shape) == 1, 'label array must be 1D'
        n_batches = int(len(y) / batch_size)
        self.skf = StratifiedKFold(n_splits=n_batches, shuffle=shuffle)
        self.X = torch.randn(len(y),1).numpy()
        self.y = y
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            self.skf.random_state = torch.randint(0,int(1e8),size=()).item()
        for train_idx, test_idx in self.skf.split(self.X, self.y):
            yield test_idx

    def __len__(self):
        return len(self.y)

In [33]:
common_transform = torchvision.transforms.Compose(
  [
    torchvision.transforms.ToTensor()
  ]
)

fashion_train_transformed = torchvision.datasets.FashionMNIST(root='./fashion', train=True, download=True, transform=common_transform)
fashion_test_transformed = torchvision.datasets.FashionMNIST(root='./fashion', train=False, download=True, transform=common_transform)

In [48]:
BATCH_SIZE = 128

In [89]:
from torch.utils.data import Dataset

class TrainDataset(Dataset):
    def __init__(self, x, y, train = False):
        self.X = x
        self.y = y

    def __len__(self):
        len_dataset= len(self.X)
        return len_dataset

    def __getitem__(self, idx):
        X = torch.tensor(self.X[idx])
        y = torch.tensor(self.y[idx])
        return X, y

In [90]:
x = torch.arange(len(fashion_train_transformed))
y = fashion_train_transformed.targets
x_dataset = TrainDataset(x, y)

In [92]:
arr = fashion_train_transformed.targets.numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 6000,
 1: 6000,
 2: 6000,
 3: 6000,
 4: 6000,
 5: 6000,
 6: 6000,
 7: 6000,
 8: 6000,
 9: 6000}

In [111]:
not_stratify_fashion_train_dataloader = torch.utils.data.DataLoader(x_dataset, batch_size=BATCH_SIZE, shuffle = False, num_workers=2)

arr = next(iter(not_stratify_fashion_train_dataloader))[1].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 13, 1: 15, 2: 12, 3: 16, 4: 10, 5: 14, 6: 15, 7: 11, 8: 8, 9: 14}

In [None]:
for i in not_stratify_fashion_train_dataloader:
    arr = i[1].numpy()
    unique, counts = np.unique(arr, return_counts = True)
    uniq_cnt_dict = dict(zip(unique, counts))
    print(i[0], uniq_cnt_dict)

In [113]:
arr = x_dataset.y[i[0]].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
print(uniq_cnt_dict)

{0: 12, 1: 12, 2: 14, 3: 11, 4: 11, 5: 7, 6: 8, 7: 5, 8: 10, 9: 6}


In [114]:
sampler = StratifiedSampler(y = x_dataset.y, batch_size = BATCH_SIZE, shuffle = False)

stratify_fashion_train_dataloader = torch.utils.data.DataLoader(x_dataset, num_workers=2, batch_sampler = sampler)

arr = next(iter(stratify_fashion_train_dataloader))[1].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 12, 6: 13, 7: 13, 8: 13, 9: 13}

In [None]:
for i in stratify_fashion_train_dataloader:
    arr = i[1].numpy()
    unique, counts = np.unique(arr, return_counts = True)
    uniq_cnt_dict = dict(zip(unique, counts))
    print(i[0], uniq_cnt_dict)

In [116]:
arr = x_dataset.y[i[0]].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
print(uniq_cnt_dict)

{0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 12, 6: 13, 7: 13, 8: 13, 9: 12}
