In [7]:
import argparse
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD, FGSM
from advertorch.attacks import LinfSPSAAttack
from trainers import Trainer, FGSMTrainer
from robustbench.model_zoo.models import Carmon2019UnlabeledNet
from utils import adversarial_accuracy, fgsm_, gradient_norm
import eagerpy as ep
from Nets import CIFAR_Wide_Res_Net, CIFAR_Res_Net
%load_ext autoreload
%autoreload 2
%aimport Nets, trainers

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# setup
device = torch.device("cpu")
batch_size = 8
# remove the normalize
transform = transform = transforms.Compose(
            [transforms.ToTensor()]
)
        
normalized_min = (0 - 0.5) / 0.5
normalized_max = (1 - 0.5) / 0.5
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=True, num_workers=2)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                               download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                 shuffle=False, num_workers=2)
classes = classes = ('plane', 'car', 'bird', 'cat',
   'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


## Load Regular CIFAR-10 model

In [3]:
model = CIFAR_Res_Net(device).eval()
model.load_state_dict(torch.load("models/cifar_res_net.model"))

<All keys matched successfully>

## Load CIFAR-10 Model trained with large FGSM steps

In [None]:
fgsm_model = CIFAR_Res_Net(device).eval()
fgsm_model.load_state_dict(torch.load("models/cifar_res_net_fgsm06.model"))

### Check gradient norms

In [5]:
n_examples = 2000
x = torch.cat([train_dataset[i][0].unsqueeze(0) for i in range(n_examples)]).to(device)
y = torch.LongTensor([train_dataset[i][1] for i in range(n_examples)]).to(device)

In [8]:
gradient_norm(model, x, y, device=device)

KeyboardInterrupt: 