In [None]:
import torch
from configs.config_task1 import get_cfg_defaults
from data.dataset import load_cifar_dataset
from configs.config_task1 import get_cfg_defaults


In [None]:
cfg = get_cfg_defaults()
cfg.TRAIN.batch_size = 10

In [None]:
# load data
train_dataset,test_dataset,num_classes = load_cifar_dataset(cfg)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=cfg.TRAIN.batch_size,
                                            shuffle=False,
                                            pin_memory=False,
                                            num_workers=2)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=cfg.TRAIN.batch_size,
                                            shuffle=False,
                                            pin_memory=False,
                                            num_workers=2)

In [None]:
import numpy as np
def mixup_data( x1, x2,y1,y2,alpha):
        '''Returns mixed inputs, pairs of targets, and lambda'''
        lam=0.5
        mixed_x = lam * x1 + (1 - lam) * x2
        y_a, y_b = y1,y2
        return mixed_x, (y_a, y_b), lam

In [None]:
def rand_bbox(size, lam):
        W = size[1]
        H = size[2]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        # uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2
def cutmix_data(x1,x2,y1,y2,lam):
    # generate mixed sample
    
    target_a = y1
    target_b = y2
    bbx1, bby1, bbx2, bby2 = rand_bbox(x1.size(), lam)
    x1[:, bbx1:bbx2, bby1:bby2] = x2[:, bbx1:bbx2, bby1:bby2]
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x1.size()[-1] * x2.size()[-2]))
    # compute output
    return x1,(target_a,target_b),lam

In [None]:
# visualize
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from data.argument_type import Cutout,Mixup,Cutmix

images, labels = next(iter(train_loader))
mix_images, mix_labels = next(iter(test_loader))
print(images.size())  # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
cutout = Cutout(1,16)
mixup = Mixup(alpha=0.5)
for i in range(3,6):
    plt.subplot(5, 3, i-2)
    plt.title(labels[i].item())
    plt.imshow(images[i].permute(1, 2, 0))
    plt.axis('off')

cutout_images = images.clone()
for i in range(3,6):
    cutout_images[i] = cutout(cutout_images[i])
    plt.subplot(5, 3, i+1)
    plt.title(labels[i].item())
    plt.imshow(cutout_images[i].permute(1, 2, 0))
    plt.axis('off')
mixup_images = images.clone()
for i in range(3,6):
    mixup_images[i],mixup_labels,lam = mixup_data(images[i],mix_images[i],labels[i],mix_labels[i],alpha=0.5)
    print(lam)
    plt.subplot(5, 3, i+4)
    plt.title((mixup_labels[0].item(),mixup_labels[1].item()))
    plt.imshow(mixup_images[i].permute(1, 2, 0))
    plt.axis('off')

for i in range(3,6):
    cutmix_img,mixup_labels,lam = cutmix_data(images[i],mix_images[i],labels[i],mix_labels[i],lam=0.5)
    plt.subplot(5, 3, i+7)
    plt.title((mixup_labels[0].item(),mixup_labels[1].item()))
    plt.imshow(cutmix_img.permute(1, 2, 0))
    plt.axis('off')

for i in range(3,6):
    plt.subplot(5, 3, i+10)
    plt.title(mix_labels[i].item())
    plt.imshow(mix_images[i].permute(1, 2, 0))
    plt.axis('off')
plt.show()