In [1]:
import torch, torchvision
import numpy as np
import matplotlib.pyplot as plt
import pyro
import tqdm
import os
import common

In [2]:
# Reproducibility
common.set_seed(1)

In [3]:
class NN(torch.nn.Module):
    def __init__(self, ni, nh, no):
        super(NN, self).__init__()
        self.A = torch.nn.Linear(ni, nh)
        self.relu = torch.nn.ReLU()
        self.B = torch.nn.Linear(nh, no)
    def forward(self, x):
        # Two layer neural network
        x = self.B(self.relu(self.A(x)))
        return x

In [4]:
# Train dataset
train_dataset = torchvision.datasets.MNIST('.', train=True, download=True,
                       transform=torchvision.transforms.ToTensor())
# Train data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
# Point estimate NN
net = NN(28*28, 1024, 10)

In [5]:
def model(x, y):
    # Put priors on weights and biases 
    priors = {
        "A.weight": pyro.distributions.Normal(
            loc=torch.zeros_like(net.A.weight), 
            scale=torch.ones_like(net.A.weight),
        ).independent(2),
        "A.bias": pyro.distributions.Normal(
            loc=torch.zeros_like(net.A.bias), 
            scale=torch.ones_like(net.A.bias),
        ).independent(1),
        "B.weight": pyro.distributions.Normal(
            loc=torch.zeros_like(net.B.weight), 
            scale=torch.ones_like(net.B.weight),
        ).independent(2),
        "B.bias": pyro.distributions.Normal(
            loc=torch.zeros_like(net.B.bias), 
            scale=torch.ones_like(net.B.bias),
        ).independent(1),
    }
    # Create a NN module using the priors
    lmodule = pyro.random_module("module", net, priors)
    regressor = lmodule()
    # Do a forward pass on the NN module, i.e. yhat=f(x) and condition on yhat=y
    lhat = torch.nn.LogSoftmax(dim=1)(regressor(x))
    pyro.sample("obs", pyro.distributions.Categorical(logits=lhat).independent(1), obs=y)

In [6]:
softplus = torch.nn.Softplus()
def guide(x, y):
    # Create parameters for variational distribution priors
    Aw_mu = pyro.param("Aw_mu", torch.randn_like(net.A.weight))
    Aw_sigma = softplus(pyro.param("Aw_sigma", torch.randn_like(net.A.weight)))
    Ab_mu = pyro.param("Ab_mu", torch.randn_like(net.A.bias))
    Ab_sigma = softplus(pyro.param("Ab_sigma", torch.randn_like(net.A.bias)))
    Bw_mu = pyro.param("Bw_mu", torch.randn_like(net.B.weight))
    Bw_sigma = softplus(pyro.param("Bw_sigma", torch.randn_like(net.B.weight)))
    Bb_mu = pyro.param("Bb_mu", torch.randn_like(net.B.bias))
    Bb_sigma = softplus(pyro.param("Bb_sigma", torch.randn_like(net.B.bias)))
    # Create random variables similarly to model
    priors = {
        "A.weight": pyro.distributions.Normal(loc=Aw_mu, scale=Aw_sigma).independent(2),
        "A.bias": pyro.distributions.Normal(loc=Ab_mu, scale=Ab_sigma).independent(1),
        "B.weight": pyro.distributions.Normal(loc=Bw_mu, scale=Bw_sigma).independent(2),
        "B.bias": pyro.distributions.Normal(loc=Bb_mu, scale=Bb_sigma).independent(1),
    }
    # Return NN module from these random variables
    lmodule = pyro.random_module("module", net, priors)
    return lmodule()

In [7]:
# Do stochastic variational inference to find q(w) closest to p(w|D)
svi = pyro.infer.SVI(
    model, guide, pyro.optim.Adam({'lr': 0.01}), pyro.infer.Trace_ELBO(),
)

In [8]:
def train_and_save_models(epochs = 10, K = 100, modelname = "model.pt"):
    if os.path.exists(modelname):
        print("File exists")
        return
    # Train with SVI
    for epoch in range(epochs):
        loss = 0.
        for data in train_loader:
            images, labels = data
            images = images.view(-1, 28*28)
            loss += svi.step(images, labels)
        loss /= len(train_loader.dataset)
        print("Epoch %g: Loss = %g" % (epoch, loss))
    # Sample k models from the posterior
    sampled_models = [guide(None, None) for i in range(K)]
    # Save the models
    nn_dicts = []
    for i in range(len(sampled_models)):
        nn_dicts += [sampled_models[i].state_dict()]
    torch.save(nn_dicts, modelname)
    print("Saved %d models" % K)

In [9]:
def load_models(K = 100):
    # Load the models
    sampled_models = [NN(28*28, 1024, 10) for i in range(K)]
    for net, state_dict in zip(sampled_models, torch.load("model.pt")):
        net.load_state_dict(state_dict)
    print("Loaded %d sample models" % K)
    return sampled_models

In [10]:
train_and_save_models(epochs = 10, K = 100, modelname = "model.pt")
sampled_models = load_models(K = 100)

File exists
Loaded 100 sample models


In [11]:
# Test dataset
test_dataset = torchvision.datasets.MNIST('.', train=False, download=True,
                       transform=torchvision.transforms.ToTensor())
# Test data loader with batch_size 1
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

In [12]:
# Get a batch and flatten the input
def get_images_targets_nontargets():
    images, targets = next(iter(test_loader))
    images = images.reshape(-1, 28*28)
    nontargets = []
    for j in range(10):
        if j != targets.item():
            nontargets += [torch.tensor([j])]
    return images, targets, nontargets

In [13]:
def forward_pass(model, images, loss_target = None):
    output = model(images)
    output = torch.nn.LogSoftmax(dim=-1)(output)
    which_class = torch.argmax(output).item()
    if loss_target:
        loss, target = loss_target
        loss(output, target).backward()
    return which_class

In [14]:
def otcm(images, eps, saliency):
    return torch.clamp(images.clone()-eps*saliency, 0, 1)

In [15]:
# How many models can an adversarial example fool?
def how_many_can_it_fool(sampled_models, images, eps, saliency):
    fool = 0
    for k in range(len(sampled_models)):
        # Forward pass on sampled model k
        old_class = forward_pass(sampled_models[k], images)
        # One step Target Class Method (OTCM); saliency is noise
        new_images = otcm(images, eps, saliency)
        # Forward pass again on adv. example
        new_class = forward_pass(sampled_models[k], new_images)
        # If we change the class, we fool the model
        fool += int(old_class != new_class)
    return fool/len(sampled_models)

In [16]:
# Collect noises (saliencies)
def collect_saliencies(sampled_models, images, new_targets, eps):
    saliencies = []
    how_many_fooled = []
    torch.set_printoptions(sci_mode=False)
    for k in range(len(sampled_models)):
        # Forward pass
        # Compute loss w.r.t. an incorrect class
        # Note that we just have to ensure this class is different from targets
        images.grad = None
        images.requires_grad = True
        old_class = forward_pass(sampled_models[k], images, [torch.nn.NLLLoss(), new_targets])
        # Compute adversarial example
        new_images = otcm(images, eps, images.grad.sign())
        # Forward pass on adv. example
        new_class = forward_pass(sampled_models[k], new_images)
        if old_class != new_class:
            # How many models can this adv. example fool?
            how_many_fooled += [how_many_can_it_fool(sampled_models, images, eps, images.grad.sign())]
            saliencies += [images.grad.sign().view(28, 28)]
    return saliencies, how_many_fooled

In [17]:
# distributional saliency map
def distr_saliency_map(saliencies):
    saliencies = torch.stack(saliencies)
    newsaliency = torch.zeros(28, 28)
    for i in range(28):
        for j in range(28):
            # choose median perturbation
            newsaliency[i, j] = np.percentile(saliencies[:, i, j].numpy(), 50)
    newsaliency = newsaliency.flatten()
    return newsaliency

In [18]:
# get a random combination of hyperparameters
def get_random_hyperparam(dict_of_lists):
    hp = {}
    for arg, argvals in dict_of_lists.items():
        hp[arg] = np.random.choice(argvals)
    return hp

In [19]:
# try 100 different combinations
for i in range(100):
    images, targets, nontargets = get_images_targets_nontargets()
    hp = get_random_hyperparam({
        'eps': np.arange(0.0, 0.22, 0.02),
        'classidx': nontargets
    })
    eps, classidx = hp['eps'], hp['classidx']
    saliencies, how_many_fooled = collect_saliencies(sampled_models, images, torch.tensor([classidx]), eps)
    if saliencies == []: continue
    newsaliency = distr_saliency_map(saliencies)
    indiv = np.percentile(how_many_fooled, 50)
    agg = how_many_can_it_fool(sampled_models, images, eps, newsaliency)
    print("%d => Eps:%.2f, NewTarget:%g, MedianOfIndivFool:%.2f, AggFool(MedianPixel):%.2f" % (
        i, eps, classidx, indiv, agg))

0 => Eps:0.10, NewTarget:9, MedianOfIndivFool:0.15, AggFool(MedianPixel):0.31
1 => Eps:0.18, NewTarget:5, MedianOfIndivFool:0.25, AggFool(MedianPixel):0.75
3 => Eps:0.02, NewTarget:8, MedianOfIndivFool:0.01, AggFool(MedianPixel):0.02
4 => Eps:0.12, NewTarget:2, MedianOfIndivFool:0.26, AggFool(MedianPixel):0.72
5 => Eps:0.08, NewTarget:5, MedianOfIndivFool:0.06, AggFool(MedianPixel):0.18
6 => Eps:0.04, NewTarget:5, MedianOfIndivFool:0.01, AggFool(MedianPixel):0.01
7 => Eps:0.20, NewTarget:2, MedianOfIndivFool:0.28, AggFool(MedianPixel):0.64
8 => Eps:0.08, NewTarget:7, MedianOfIndivFool:0.06, AggFool(MedianPixel):0.16
9 => Eps:0.14, NewTarget:1, MedianOfIndivFool:0.09, AggFool(MedianPixel):0.37
10 => Eps:0.14, NewTarget:0, MedianOfIndivFool:0.08, AggFool(MedianPixel):0.23
11 => Eps:0.12, NewTarget:8, MedianOfIndivFool:0.28, AggFool(MedianPixel):0.63
12 => Eps:0.12, NewTarget:1, MedianOfIndivFool:0.15, AggFool(MedianPixel):0.37
14 => Eps:0.16, NewTarget:9, MedianOfIndivFool:0.29, AggFool(

In [20]:
images, targets, _ = get_images_targets_nontargets()
for eps in np.arange(0.0, 0.22, 0.02):
    for classidx in range(10):
        if classidx != targets.item():
            saliencies, how_many_fooled = collect_saliencies(sampled_models, images, torch.tensor([classidx]), eps)
            if saliencies == []: continue
            newsaliency = distr_saliency_map(saliencies)
            indiv = np.percentile(how_many_fooled, 50)
            agg = how_many_can_it_fool(sampled_models, images, eps, newsaliency)
            print("Eps:%.2f, NewTarget:%g, MedianOfIndivFool:%.2f, AggFool(MedianPixel):%.2f" % (
                eps, classidx, indiv, agg))

Eps:0.02, NewTarget:0, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.08
Eps:0.02, NewTarget:1, MedianOfIndivFool:0.03, AggFool(MedianPixel):0.06
Eps:0.02, NewTarget:3, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.09
Eps:0.02, NewTarget:4, MedianOfIndivFool:0.04, AggFool(MedianPixel):0.06
Eps:0.02, NewTarget:5, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.06
Eps:0.02, NewTarget:6, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.07
Eps:0.02, NewTarget:7, MedianOfIndivFool:0.04, AggFool(MedianPixel):0.05
Eps:0.02, NewTarget:8, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.10
Eps:0.02, NewTarget:9, MedianOfIndivFool:0.05, AggFool(MedianPixel):0.06
Eps:0.04, NewTarget:0, MedianOfIndivFool:0.08, AggFool(MedianPixel):0.15
Eps:0.04, NewTarget:1, MedianOfIndivFool:0.06, AggFool(MedianPixel):0.17
Eps:0.04, NewTarget:3, MedianOfIndivFool:0.08, AggFool(MedianPixel):0.17
Eps:0.04, NewTarget:4, MedianOfIndivFool:0.08, AggFool(MedianPixel):0.12
Eps:0.04, NewTarget:5, MedianOfIndivFool:0.09, AggF