In [2]:
import os
import copy
import math
import tqdm
import torch
import pickle
import random
import datetime

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.distributions.beta as beta
import torchvision.transforms.functional as TF

from PIL import Image
from torchvision import transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter
from torch.nn.functional import relu, avg_pool2d
from sklearn.calibration import calibration_curve

In [3]:
seed = 0

In [4]:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [6]:
'''Data Preparation Arguments'''
data_prep_from_scratch = True
dataset = 'cifar10'
path_dataset = './Data/CIFAR10/'

'''Data Augmentation Method'''
method = 'vhmixup'

'''Optimization Arguments'''
batch_size = 128

# Preprocess CIFAR

In [7]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [8]:
if data_prep_from_scratch and dataset == 'cifar100':
    cifar100_train = unpickle(os.path.join(path_dataset, 'train'))
    cifar100_test = unpickle(os.path.join(path_dataset, 'test'))

    x_tr = torch.from_numpy(cifar100_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_tr = torch.LongTensor(cifar100_train[b'fine_labels'])
    x_te = torch.from_numpy(cifar100_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar100_test[b'fine_labels'])

    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

In [9]:
if data_prep_from_scratch and dataset == 'cifar10':
    x_tr, y_tr = None, None
    for b in range(5):
        cifar10_train = unpickle(os.path.join(path_dataset, 'data_batch_{}'.format(b+1)))
        
        batch_img = torch.from_numpy(cifar10_train[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
        batch_label = torch.LongTensor(cifar10_train[b'labels'])
        
        if x_tr is None:
            x_tr = batch_img
            y_tr = batch_label
        else:
            x_tr = torch.cat((x_tr, batch_img), dim=0)
            y_tr = torch.cat((y_tr, batch_label), dim=0)
    
    cifar10_test = unpickle(os.path.join(path_dataset, 'test_batch'))
    x_te = torch.from_numpy(cifar10_test[b'data'].reshape((-1,32,32,3), order='F')).permute(0,2,1,3)
    y_te = torch.LongTensor(cifar10_test[b'labels'])
    
    torch.save((x_tr, y_tr, x_te, y_te), os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Load Data

In [10]:
def load_datasets(path):
    d = torch.load(path)
    d_tr = (d[0], d[1])
    d_te = (d[2], d[3])
    if dataset == 'cifar100':
        n_outputs = 100
    else:
        n_outputs = 10        
    return d_tr, d_te, n_outputs

In [11]:
d_tr, d_te, n_outputs = load_datasets(os.path.join(path_dataset, '{}.pt'.format(dataset)))

# Dataloader

In [12]:
class CIFAR(torch.utils.data.Dataset):
    def __init__(self, pack, method, train=False):
        self.x = pack[0]
        self.y = pack[1]
        self.img_size = (3,32,32)
        
        self.method = method
        self.train = train
    
    def __len__(self):
        return len(self.x)
    
    def transform(self, img):
        top = torch.randint(0,8,(1,))
        left = torch.randint(0,8,(1,))
        img = TF.crop(img, top=top, left=left, height=self.img_size[1], width=self.img_size[2])
        
        if torch.rand(1) > 0.5:
            img = TF.hflip(img)
            
        return img
    
    def __getitem__(self, item):
        x = self.x[item].float() / 255.0
        
        x = x.permute(2,0,1)
        
        if self.train:
            x = TF.pad(x, padding=4)
            x = self.transform(x)
        
        if 'bcplus' not in self.method:
            mean_image = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
        else:
            x = x - torch.mean(x)
            mean_image = torch.from_numpy(np.array([0.21921569, 0.21058824, 0.22156863])).float()
            std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
        
        x = x.permute(1,2,0)
        x = x - mean_image
        x = x / std_image
        
        return x.permute(2,0,1), self.y[item]

In [13]:
test_datasets = CIFAR(d_te, method)
test_dataloaders = torch.utils.data.DataLoader(test_datasets, batch_size=batch_size, shuffle=False)

In [14]:
print(len(test_dataloaders))

79


In [15]:
for i, d in enumerate(test_dataloaders):
    mean_image = torch.from_numpy(np.array([0.4914, 0.4822, 0.4465])).float()
    std_image = torch.from_numpy(np.array([0.2023, 0.1994, 0.2010])).float()
    img = d[0].permute(0,2,3,1) * std_image + mean_image
    save_image(img.permute(0,3,1,2), './tmp.png')
    break

# Augmentations

In [16]:
def verticalConcatMask(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_vertical_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_vertical = lambda_vertical_beta.sample(torch.Size([b])).view(b,1)
    
    img_1, label_1 = batch_1[0], batch_1[1]
    img_2, label_2 = batch_2[0], batch_2[1]
    
    if len(label_1.shape) == 1:
        label_1 = F.one_hot(label_1, num_classes=n_outputs)
        
    if len(label_2.shape) == 1:
        label_2 = F.one_hot(label_2, num_classes=n_outputs)
    
    binary_mask = torch.ones(img_1.shape)
    for b_indx in range(b):
        binary_mask[b_indx,:,(lambda_vertical[b_indx]*h).long():,:] = 0
        
    vertical_img = binary_mask * img_1 + (1 - binary_mask) * img_2
    vertical_label = (lambda_vertical*h).long().repeat(1,n_outputs) / h * label_1 + \
                        (h - (lambda_vertical*h).long()).repeat(1,n_outputs) / h * label_2
    
    return vertical_img, vertical_label

In [17]:
def horizontalConcatMask(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_horizontal_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_horizontal = lambda_horizontal_beta.sample(torch.Size([b])).view(b,1)
    
    img_1, label_1 = batch_1[0], batch_1[1]
    img_2, label_2 = batch_2[0], batch_2[1]
    
    if len(label_1.shape) == 1:
        label_1 = F.one_hot(label_1, num_classes=n_outputs)
        
    if len(label_2.shape) == 1:
        label_2 = F.one_hot(label_2, num_classes=n_outputs)
    
    binary_mask = torch.ones(img_1.shape)
    for b_indx in range(b):
        binary_mask[b_indx,:,:,(lambda_horizontal[b_indx]*w).long():] = 0
        
    horizontal_img = binary_mask * img_1 + (1 - binary_mask) * img_2
    horizontal_label = (lambda_horizontal*w).long().repeat(1,n_outputs) / w * label_1 + \
                        (w - (lambda_horizontal*w).long()).repeat(1,n_outputs) / w * label_2
    
    return horizontal_img, horizontal_label

In [18]:
def Mixup(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_mixup_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_mixup = lambda_mixup_beta.sample(torch.Size([b])).view(b,1)
    
    img_1, label_1 = batch_1[0], batch_1[1]
    img_2, label_2 = batch_2[0], batch_2[1]
    
    if len(label_1.shape) == 1:
        label_1 = F.one_hot(label_1, num_classes=n_outputs)
        
    if len(label_2.shape) == 1:
        label_2 = F.one_hot(label_2, num_classes=n_outputs)
    
    mixed_img = lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w) * img_1 + \
                    (1 - lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w)) * img_2
    mixed_label = lambda_mixup.repeat(1,n_outputs) * label_1 + \
                    (1 - lambda_mixup.repeat(1,n_outputs)) * label_2
    
    return mixed_img, mixed_label

In [19]:
def VHMixup(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    lambda_mixup_beta = beta.Beta(torch.tensor([1.]), torch.tensor([1.]))
    lambda_mixup = lambda_mixup_beta.sample(torch.Size([b])).view(b,1)
    
    vertical_concat, vertical_label = verticalConcatMask(batch_1, batch_2)
    horizontal_concat, horizontal_label = horizontalConcatMask(batch_1, batch_2)
    
    mixed_img = lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w) * vertical_concat + \
                    (1 - lambda_mixup.reshape(b,1,1,1).repeat(1,c,h,w)) * horizontal_concat
    mixed_label = lambda_mixup.repeat(1,n_outputs) * vertical_label + \
                    (1 - lambda_mixup.repeat(1,n_outputs)) * horizontal_label
    
    return mixed_img, mixed_label

In [20]:
def VHBCplus(batch_1, batch_2):
    b, c, h, w = batch_1[0].shape[0], batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    vertical_concat, vertical_label = verticalConcatMask(batch_1, batch_2)
    horizontal_concat, horizontal_label = horizontalConcatMask(batch_1, batch_2)
    
    lambda_uni = torch.rand(b)
    lambda_factor = (1 - lambda_uni) / lambda_uni
    
    vertical_std = torch.std(vertical_concat.view(b,-1),dim=1)
    horizontal_std = torch.std(horizontal_concat.view(b,-1),dim=1)
    std_factor = vertical_std / horizontal_std

    p = 1 / (1 + std_factor * lambda_factor)
    
    denom = torch.sqrt(p**2 + (1-p)**2)
    
    c, h, w = batch_1[0].shape[1], batch_1[0].shape[2], batch_1[0].shape[3]
    
    bcplus_img = (p.reshape(b,1,1,1).repeat(1,c,h,w) * vertical_concat + \
                      (1 - p).reshape(b,1,1,1).repeat(1,c,h,w) * horizontal_concat) / denom.reshape(b,1,1,1).repeat(1,c,h,w)
    bcplus_label = lambda_uni.reshape(b,1).repeat(1,n_outputs) * vertical_label + \
                        (1 - lambda_uni.reshape(b,1).repeat(1,n_outputs)) * horizontal_label
    
    return bcplus_img, bcplus_label

# Model

In [21]:
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, nf):
        super(ResNet, self).__init__()
        self.in_planes = nf

        self.conv1 = conv3x3(3, nf * 1)
        self.bn1 = nn.BatchNorm2d(nf * 1)
        self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2)
        self.linear = nn.Linear(nf * 8 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        bsz = x.size(0)
        out = relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [22]:
def ResNet18(nclasses, nf=64):
    return ResNet(BasicBlock, [2, 2, 2, 2], nclasses, nf)

# Evaluate

In [23]:
model = ResNet18(n_outputs).to(device)
model.load_state_dict(torch.load('./Results/cidar10_ResNet_baseline/models/final.pth'))

<All keys matched successfully>

In [40]:
augmentor = VHMixup

In [41]:
def calc_bins(preds):
    # Assign each prediction to a bin
    num_bins = 100
    bins = np.linspace(0.1, 1, num_bins)
    binned = np.digitize(preds, bins)

    # Save the accuracy, confidence and size of each bin
    bin_accs = np.zeros(num_bins)
    bin_confs = np.zeros(num_bins)
    bin_sizes = np.zeros(num_bins)
    
    for bin in range(num_bins):
        bin_sizes[bin] = len(preds[binned == bin])
        if bin_sizes[bin] > 0:
            bin_accs[bin] = (labels_oneh[binned==bin]).sum() / bin_sizes[bin]
            bin_confs[bin] = (preds[binned==bin]).sum() / bin_sizes[bin]

    return bins, binned, bin_accs, bin_confs, bin_sizes

In [42]:
def get_metrics(preds):
    ECE = 0
    MCE = 0
    bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds)

    for i in range(len(bins)):
        abs_conf_dif = abs(bin_accs[i] - bin_confs[i])
        ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif
        MCE = max(MCE, abs_conf_dif)

    return ECE, MCE, _

In [43]:
def softXEnt(output, target):
    logprobs = torch.nn.functional.log_softmax(output, dim = 1)
    return  -(target * logprobs).sum() / output.shape[0]

In [44]:
model.eval()
pred_probs = None
labels_oneh = None

for i, d in enumerate(tqdm.tqdm(test_dataloaders)):
    x, y = d[0], d[1]

    x = x.float().to(device)
    y = y.long().to(device)

    with torch.no_grad():
        out_prob = F.softmax(model(x), dim=1).cpu().numpy()
    
    if pred_probs is None:
        pred_probs = out_prob
    else:
        pred_probs = np.concatenate((pred_probs, out_prob), axis=0)
        
    if labels_oneh is None:
        labels_oneh = F.one_hot(y, num_classes=10).cpu().numpy()
    else:
        labels_oneh = np.concatenate((labels_oneh, F.one_hot(y, num_classes=10).cpu().numpy()), axis=0)

100%|██████████████████████████████████████████████████████████████████████████████████| 79/79 [00:15<00:00,  5.25it/s]


In [45]:
pred_probs = pred_probs.flatten()
labels_oneh = labels_oneh.flatten()

In [46]:
ece, mce, binned = get_metrics(pred_probs)

In [48]:
ece

0.006550298200845712

In [None]:
model.eval()
total_loss = 0
total_acc = 0; atleast_one_acc = 0; prime_acc = 0
for i, d in enumerate(tqdm.tqdm(test_dataloaders)):
    indx = torch.randperm(d[0].shape[0])
    x1, y1 = d[0], d[1]
    x2, y2 = x1[indx], y1[indx]

    x, y = augmentor((x1,y1), (x2,y2))

    x = x.float().to(device)
    y = y.float().to(device)

    with torch.no_grad():
        out_prob = model(x)
        loss = softXEnt(out_prob, y)
                
    out_labels = torch.topk(out_prob, 2, dim=1)[1]
    out_l1 = out_labels[:,0]; out_l2 = out_labels[:,1]
    
    pred_labels = torch.topk(y, 2, dim=1)[1]
    pred_l1 = pred_labels[:,0]; pred_l2 = pred_labels[:,1]
    
    prediction = out_l1.cpu().numpy()
    truth = pred_l1.cpu().numpy()
    acc_l1 = prediction == truth
        
    prediction = out_l2.cpu().numpy()
    truth = pred_l2.cpu().numpy()
    acc_l2 = prediction == truth
    
    primary_label = y1.cpu().numpy()
    acc_prime_1 = out_l1.cpu().numpy() == primary_label
    acc_prime_2 = out_l2.cpu().numpy() == primary_label
    
    acc_prime = np.count_nonzero(acc_prime_1 | acc_prime_2)
    prime_acc += acc_prime
    
    one_acc = np.count_nonzero(acc_l1 | acc_l2)
    atleast_one_acc += one_acc
    
    both_acc = np.count_nonzero(acc_l1 & acc_l2)
    total_acc += both_acc
    
    total_loss += loss.item()    

In [None]:
print(total_acc / len(test_datasets), atleast_one_acc / len(test_datasets), total_loss / len(test_datasets), prime_acc / len(test_datasets))