In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

import cv2
import random
import numpy as np
from time import time

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
def set_random_seed(seed):
    if seed < 0:
        return
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print("Setting seed done...")

set_random_seed(2022)

Setting seed done...


In [3]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [10]:
trainset = torchvision.datasets.CIFAR100(
    root='../datasets/cifar', 
    train=True, 
    download=True, 
)

Files already downloaded and verified


In [11]:
images = []
labels = []
for i, (image, label) in enumerate(trainset):
    images.append(
        # HWC -> CHW
        torch.tensor(np.array(image).transpose((2,0,1)))
    )
    labels.append(label)
    if i == 15: break

org_images = torch.stack(images).float()
print(org_images.shape)

torch.Size([16, 3, 32, 32])


In [16]:
r = np.random.rand(1)
beta = 1.0
selected_idx = [0, 7, 9]

rand_index = torch.randperm(org_images.size()[0])
lam = np.random.beta(beta, beta)

for visual in ["cutmix", "cutout", "mixup", "original"]:
    images = org_images.clone()
    if visual == "cutmix":
        # generate mixed sample
        bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
        images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
        # adjust lambda to exactly match pixel ratio
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
    elif visual == "cutout":
        # generate cut sample
        bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
        images[:, :, bbx1:bbx2, bby1:bby2] = 0
    elif visual == "mixup":
        # generate mixed sample
        lam = 0.5
        images = lam * images + (1 - lam) * images[rand_index, :]
    else:
        images = images
        pass
    
    for img_idx in selected_idx:
        image = images[img_idx].numpy()
        # CHW -> HWC
        image = image.transpose((1,2,0))
        label = labels[img_idx]
        
        cv2.imwrite(f"./images/{visual}_{label}.png", image)
        print(img_idx, f"./images/{visual}_{label}.png")

0 ./images/cutmix_19.png
7 ./images/cutmix_28.png
9 ./images/cutmix_31.png
0 ./images/cutout_19.png
7 ./images/cutout_28.png
9 ./images/cutout_31.png
0 ./images/mixup_19.png
7 ./images/mixup_28.png
9 ./images/mixup_31.png
0 ./images/original_19.png
7 ./images/original_28.png
9 ./images/original_31.png
