### Packages

In [None]:
cd '/h/ama/workspace/ama-at-vector/best-mix'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import math
import pickle

from mixup import mixup_graph
import time
from utils_mixup import gradmix_v2, gradmix_v2_improved

os.environ['KMP_DUPLICATE_LIB_OK']='True'
%matplotlib inline

In [4]:
def print_fig(input, target=None, title=None, save_dir=None):
    fig, axes = plt.subplots(1,len(input),figsize=(3*len(input),3))
    if title:
        fig.suptitle(title, size=16)
    if len(input) == 1 :
        axes = [axes]
        
    for i, ax in enumerate(axes):
        if len(input.shape) == 4:
            ax.imshow(input[i].permute(1,2,0).numpy())
        else :
            ax.imshow(input[i].numpy(), cmap='gray', vmin=0., vmax=1.)
        
        if target is not None:
            output = net((input[i].unsqueeze(0) - mean)/std)
            loss = criterion(output, target[i:i+1])
            ax.set_title("loss: {:.3f}\n pred: {}\n true : {}".format(loss, CIFAR100_LABELS_LIST[output.max(1)[1][0]], CIFAR100_LABELS_LIST[target[i]]))
        ax.axis('off')
    plt.subplots_adjust(wspace = 0.1)
    
    if save_dir is not None:
        plt.savefig(save_dir, bbox_inches = 'tight',  pad_inches = 0)
        
    plt.show()

### Model, Data, Saliency

In [5]:
''' Model '''
import models
# import torchvision.models as models
from load_data import load_data_subset
from collections import OrderedDict


resnet = models.__dict__['preactresnet18'](10, False, 1).cuda()

# checkpoint = torch.load('/group-volume/Multimodal-Learning/ssl/vse_files/runs/fast_autoaugment/models/cifar10_preact_ckpt/vanilla.pth.tar')
checkpoint = torch.load('checkpoint/cifar10_preact_ckpt_vanilla.pth.tar')

od = OrderedDict()
for key in checkpoint['state_dict'].keys():
    od[key[7:]] = checkpoint['state_dict'][key]
resnet.load_state_dict(od)

# resnet = models.resnet18(pretrained=True)


# resnet.load_state_dict(checkpoint['state_dict'])
# mean = torch.tensor([x / 255 for x in [125.3, 123.0, 113.9]],dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
# std = torch.tensor([x / 255 for x in [63.0, 62.1, 66.7]], dtype=torch.float32).reshape(1, 3, 1, 1).cuda()
labels_per_class = 5000
mean = torch.tensor([125.3, 123.0, 113.9])/255
std = torch.tensor([63.0, 62.1, 66.7])/255
mean_torch = mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
std_torch = std.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)

criterion = nn.CrossEntropyLoss()


batch_size = 100
workers=2
dataset='cifar10'
# data_dir='/group-volume/Multimodal-Learning/ssl/vse_files/runs/fast_autoaugment/data'
data_dir='data'
valid_labels_per_class=0
mixup_alpha=0
train_loader, valid_loader, _, test_loader, num_classes = load_data_subset(batch_size,workers,dataset,data_dir,labels_per_class=labels_per_class,valid_labels_per_class=valid_labels_per_class,mixup_alpha=mixup_alpha)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
''' Data '''

sample_num=100
for x,y in test_loader:
    input_sp,targets = x[:sample_num,:],y[:sample_num]
    break
    
# print_fig((input_sp * std_torch + mean_torch)[:sample_num])

In [7]:
''' Saliency '''
resnet.cpu()
resnet.eval()
input_var = input_sp[:sample_num].clone().detach().requires_grad_(True)
output = resnet(input_var)
loss = criterion(output, targets[:sample_num])
loss.backward()

blurr = torchvision.transforms.GaussianBlur(5, sigma=(1.0, 1.0))
grad = blurr(input_var.grad.detach().abs().mean(dim=1).squeeze())


### Ours

In [8]:
total_time_list = []
update_time_list = []
update_counter_list = []
for i in range(10):
    tic = time.perf_counter()
    mixed_x, mixed_y, mixed_lam= gradmix_v2(input_sp.cuda(), targets.cuda(), grad.unsqueeze(1).cuda(), 
                                                 alpha = 0.5, normalization = 'standard', stride = 1, debug=False, 
                                                 rand_pos = 0)


    toc = time.perf_counter()
    total_time_list.append(toc-tic)
#     update_time_list.append(total_time_in_update)
#     update_counter_list.append(update_counter)
# update_time_list = np.array(update_time_list)
total_time_list= np.array(total_time_list)
# update_counter_list=np.array(update_counter_list)
print(f"Average total time: {total_time_list.mean():0.4f} seconds (var: {total_time_list.var():0.4f})")
# print(f"Average time on update: {update_time_list.mean():0.4f} seconds (var: {update_time_list.var():0.4f})")
# print(f"Average update: {update_counter_list.mean():0.4f} times")
print(f"Max/Min time: {total_time_list.max():0.4f}/{total_time_list.min():0.4f} seconds")


Average total time: 1.0593 seconds (var: 0.0008)
Max/Min time: 1.1031/1.0076 seconds


In [9]:
total_time_list = []
update_time_list = []
update_counter_list = []
for i in range(10):
    tic = time.perf_counter()
    mixed_x, mixed_y, mixed_lam= gradmix_v2(input_sp.cuda(), targets.cuda(), grad.unsqueeze(1).cuda(), 
                                                 alpha = 0.5, normalization = 'standard', stride = 1, debug=False, 
                                                 rand_pos = 1)


    toc = time.perf_counter()
    total_time_list.append(toc-tic)
#     update_time_list.append(total_time_in_update)
#     update_counter_list.append(update_counter)
# update_time_list = np.array(update_time_list)
total_time_list= np.array(total_time_list)
# update_counter_list=np.array(update_counter_list)
print(f"Average total time: {total_time_list.mean():0.4f} seconds (var: {total_time_list.var():0.4f})")
# print(f"Average time on update: {update_time_list.mean():0.4f} seconds (var: {update_time_list.var():0.4f})")
# print(f"Average update: {update_counter_list.mean():0.4f} times")
print(f"Max/Min time: {total_time_list.max():0.4f}/{total_time_list.min():0.4f} seconds")


Average total time: 0.6084 seconds (var: 0.0001)
Max/Min time: 0.6223/0.5997 seconds


In [13]:
total_time_list = []
update_time_list = []
update_counter_list = []
for i in range(10):
    tic = time.perf_counter()
    mixed_x, mixed_y, mixed_lam= gradmix_v2_improved(input_sp.cuda(), targets.cuda(), grad.unsqueeze(1).cuda(), 
                                                 alpha = 0.5, normalization = 'standard', stride = 1, debug=False, 
                                                 rand_pos = 0)


    toc = time.perf_counter()
    total_time_list.append(toc-tic)
#     update_time_list.append(total_time_in_update)
#     update_counter_list.append(update_counter)
# update_time_list = np.array(update_time_list)
total_time_list= np.array(total_time_list)
# update_counter_list=np.array(update_counter_list)
print(f"Average total time: {total_time_list.mean():0.4f} seconds (var: {total_time_list.var():0.4f})")
# print(f"Average time on update: {update_time_list.mean():0.4f} seconds (var: {update_time_list.var():0.4f})")
# print(f"Average update: {update_counter_list.mean():0.4f} times")
print(f"Max/Min time: {total_time_list.max():0.4f}/{total_time_list.min():0.4f} seconds")


Average total time: 0.4760 seconds (var: 0.0002)
Max/Min time: 0.4966/0.4433 seconds


In [12]:
total_time_list = []
update_time_list = []
update_counter_list = []
for i in range(10):
    tic = time.perf_counter()
    mixed_x, mixed_y, mixed_lam= gradmix_v2_improved(input_sp.cuda(), targets.cuda(), grad.unsqueeze(1).cuda(), 
                                                 alpha = 0.5, normalization = 'standard', stride = 1, debug=False, 
                                                 rand_pos = 1)


    toc = time.perf_counter()
    total_time_list.append(toc-tic)
#     update_time_list.append(total_time_in_update)
#     update_counter_list.append(update_counter)
# update_time_list = np.array(update_time_list)
total_time_list= np.array(total_time_list)
# update_counter_list=np.array(update_counter_list)
print(f"Average total time: {total_time_list.mean():0.4f} seconds (var: {total_time_list.var():0.4f})")
# print(f"Average time on update: {update_time_list.mean():0.4f} seconds (var: {update_time_list.var():0.4f})")
# print(f"Average update: {update_counter_list.mean():0.4f} times")
print(f"Max/Min time: {total_time_list.max():0.4f}/{total_time_list.min():0.4f} seconds")


Average total time: 0.3696 seconds (var: 0.0001)
Max/Min time: 0.3843/0.3586 seconds
