# <Semi-supervised learning tutorial 1 - CIFAR10 dataset & RandAugment>

In [None]:
! git clone https://github.com/KU-DIC/LG_semi_supervised_learning_day17.git

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import collections
from torchvision import datasets
import PIL
from PIL import Image
from LG_semi_supervised_learning_day17.augmentation import *

In [None]:
data_path = '/content/LG_semi_supervised_learning_day17/data'

In [None]:
train_dataset = datasets.CIFAR10(data_path, train=True, download=False)
test_dataset = datasets.CIFAR10(data_path, train=False, download=False)

In [None]:
train_dataset.classes

In [None]:
train_dataset.class_to_idx

In [None]:
train_dataset.targets[:10]

In [None]:
train_dataset.data.shape

In [None]:
test_dataset.data.shape

In [None]:
collections.Counter(train_dataset.targets)

In [None]:
collections.Counter(test_dataset.targets)

In [None]:
train_dataset.class_to_idx.keys()

In [None]:
train_dataset.class_to_idx.values()

In [None]:
np.random.seed(0)
fig = plt.figure(figsize=(20, 20))
for cls, i in zip(train_dataset.class_to_idx.keys(), train_dataset.class_to_idx.values()):
    index_list = np.where(np.array(train_dataset.targets)==i)[0]
    index = np.random.choice(index_list, 1, replace=False)
    img = train_dataset.data[index][0]
    
    fig.add_subplot(5, 5, i+1)
    plt.imshow(img)
    plt.title(f'{cls} - {i}')
plt.show()

In [None]:
def rand_augment_pool():
    augs = [
        # 사용할 이미지에 맞는 augmentation 풀 설정
        # 함수명, 각 증강기법에 하이퍼파라미터 값으로 적당한 max_v, bias 미리 설정
        (AutoContrast, None, None),
        (Equalize, None, None),
        (Invert, None, None),
        (Rotate, 30, None),
        (Posterize, 4, 0),
        (Solarize, 256, None),
        (Color, 1.8, 0.1),
        (Contrast, 1.8, 0.1),
        (Brightness, 1.8, 0.1),
        (Sharpness, 1.8, 0.1),
        (ShearX, 0.3, None),
        (ShearY, 0.3, None),
        (TranslateXConst, 100, None),
        (TranslateYConst, 100, None),
    ]
    return augs

In [None]:
class RandAugment(object):
    def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC, fill=(128, 128, 128)):
        global RESAMPLE_MODE, FILL_COLOR
        RESAMPLE_MODE = resample_mode
        FILL_COLOR = fill
        
        self.n = int(n)
        self.m = m
        self.augment_pool = rand_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            print(op)
            img = op(img, v=self.m, max_v=max_v, bias=bias)
        return img

In [None]:
path = f'{data_path}/air_plane.jpg'
image = Image.open(path).convert('RGB')

In [None]:
n = 5
m = 10
ra = RandAugment(n, m)
transform_image = ra(image)

In [None]:
np.random.seed(0)
fig = plt.figure(figsize=(10, 10))

fig.add_subplot(1, 2, 1)
plt.imshow(image)
plt.title('Original')

fig.add_subplot(1, 2, 2)
plt.imshow(transform_image)
plt.title(f'RandAugment n={n}, m={m}')

plt.show()

In [None]:
def cifar_augment_pool():
    augs = [
        (AutoContrast, None, None),
        (Equalize, None, None),
        (Invert, None, None),
        (Rotate, 30, None),
        (Posterize, 4, 0),
        (Solarize, 256, None),
        (Color, 1.8, 0.1),
        (Contrast, 1.8, 0.1),
        (Brightness, 1.8, 0.1),
        (Sharpness, 1.8, 0.1),
        (ShearX, 0.3, None),
        (ShearY, 0.3, None),
        (TranslateXConst, 32 // 8, None),
        (TranslateYConst, 32 // 8, None),
        (CutoutConst, 32 // 8, None),
    ]
    return augs

In [None]:
class RandAugmentCIFAR(object):
    def __init__(self, n, m, resample_mode=PIL.Image.BICUBIC, fill=(128, 128, 128)):
        global RESAMPLE_MODE, FILL_COLOR
        RESAMPLE_MODE = resample_mode
        FILL_COLOR = fill
        self.n = int(n)
        self.m = m
        self.augment_pool = cifar_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            img = op(img, v=self.m, max_v=max_v, bias=bias)
        img = CutoutConst(img, v=self.m, max_v=32 // 4)
        return img

In [None]:
ra_cifar = RandAugmentCIFAR(5, 10)

In [None]:
np.random.seed(0)
fig = plt.figure(figsize=(5,40))
for cls, i in zip(train_dataset.class_to_idx.keys(), train_dataset.class_to_idx.values()):
    index_list = np.where(np.array(train_dataset.targets)==i)[0]
    index = np.random.choice(index_list, 1, replace=False)
    img = train_dataset.data[index][0]
    img = Image.fromarray(img)
    transform_image = ra_cifar(img)
    
    fig.add_subplot(10, 2, 2*i+1)
    plt.imshow(img)
    plt.title(f'{cls} - {i}')
    
    fig.add_subplot(10, 2, 2*i+2)
    plt.imshow(transform_image)
    plt.title(f'{cls} - {i} RandAug')
plt.show()