In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms, utils
import numpy as np
import pdb
import argparse
import time
import os
import sys
import foolbox
import wideresnet
from collections import OrderedDict
from utils import *

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

In [7]:
def remove_module_state_dict(state_dict):
    new_state_dict = OrderedDict()
    for key in state_dict.keys():
        new_key = '.'.join(key.split('.')[1:])
        new_state_dict[new_key] = state_dict[key]
    return new_state_dict

In [8]:
class train_args():
    def __init__(self, param_dict):
        # training
        self.dataset = 'cifar'
        self.batch_size = 50
        self.norm = None

        # EBM specific
        self.n_steps = 100 
        self.width = 10
        self.depth = 28
        # 
        self.n_steps_refine = 0
        self.n_classes = 10
        self.init_batch_size = 128
        self.softmax_ce = True
        # attack
        self.attack_conf = True
        self.random_init= True
        self.threshold = .7
        self.debug = True
        self.no_random_start = True
        self.load_path = None
        self.distance = 'Linf'
        self.n_steps_pgd_attack = 40
        self.start_batch = -1
        self.end_batch = 2
        self.sgld_sigma = 1e-2
        self.n_dup_chains = 5
        self.sigma = .03
        self.base_dir = './adv_results'
        # logging
        self.exp_name = 'exp'
        
        # set from inline dict
        for key in param_dict:
            #print(key, '->', param_dict[key])
            setattr(self, key, param_dict[key])

In [9]:
# setup change from defaults
inline_parms = {"load_path": "./production/ours_nb/run1/best_valid_ckpt.pt"}

# instantiate
args = train_args(inline_parms)

device = torch.device('cuda')
#args_ = vars(args)
#for key in args_.keys():
#    print('{}:   {}'.format(key,args_[key]))
base_dir = args.base_dir
save_dir = os.path.join(base_dir, args.exp_name, 'saved_model')
last_dir = os.path.join(save_dir,'last')
best_dir = os.path.join(save_dir,'best')
data_dir = os.path.join(base_dir,'data')

In [10]:
class gradient_attack_wrapper(nn.Module):
  def __init__(self, model):
    super(gradient_attack_wrapper, self).__init__()
    self.model = model.eval()

  def forward(self, x):
    x = x - 0.5
    x = x / 0.5
    x.requires_grad_()
    out = self.model.module.refined_logits(x)
    return out

  def eval(self):
    return self.model.eval()

In [11]:
model_attack_wrapper =gradient_attack_wrapper

transformer_train  = transforms.Compose([transforms.ToTensor()])
transformer_test  = transforms.Compose([transforms.ToTensor()])

data_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(data_dir, train=False,
                                                            transform=transformer_test, download=True),
                                           batch_size=args.batch_size, shuffle=False, num_workers=10)
init_loader = torch.utils.data.DataLoader(datasets.CIFAR10(data_dir, train=True,
                                                           download=True, transform=transformer_train),
                                          batch_size=args.init_batch_size, shuffle=True, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
class F(nn.Module):
    def __init__(self, depth=28, width=2, norm=None):
        super(F, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, 10)

    def forward(self, x, y=None):
        penult_z = self.f(x)
        return self.energy_output(penult_z).squeeze()

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z)


class CCF(F):
    def __init__(self, depth=28, width=2, norm=None):
        super(CCF, self).__init__(depth, width, norm=norm)

    def forward(self, x, y=None):
        logits = self.classify(x)
        if y is None:
            return logits.logsumexp(1)
        else:
            return torch.gather(logits, 1, y[:, None])

In [13]:
# construct model and ship to GPU
f = CCF(args.depth, args.width, args.norm)
print(args.load_path)
print("loading model from {args.load_path}")
ckpt_dict = torch.load(args.load_path)
if "model_state_dict" in ckpt_dict:
    # loading from a new checkpoint
    f.load_state_dict(ckpt_dict["model_state_dict"])
else:
    # loading from an old checkpoint
    f.load_state_dict(ckpt_dict)

| Wide-Resnet 28x10
ckpt_48_1-23.pt
loading model from {args.load_path}


In [14]:
# wrapper class to provide utilities for what you need
class DummyModel(nn.Module):
    def __init__(self, f):
        super(DummyModel, self).__init__()
        self.f = f

    def logits(self, x):
        return self.f.classify(x)

    def refined_logits(self, x, n_steps=args.n_steps_refine):
        xs = x.size()
        dup_x = x.view(xs[0], 1, xs[1], xs[2], xs[3]).repeat(1, args.n_dup_chains, 1, 1, 1)
        dup_x = dup_x.view(xs[0] * args.n_dup_chains, xs[1], xs[2], xs[3])
        dup_x = dup_x + torch.randn_like(dup_x) * args.sigma
        refined = self.refine(dup_x, n_steps=n_steps, detach=False)
        logits = self.logits(refined)
        logits = logits.view(x.size(0), args.n_dup_chains, logits.size(1))
        logits = logits.mean(1)
        return logits

    def classify(self, x):
        logits = self.logits(x)
        pred = logits.max(1)[1]
        return pred

    def logpx_score(self, x):
        # unnormalized logprob, unconditional on class
        return self.f(x)

    def refine(self, x, n_steps=args.n_steps_refine, detach=True):
        # runs a markov chain seeded at x, use n_steps=10
        x_k = torch.autograd.Variable(x, requires_grad=True) if detach else x
        # sgld
        for k in range(n_steps):
            f_prime = torch.autograd.grad(self.f(x_k).sum(), [x_k], retain_graph=True)[0]
            x_k.data += f_prime + args.sgld_sigma * torch.randn_like(x_k)
        final_samples = x_k.detach() if detach else x_k
        return final_samples

    def grad_norm(self, x):
        x_k = torch.autograd.Variable(x, requires_grad=True)
        f_prime = torch.autograd.grad(self.f(x_k).sum(), [x_k], retain_graph=True)[0]
        grad = f_prime.view(x.size(0), -1)
        return grad.norm(p=2, dim=1)

    def logpx_delta_score(self, x, n_steps=args.n_steps_refine):
        # difference in logprobs from input x and samples from a markov chain seeded at x
        #
        init_scores = self.f(x)
        x_r = self.refine(x, n_steps=n_steps)
        final_scores = self.f(x_r)
        # for real data final_score is only slightly higher than init_score
        return init_scores - final_scores

    def logp_grad_score(self, x):
        return -self.grad_norm(x)


In [None]:
f = DummyModel(f)
model = f.to(device)
model = nn.DataParallel(model).to(device)

model.eval()
## Define criterion
criterion = foolbox.criteria.Misclassification()

## Initiate attack and wrap model
model_wrapped = model_attack_wrapper(model)
fmodel = foolbox.models.PyTorchModel(model_wrapped, bounds=(0.,1.), num_classes=10, device=device)

if args.distance == 'L2':
    distance = foolbox.distances.MeanSquaredDistance
    attack = foolbox.attacks.L2BasicIterativeAttack(model=fmodel, criterion=criterion, distance=distance)
else:
    distance = foolbox.distances.Linfinity
    attack = foolbox.attacks.RandomStartProjectedGradientDescentAttack(model=fmodel, criterion=criterion, distance=distance)

print('Starting...')
for i, (img, label) in enumerate(data_loader):
    adversaries = []
    if i < args.start_batch:
        continue
    if i >= args.end_batch:
      break
    img = img.data.cpu().numpy()
    logits = model_wrapped(torch.from_numpy(img[:, :, :, :]).to(device))
    _, top = torch.topk(logits,k=2,dim=1)
    top = top.data.cpu().numpy()
    pred = top[:,0]
    for k in range(len(label)):
      im = img[k,:,:,:]
      orig_label = label[k].data.cpu().numpy()
      if pred[k] != orig_label:
        continue
      best_adv = None
      for ii in range(20):
          try:
            adversarial = attack(im, label=orig_label, unpack=False, random_start=True, iterations=args.n_steps_pgd_attack) 
            if ii == 0 or best_adv.distance > adversarial.distance:
                best_adv = adversarial
          except:
            continue
      try:
          adversaries.append((im, orig_label, adversarial.image, adversarial.adversarial_class))
      except:
          continue
    adv_save_dir = os.path.join(base_dir, args.exp_name)
    save_file = 'adversarials_batch_'+str(i)
    if not os.path.exists(os.path.join(adv_save_dir,save_file)):
        os.makedirs(os.path.join(adv_save_dir,save_file))
    np.save(os.path.join(adv_save_dir,save_file),adversaries)

Starting...


  'The PyTorch model is in training mode and therefore might'
