In [None]:
cd '../'

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import torch
import numpy as np

from src.utils_freq import rgb2gray, dct, dct2, idct, idct2, batch_dct2, getDCTmatrix,batch_idct2, batch_dct2_3channel, batch_idct2_3channel, equal_dist_from_top_left, mask_radial


from models import PreActResNet18

import matplotlib.pyplot as plt
_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


from src.utils_dataset import load_dataset
from src.evaluation import test_clean, test_gaussian, test_gaussian_LF_HF


In [None]:
dataset = 'cifar100'
num_classes = 100
dct_matrix = getDCTmatrix(28)
train_loader, test_loader = load_dataset(dataset, 128)
    
model_sgd_path = '/scratch/ssd001/home/ama/workspace/opt-robust/ckpt/'+dataset+'-sgd.pt'
model_adam_path = '/scratch/ssd001/home/ama/workspace/opt-robust/ckpt/'+dataset+'-adam.pt'
model_rmsp_path = '/scratch/ssd001/home/ama/workspace/opt-robust/ckpt/'+dataset+'-rmsprop.pt'


var = 0.01
r_list = [10.781431922956449, 15.301968130724937, 19.04982421996594, 22.040185695902814, 24.75901837969084, 27.220978213474407, 29.441362859363448, 31.57831841836015, 34.71337554275011]
model_sgd = PreActResNet18(dataset, num_classes, False, False)
model_adam = PreActResNet18(dataset, num_classes, False, False)
model_rmsp = PreActResNet18(dataset, num_classes, False, False)


#sgd
model_sgd.load_state_dict(torch.load(model_sgd_path))
#adam
model_adam.load_state_dict(torch.load(model_adam_path))
#rmsp
model_rmsp.load_state_dict(torch.load(model_rmsp_path))

model_sgd.to(_device)
model_adam.to(_device)
model_rmsp.to(_device)
print('model loaded')

In [None]:
test_acc_sgd, test_loss, _ = test_clean(test_loader, model_sgd, _device)
test_acc_adam, test_loss, _ = test_clean(test_loader, model_adam, _device)
test_acc_rmsp, test_loss, _ = test_clean(test_loader, model_rmsp, _device)

In [None]:
print("**************** accuracy ****************")
print('SGD\t{:.2f}%'.format(test_acc_sgd))
print('ADAM\t{:.2f}%'.format(test_acc_adam))
print('RmsProp\t{:.2f}%'.format(test_acc_rmsp))

In [None]:
num_noise = 1

acc_sgd = np.zeros([len(r_list)+1,1])
acc_adam = np.zeros([len(r_list)+1,1])
acc_rmsp = np.zeros([len(r_list)+1,1])

_loss_sgd, _acc_sgd = test_gaussian_LF_HF(test_loader, dataset, model_sgd, var, r_list, num_noise, _device)
_loss_adam, _acc_adam = test_gaussian_LF_HF(test_loader, dataset, model_adam, var, r_list, num_noise, _device)
_loss_rmsp, _acc_rmsp = test_gaussian_LF_HF(test_loader, dataset, model_rmsp, var, r_list, num_noise, _device)

acc_sgd[:,0] = _acc_sgd[0] - _acc_sgd[1]
acc_adam[:,0] = _acc_adam[0] - _acc_adam[1]
acc_rmsp[:,0] = _acc_rmsp[0] - _acc_rmsp[1]

In [None]:
if dataset == "cifar10":
    print_dataset = 'CIFAR10'
elif dataset == "cifar100":
    print_dataset = 'CIFAR100'

fix, axs = plt.subplots(nrows = 1, ncols=1, figsize=(8, 4))

axs.plot(acc_sgd.mean(axis=1), color = 'C0', label = 'SGD')
xrange = np.arange(acc_sgd.shape[0])
fill_up = acc_sgd.mean(axis=1) + acc_sgd.std(axis=1)/acc_sgd.shape[1]
fill_low = acc_sgd.mean(axis=1) - acc_sgd.std(axis=1)/acc_sgd.shape[1]
axs.fill_between(xrange, fill_up, fill_low, color = "C0", alpha=0.3)


axs.plot(acc_adam.mean(axis=1), color = 'C1', label = 'Adam')
xrange = np.arange(acc_adam.shape[0])
fill_up = acc_adam.mean(axis=1) + acc_adam.std(axis=1)/acc_adam.shape[1]
fill_low = acc_adam.mean(axis=1) - acc_adam.std(axis=1)/acc_adam.shape[1]
axs.fill_between(xrange, fill_up, fill_low, color = "C1", alpha=0.3)

axs.plot(acc_rmsp.mean(axis=1), color = 'C2', label = 'RMSProp')
xrange = np.arange(acc_rmsp.shape[0])
fill_up = acc_rmsp.mean(axis=1) + acc_rmsp.std(axis=1)/acc_rmsp.shape[1]
fill_low = acc_rmsp.mean(axis=1) - acc_rmsp.std(axis=1)/acc_rmsp.shape[1]
axs.fill_between(xrange, fill_up, fill_low, color = "C2", alpha=0.3)

axs.set_ylabel('Accuracy change under \n band-limited perturbations (%)', fontsize=13)
axs.set_xlabel('Perturbed frequency band (r)', fontsize=15)
axs.set_title('{}: Freq contribution to acc change'.format(print_dataset), fontsize=14)
axs.legend()
axs.grid()

plt.tight_layout()
plt.savefig('./notebook/{}.pdf'.format(print_dataset))  