<a href="https://colab.research.google.com/github/SeongBeomLEE/Study/blob/main/Experiment/Pytorch_StratifiedSampler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pytorch StratifiedSampler 사용법

https://discuss.pytorch.org/t/how-to-enable-the-dataloader-to-sample-from-each-class-with-equal-probability/911/7

In [31]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os

import warnings
warnings.filterwarnings(action='ignore')

In [32]:
from sklearn.model_selection import StratifiedKFold

class StratifiedSampler(torch.utils.data.Sampler):
    """Stratified batch sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, y, batch_size, shuffle=True):
        if torch.is_tensor(y):
            y = y.cpu().numpy()
        assert len(y.shape) == 1, 'label array must be 1D'
        n_batches = int(len(y) / batch_size)
        self.skf = StratifiedKFold(n_splits=n_batches, shuffle=shuffle)
        self.X = torch.randn(len(y),1).numpy()
        self.y = y
        self.shuffle = shuffle

    def __iter__(self):
        if self.shuffle:
            self.skf.random_state = torch.randint(0,int(1e8),size=()).item()
        for train_idx, test_idx in self.skf.split(self.X, self.y):
            yield test_idx

    def __len__(self):
        return len(self.y)

In [33]:
common_transform = torchvision.transforms.Compose(
  [
    torchvision.transforms.ToTensor()
  ]
)

fashion_train_transformed = torchvision.datasets.FashionMNIST(root='./fashion', train=True, download=True, transform=common_transform)
fashion_test_transformed = torchvision.datasets.FashionMNIST(root='./fashion', train=False, download=True, transform=common_transform)

In [48]:
BATCH_SIZE = 128

In [89]:
from torch.utils.data import Dataset

class TrainDataset(Dataset):
    def __init__(self, x, y, train = False):
        self.X = x
        self.y = y

    def __len__(self):
        len_dataset= len(self.X)
        return len_dataset

    def __getitem__(self, idx):
        X = torch.tensor(self.X[idx])
        y = torch.tensor(self.y[idx])
        return X, y

In [90]:
x = torch.arange(len(fashion_train_transformed))
y = fashion_train_transformed.targets
x_dataset = TrainDataset(x, y)

In [92]:
arr = fashion_train_transformed.targets.numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 6000,
 1: 6000,
 2: 6000,
 3: 6000,
 4: 6000,
 5: 6000,
 6: 6000,
 7: 6000,
 8: 6000,
 9: 6000}

In [111]:
not_stratify_fashion_train_dataloader = torch.utils.data.DataLoader(x_dataset, batch_size=BATCH_SIZE, shuffle = False, num_workers=2)

arr = next(iter(not_stratify_fashion_train_dataloader))[1].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 13, 1: 15, 2: 12, 3: 16, 4: 10, 5: 14, 6: 15, 7: 11, 8: 8, 9: 14}

In [112]:
for i in not_stratify_fashion_train_dataloader:
    arr = i[1].numpy()
    unique, counts = np.unique(arr, return_counts = True)
    uniq_cnt_dict = dict(zip(unique, counts))
    print(i[0], uniq_cnt_dict)

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
        10772, 10773, 10774, 10775, 10776, 10777, 10778, 10779, 10780, 10781,
        10782, 10783, 10784, 10785, 10786, 10787, 10788, 10789, 10790, 10791,
        10792, 10793, 10794, 10795, 10796, 10797, 10798, 10799, 10800, 10801,
        10802, 10803, 10804, 10805, 10806, 10807, 10808, 10809, 10810, 10811,
        10812, 10813, 10814, 10815, 10816, 10817, 10818, 10819, 10820, 10821,
        10822, 10823, 10824, 10825, 10826, 10827, 10828, 10829, 10830, 10831,
        10832, 10833, 10834, 10835, 10836, 10837, 10838, 10839, 10840, 10841,
        10842, 10843, 10844, 10845, 10846, 10847, 10848, 10849, 10850, 10851,
        10852, 10853, 10854, 10855, 10856, 10857, 10858, 10859, 10860, 10861,
        10862, 10863, 10864, 10865, 10866, 10867, 10868, 10869, 10870, 10871,
        10872, 10873, 10874, 10875, 10876, 10877, 10878, 10879]) {0: 9, 1: 7, 2: 11, 3: 15, 4: 17, 5: 11, 6: 15, 7: 13, 8: 13, 9: 17}
tensor([10880, 10881, 10882, 10883, 

In [113]:
arr = x_dataset.y[i[0]].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
print(uniq_cnt_dict)

{0: 12, 1: 12, 2: 14, 3: 11, 4: 11, 5: 7, 6: 8, 7: 5, 8: 10, 9: 6}


In [114]:
sampler = StratifiedSampler(y = x_dataset.y, batch_size = BATCH_SIZE, shuffle = False)

stratify_fashion_train_dataloader = torch.utils.data.DataLoader(x_dataset, num_workers=2, batch_sampler = sampler)

arr = next(iter(stratify_fashion_train_dataloader))[1].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
uniq_cnt_dict

{0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 12, 6: 13, 7: 13, 8: 13, 9: 13}

In [115]:
for i in stratify_fashion_train_dataloader:
    arr = i[1].numpy()
    unique, counts = np.unique(arr, return_counts = True)
    uniq_cnt_dict = dict(zip(unique, counts))
    print(i[0], uniq_cnt_dict)

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
        10613, 10619, 10621, 10622, 10631, 10633, 10653, 10684, 10688, 10689,
        10690, 10692, 10694, 10696, 10701, 10705, 10706, 10711, 10713, 10714,
        10718, 10719, 10720, 10726, 10728, 10732, 10734, 10735, 10736, 10738,
        10739, 10740, 10741, 10749, 10762, 10766, 10775, 10779, 10785, 10818,
        10938, 10941, 10942, 10953, 10955, 10956, 10961, 10968, 10971, 10979,
        10987, 10988, 10990, 11108, 11110, 11111, 11127, 11134, 11149, 11163,
        11172, 11183, 11191, 11192, 11199, 11201, 11534, 11536, 11537, 11550,
        11567, 11596, 11611, 11618, 11626, 11664, 11683, 11687, 11689]) {0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 13, 6: 13, 7: 12, 8: 13, 9: 13}
tensor([10370, 10404, 10405, 10407, 10441, 10452, 10454, 10455, 10475, 10479,
        10482, 10486, 10488, 10489, 10495, 10516, 10521, 10539, 10555, 10575,
        10580, 10586, 10590, 10597, 10614, 10625, 10634, 10635, 10639, 10640,
        10667, 10671, 10674

In [116]:
arr = x_dataset.y[i[0]].numpy()
unique, counts = np.unique(arr, return_counts = True)
uniq_cnt_dict = dict(zip(unique, counts))
print(uniq_cnt_dict)

{0: 13, 1: 13, 2: 13, 3: 13, 4: 13, 5: 12, 6: 13, 7: 13, 8: 13, 9: 12}
