In [None]:
from tensorflow.keras.datasets import cifar10

import albumentations
import cv2
import numpy as np
import matplotlib.pyplot as plt
import cv2
# augmentation method를 import합니다.
from albumentations import (
    Compose, HorizontalFlip, CLAHE, HueSaturationValue,
    RandomBrightness, RandomContrast, RandomGamma,
    ToFloat, ShiftScaleRotate
)

In [None]:
# 각 함수에 대한 설명은
# https://albumentations.ai/docs/
# document를 참고하세요.
Aug_train = Compose([
    HorizontalFlip(p=0.5),
    RandomContrast(limit=0.2, p=0.5),
    RandomGamma(gamma_limit=(80, 120), p=0.5),
    RandomBrightness(limit=0.2, p=0.5),
    HueSaturationValue(hue_shift_limit=5, sat_shift_limit=20,
                       val_shift_limit=10, p=.9),
    ShiftScaleRotate(
        shift_limit=0.0625, scale_limit=0.1, 
        rotate_limit=15, border_mode=cv2.BORDER_REFLECT_101, p=0.8), 
    ToFloat(max_value=255)
])

Aug_test = Compose([
    ToFloat(max_value=255)
])

In [None]:
from tensorflow.python.keras.utils.data_utils import Sequence

# Sequence 클래스를 상속받아 generator 형태로 사용합니다.
class CIFAR10Dataset(Sequence):
    def __init__(self, x_set, y_set, batch_size, augmentations):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.augment = augmentations

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.batch_size)))

    # 지정 배치 크기만큼 데이터를 로드합니다.
    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        # augmentation을 적용해서 numpy array에 stack합니다.
        return np.stack([
            self.augment(image=x)["image"] for x in batch_x
        ], axis=0), np.array(batch_y)

# CIFAR-10 Dataset을 불러옵니다.
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

BATCH_SIZE = 8

# Dataset을 생성합니다.
train_gen = CIFAR10Dataset(x_train, y_train, BATCH_SIZE, Aug_train)
test_gen = CIFAR10Dataset(x_test, y_test, BATCH_SIZE, Aug_test)

In [None]:
# 데이터를 그려봅시다.
images, labels = next(iter(train_gen))

fig = plt.figure()

for i, (image, label) in enumerate(zip(images, labels)):
    ax = fig.add_subplot(3, 3, i + 1)
    ax.imshow(image)
    ax.set_xlabel(label)
    ax.set_xticks([]); ax.set_yticks([])

plt.tight_layout()
plt.show()