In [0]:
# this cell contains all the commands necessary to run this notebook in colab
# if you cloned the repository and run this notebook locally you do not need to run these commands
!wget https://raw.githubusercontent.com/wielandbrendel/robustness_workshop/v0.0.1/01_kwta/resnet.py
!wget https://raw.githubusercontent.com/wielandbrendel/robustness_workshop/v0.0.1/01_kwta/models.py

In [0]:
# run this cell the first time you execute this notebook to download the pretrained weights
!wget https://github.com/wielandbrendel/robustness_workshop/releases/download/v0.0.1/kwta_spresnet18_0.1_cifar_adv.pth

In [0]:
# install the latest master version of Foolbox 3.0
!pip3 install git+https://github.com/bethgelab/foolbox.git

In [0]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np

import resnet

### load data

In [0]:
norm_mean = 0
norm_var = 1
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])
cifar_train = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
cifar_test = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
train_loader = DataLoader(cifar_train, batch_size = 256, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size = 100, shuffle=True)

### load model

In [0]:
# if you hit a "no CUDA-capable device is detected" with colab, please
# make sure that you enabled the GPU runtime (Runtime >> Change runtime type >> select Hardware acceleration)
gamma = 0.1
eps = 0.031
filepath = f'kwta_spresnet18_{gamma}_cifar_adv.pth'
device = torch.device('cuda:0')

model = resnet.SparseResNet18(sparsities=[gamma, gamma, gamma, gamma], sparse_func='vol').to(device)
model.load_state_dict(torch.load(filepath))
model.eval();

### clean accuracy

In [0]:
acc = 0
total_number = 0

for images, labels in test_loader:
    logits = model(images.to(device))
    acc += np.sum(logits.detach().cpu().numpy().argmax(1) == labels.cpu().numpy())
    total_number += images.shape[0]

print(acc / total_number)