In [None]:
import os
import copy
import math
import tqdm
import torch
import pickle
import random
import datetime

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.distributions.beta as beta
import torchvision.transforms.functional as TF

from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
from torch.nn.functional import relu, avg_pool2d

In [None]:
seed = 0

In [None]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
'''Data Preparation Arguments'''
data_prep_from_scratch = True
dataset = 'cifar10'
path_dataset = './Data/CIFAR10/'

'''Data Augmentation Method'''
method = 'vhmixup'

'''Optimization Arguments'''
batch_size = 128
train_epochs = 225

lr = 0.01
momentum = 0.9
weight_decay = 5e-4
# Increase LR from 0.01 to 1.
lr_scheduler_1_gamma = 10.0; milestones_1 = [400]
# Lower LR from 1. to 0.01
lr_scheduler_2_gamma = 0.1; milestones_2 = [32000, 48000, 70000]

# Preprocess CIFAR

In [None]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [None]:
if data_prep_from_scratch and dataset == 'cifar100':
    cifar100_train = unpickle(os.path.join(path_dataset, 'train'))
    cifar100_test = unpickle(os.path.join(path_dataset, 'test'))

    x_tr = torch.from_numpy(cifar100_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_tr = torch.LongTensor(cifar100_train[b'fine_labels'])
    x_te = torch.from_numpy(cifar100_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar100_test[b'fine_labels'])

    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

In [None]:
if data_prep_from_scratch and dataset == 'cifar10':
    x_tr, y_tr = None, None
    for b in range(5):
        cifar10_train = unpickle(os.path.join(path_dataset, 'data_batch_{}'.format(b+1)))
        
        batch_img = torch.from_numpy(cifar10_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
        batch_label = torch.LongTensor(cifar10_train[b'labels'])
        
        if x_tr is None:
            x_tr = batch_img
            y_tr = batch_label
        else:
            x_tr = torch.cat((x_tr, batch_img), dim=0)
            y_tr = torch.cat((y_tr, batch_label), dim=0)
    
    cifar10_test = unpickle(os.path.join(path_dataset, 'test_batch'))
    x_te = torch.from_numpy(cifar10_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar10_test[b'labels'])
    
    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Load Data

In [None]:
def load_datasets(path):
    d = torch.load(path)
    d_tr = (d[0], d[1])
    d_te = (d[2], d[3])
    if dataset == 'cifar100':
        n_outputs = 100
    else:
        n_outputs = 10        
    return d_tr, d_te, n_outputs

In [None]:
d_tr, d_te, n_outputs = load_datasets(os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Dataloader

In [None]:
class CIFAR(torch.utils.data.Dataset):
    def __init__(self, pack, method, train=False):
        self.x = pack[0]
        self.y = pack[1]
        self.img_size = (3,32,32)
        
        self.method = method
        self.train = train
    
    def __len__(self):
        return len(self.x)
    
    def transform(self, img):
        top = torch.randint(0,8,(1,))
        left = torch.randint(0,8,(1,))
        img = TF.crop(img, top=top, left=left, height=self.img_size[1], width=self.img_size[2])
        
        if torch.rand(1) > 0.5:
            img = TF.hflip(img)
            
        return img
    
    def __getitem__(self, item):
        x = self.x[item].float() / 255.0
        
        x = x.permute(2,0,1)
        
        if self.train:
            x = TF.pad(x, padding=4)
            x = self.transform(x)
        
        if 'bcplus' not in self.method:
            mean_image = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
        else:
            x = x - torch.mean(x)
            mean_image = torch.from_numpy(np.array([0.21921569, 0.21058824, 0.22156863])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
        
        x = x.permute(1,2,0)
        x = x - mean_image
        x = x / std_image
        
        return x.permute(2,0,1), self.y[item]

In [None]:
train_datasets = CIFAR(d_tr, method, train=True)
train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
test_datasets = CIFAR(d_te, method)
test_dataloaders = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=False)

In [None]:
print(len(train_datasets), len(test_datasets))

In [None]:
print(len(train_dataloaders), len(test_dataloaders))

In [None]:
for i, d in enumerate(train_dataloaders):
    mean_image = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])).float()
    std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
    img = d[0].permute(0,2,3,1) * std_image + mean_image
    save_image(img.permute(0,3,1,2), './tmp.png')
    break

# Augmentations

In [None]:
def verticalConcatMask(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_vertical_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_vertical = lambda_vertical_beta.sample(torch.Size([b])).view(b,1)
    
    img_1, label_1 = batch_1[0], batch_1[1]
    img_2, label_2 = batch_2[0], batch_2[1]
    
    if len(label_1.shape) == 1:
        label_1 = F.one_hot(label_1, num_classes=n_outputs)
        
    if len(label_2.shape) == 1:
        label_2 = F.one_hot(label_2, num_classes=n_outputs)
    
    binary_mask = torch.ones(img_1.shape)
    for b_indx in range(b):
        binary_mask[b_indx,:,(lambda_vertical[b_indx]*h).long():,:] = 0
        
    vertical_img = binary_mask * img_1 + (1 - binary_mask) * img_2
    vertical_label = (lambda_vertical*h).long().repeat(1,n_outputs) / h * label_1 + \
                        (h - (lambda_vertical*h).long()).repeat(1,n_outputs) / h * label_2
    
    return vertical_img, vertical_label

In [None]:
def horizontalConcatMask(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_horizontal_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_horizontal = lambda_horizontal_beta.sample(torch.Size([b])).view(b,1)
    
    img_1, label_1 = batch_1[0], batch_1[1]
    img_2, label_2 = batch_2[0], batch_2[1]
    
    if len(label_1.shape) == 1:
        label_1 = F.one_hot(label_1, num_classes=n_outputs)
        
    if len(label_2.shape) == 1:
        label_2 = F.one_hot(label_2, num_classes=n_outputs)
    
    binary_mask = torch.ones(img_1.shape)
    for b_indx in range(b):
        binary_mask[b_indx,:,:,(lambda_horizontal[b_indx]*w).long():] = 0
        
    horizontal_img = binary_mask * img_1 + (1 - binary_mask) * img_2
    horizontal_label = (lambda_horizontal*w).long().repeat(1,n_outputs) / w * label_1 + \
                        (w - (lambda_horizontal*w).long()).repeat(1,n_outputs) / w * label_2
    
    return horizontal_img, horizontal_label

In [None]:
def VHMixup(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_mixup_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_mixup = lambda_mixup_beta.sample(torch.Size([b])).view(b,1)
    
    vertical_concat, vertical_label = verticalConcatMask(batch_1, batch_2)
    horizontal_concat, horizontal_label = horizontalConcatMask(batch_1, batch_2)
    
    mixed_img = lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w) * vertical_concat + \
                    (1 - lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w)) * horizontal_concat
    mixed_label = lambda_mixup.repeat(1,n_outputs) * vertical_label + \
                    (1 - lambda_mixup.repeat(1,n_outputs)) * horizontal_label
    
    return mixed_img, mixed_label

In [None]:
def VHBCplus(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    vertical_concat, vertical_label = verticalConcatMask(batch_1, batch_2)
    horizontal_concat, horizontal_label = horizontalConcatMask(batch_1, batch_2)
    
    lambda_uni = torch.rand(b)
    lambda_factor = (1 - lambda_uni) / lambda_uni
    
    vertical_std = torch.std(vertical_concat.view(b,-1),dim=1)
    horizontal_std = torch.std(horizontal_concat.view(b,-1),dim=1)
    std_factor = vertical_std / horizontal_std

    p = 1 / (1 + std_factor * lambda_factor)
    
    denom = torch.sqrt(p**2 + (1-p)**2)
    
    c, h, w = batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    bcplus_img = (p.reshape(b,1,1,1).repeat(1,c,h,w) * vertical_concat + \
                      (1 - p).reshape(b,1,1,1).repeat(1,c,h,w) * horizontal_concat) / denom.reshape(b,1,1,1).repeat(1,c,h,w)
    bcplus_label = lambda_uni.reshape(b,1).repeat(1,n_outputs) * vertical_label + \
                        (1 - lambda_uni.reshape(b,1).repeat(1,n_outputs)) * horizontal_label
    
    return bcplus_img, bcplus_label

# Model

In [None]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, nf):
        super(ResNet, self).__init__()
        self.in_planes = nf

        self.conv1 = conv3x3(3, nf * 1)
        self.bn1 = nn.BatchNorm2d(nf * 1)
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
        self.linear = nn.Linear(nf * 8 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        bsz = x.size(0)
        out = relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [None]:
def ResNet18(nclasses, nf=64):
    return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf)

# Train

### Log Files

In [None]:
ROOT_DIR = './Results/'
now =  '{}_ResNet_{}_{}'.format(dataset, method, seed)

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
else:
    import shutil
    shutil.rmtree(LOG_DIR)
    os.makedirs(LOG_DIR)
    
MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)

In [None]:
model = ResNet18(n_outputs).to(device)

In [None]:
if method == 'vhmixup':
    augmentor = VHMixup
else:
    augmentor = VHBCplus

### Optimizer and Schedulers

In [None]:
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
scheduler_1 = torch.optim.lr_scheduler.MultiStepLR(opt, gamma=lr_scheduler_1_gamma, milestones=milestones_1)
scheduler_2 = torch.optim.lr_scheduler.MultiStepLR(opt, gamma=lr_scheduler_2_gamma, milestones=milestones_2)

In [None]:
def softXEnt(output, target):
        logprobs = torch.nn.functional.log_softmax(output, dim = 1)
        return  -(target * logprobs).sum() / output.shape[0]

### Training

In [None]:
for epoch in range(train_epochs+1):
    model.train()
    for i, d in enumerate(tqdm.tqdm(train_dataloaders)):
        indx = torch.randperm(d[0].shape[0])
        x1, y1 = d[0], d[1]
        x2, y2 = x1[indx], y1[indx]

        x, y = augmentor((x1,y1), (x2,y2))
        
        x = x.float().to(device)
        y = y.float().to(device)
        
        opt.zero_grad()
        
        out = model(x)
        
        loss = softXEnt(out, y)
        loss.backward()
        
        opt.step()
        
        summary_writer.add_scalar('Loss', loss.item())
        
        # Scheduler is defined based on total number of iterations
        scheduler_1.step()
        scheduler_2.step()
        
    model.eval()
    total_acc = 0
    for i, d in enumerate(test_dataloaders):
        x, y = d[0], d[1]
        x = x.float().to(device)
        y = y.long().to(device)
        
        with torch.no_grad():
            out_prob = model(x)
        
        pred = torch.argmax(out_prob, dim=1)
        prediction = pred.cpu().numpy()
        truth = y.cpu().numpy()
        acc = np.count_nonzero(prediction == truth)

        total_acc += acc
        
    print('After the accuracy after {} epochs is {}'.format(epoch, total_acc / len(test_datasets)))
    print()

    summary_writer.add_scalar('Eval ACC', total_acc / len(test_datasets))
    
    if epoch % 10 == 0:
        torch.save(model.state_dict(), MODEL_DIR+'epoch_{}.pth'.format(epoch))

In [None]:
total_acc / len(test_datasets)

In [None]:
torch.save(model.state_dict(), MODEL_DIR+'final.pth')

# Evaluate

In [None]:
model.eval()
total_acc = 0
for i, d in enumerate(test_dataloaders):
    x, y = d[0], d[1]
    x = x.float().to(device)
    y = y.long().to(device)

    with torch.no_grad():
        out_prob = model(x)

    pred = torch.argmax(out_prob, dim=1)
    prediction = pred.cpu().numpy()
    truth = y.cpu().numpy()
    acc = np.count_nonzero(prediction == truth)

    total_acc += acc

print('After the accuracy after {} epochs is {}'.format(epoch, total_acc / len(test_datasets)))