In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import torch
from torch import functional as F
from torch import nn
from torchvision import datasets, transforms, models
from scipy.stats import norm, binom_test
from statsmodels.stats.proportion import proportion_confint
import time

  import pandas.util.testing as tm


In [23]:
class SmoothingWrapper:  # Not a Module
    def __init__(self, model, sigma, batch_size=64):
        self.model = model
        self.sigma = sigma
        self.batch_size = batch_size

        
    def predict(self, x, alpha, mc_size=None):
        """ Gets prediction counts for MC
        :param x: the input [channel x height x width]
        :param alpha: the failure probability
        :param mc_size: the number of Monte Carlo samples to use
        :return: predicted class or None
        """
        counts = self.get_prediction_counts(x, mc_size)
        count1, count2 = counts.sort(descending=True)[:2]
        if binom_test(count1, count1 + count2) < alpha:
            return counts.argmax().item()
        else:
            return None
        

    def get_prediction_counts(self, x, mc_size=None):
        """ Gets prediction counts for MC
        :param x: the input [channel x height x width]
        :param mc_size: the number of Monte Carlo samples to use
        :return: counts themselves
        """
        if mc_size is None:
            mc_size = self.batch_size
        counts = None
        with torch.no_grad():
            for processed in range(0, mc_size, self.batch_size):
                sz = min(self.batch_size, mc_size - processed)
                outputs = self.model(self._get_noised_inputs(x, sz))
                add = outputs.argmax(1).bincount(minlength=outputs.shape[1])
                if counts is None:
                    counts = add
                else:
                    counts += add
        return counts
            
        
    def _get_noised_inputs(self, x, size):
        """ Gets noised inputs
        :param x: the input [channel x height x width]
        :return:[size x ch x heights x width] - samples from N(x, diag(self.sigma**2))
        """
        # x: [ch x heights x width]
        # returns [size x ch x heights x width] with size examples shifted by noise
        
        with torch.no_grad():
            x = x.expand((size, *x.shape))
            return x + torch.randn_like(x, device=x.device) * self.sigma
        
        
    def certify(self, x, n_sel, n_est, alpha):
        """ Certify radius with probability 1 - alpha
        :param x: the input [channel x height x width]
        :param n_sel: the number of Monte Carlo samples to use for selection
        :param n_est: the number of Monte Carlo samples to use for estimation
        :param alpha: the error probability
        :return: (predicted class, certified radius) or (None, None)
        """
        counts_selection = self.get_prediction_counts(x, n_sel)
        result_class = counts_selection.argmax().item()
        count_est = self.get_prediction_counts(x, n_est)[result_class].item()
        est_lo, est_hi = proportion_confint(count_est, n_est, 2 * alpha, method='beta')
        if est_lo < 0.05:
            return None, None
        else:
            return result_class, self.sigma * norm.ppf(est_lo)
        
        
    def eval(self):
        self.model.eval()
        
    def train(self):
        self.model.train()
        
    def get_training_output(self, x):
        """ Add noise to the sample for training and compute the model there
        :param x: the input sample [batch x channel x height x width]
        :return: output [batch x classes]
        """
        return self.model(x + torch.randn_like(x, device=x.device) * self.sigma)

In [4]:
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        self.resnet18.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
        self.resnet18.fc = nn.Linear(512, 10)

    def forward(self, x):
        return self.resnet18(x)

In [5]:
class PeriodicPrinter:
    def __init__(self, secs):
        self.secs = secs
        self.last_time = -secs

    def print(self, *args):
        cur_time = time.time()
        if self.last_time + self.secs < cur_time:
            print(*args)
            self.last_time = cur_time
            return True
        else:
            return False

In [6]:
class AverageCounter:
    def __init__(self):
        self.value = 0.
        self.total = 0

    def get_average(self):
        if self.total == 0:
            return float('nan')
        else:
            return self.value / self.total
    def add(self, x):
        self.total += 1
        self.value += x

    def zero(self, right_now=True):
        if right_now:
            self.total = self.value = 0

In [7]:
def train(loader, wrapper, criterion, optimizer, device=torch.device('cpu')):
    wrapper.train()
    loss_cnt = AverageCounter()
    printer = PeriodicPrinter(5)

    for i, (X, y) in enumerate(loader):
        loss = criterion(wrapper.get_training_output(X.to(device)), y.to(device))
        loss_cnt.add(loss.item())
        loss_cnt.zero(printer.print('{:.4f}'.format(loss_cnt.get_average())))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

In [8]:
device = torch.device('cuda')

In [9]:
mnist_train_dataset = datasets.MNIST('mnist/', download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ]))

In [10]:
mnist_test_dataset = datasets.MNIST('mnist/', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ]))

In [11]:
train_loader = torch.utils.data.DataLoader(mnist_train_dataset, 256)

In [37]:
criterion = nn.CrossEntropyLoss()
wrapper = SmoothingWrapper(MnistModel().to(device), 1, 256)
optimizer = torch.optim.Adam(wrapper.model.parameters(), lr=1e-3)

In [38]:
epoch = 0
while epoch < 100:
    epoch += 1
    print(f'Epoch {epoch}')
    train(train_loader, wrapper, criterion, optimizer, device)

Epoch 1
2.6239
0.6684
0.2256
0.1696
Epoch 2
0.1390
0.1219
0.1060
0.1015
Epoch 3
0.1189
0.0891
0.0774
0.0783
Epoch 4
0.1068
0.0661
0.0644
0.0712
Epoch 5
0.0861
0.0576
0.0538
0.0569
Epoch 6
0.0824
0.0529
0.0514
0.0509
Epoch 7
0.0700
0.0485
0.0443
0.0491
Epoch 8
0.0583
0.0438
0.0446
0.0438
Epoch 9
0.0407
0.0378
0.0404
0.0386
Epoch 10
0.0481
0.0348
0.0411
0.0372
Epoch 11
0.0200
0.0361
0.0342
0.0369
Epoch 12
0.0576
0.0341
0.0296
0.0333
Epoch 13
0.0342
0.0336
0.0340
0.0344
Epoch 14
0.0230
0.0326
0.0292
0.0368
Epoch 15
0.0290
0.0295
0.0288
0.0335
Epoch 16
0.0203
0.0304
0.0296
0.0326
Epoch 17
0.0211
0.0251
0.0255
0.0257
Epoch 18
0.0438
0.0221
0.0279
0.0273
Epoch 19
0.0182
0.0246
0.0257
0.0276
Epoch 20
0.0270
0.0266
0.0209
0.0268
Epoch 21
0.0187
0.0224
0.0306
0.0232
Epoch 22
0.0112
0.0237
0.0226
0.0242
Epoch 23
0.0097
0.0237
0.0182
0.0190
Epoch 24
0.0653
0.0204
0.0152
0.0266
Epoch 25
0.0352
0.0201
0.0210
0.0186
Epoch 26
0.0331
0.0194
0.0180


KeyboardInterrupt: ignored

In [41]:
def test(dataset, wrapper, n, alpha, radii, device):
    wrapper.eval()
    radii = np.asarray(radii)
    accuracies = np.zeros(len(radii))
    printer = PeriodicPrinter(5)
    res = np.zeros(len(dataset))
    for i, (X, y) in enumerate(dataset):
        cl, radius = wrapper.certify(X.to(device), n, n, alpha)
        if cl == y:
            res[i] = radius
            accuracies[radii < radius] += 1
        else:
            res[i] = float('nan')
        if printer.print(f'Iteration {i + 1}'):
            for r, a in zip(radii, accuracies / (1 + i) * 100):
                 print(f'Accuracy @ {r} = {a}%')
    for r, a in zip(radii, accuracies / len(dataset)):
        print(f'Accuracy @ {r} = {a * 100}%')
    return res

In [None]:
res = test(mnist_test_dataset, wrapper, 10000, 0.001, [0, 1., 2.5, 3., 3.5, 4., 4.5, 5., 10.], device)

Iteration 1
Accuracy @ 0.0 = 100.0%
Accuracy @ 1.0 = 100.0%
Accuracy @ 2.5 = 100.0%
Accuracy @ 3.0 = 100.0%
Accuracy @ 3.5 = 0.0%
Accuracy @ 4.0 = 0.0%
Accuracy @ 4.5 = 0.0%
Accuracy @ 5.0 = 0.0%
Accuracy @ 10.0 = 0.0%
Iteration 9
Accuracy @ 0.0 = 100.0%
Accuracy @ 1.0 = 100.0%
Accuracy @ 2.5 = 88.88888888888889%
Accuracy @ 3.0 = 88.88888888888889%
Accuracy @ 3.5 = 0.0%
Accuracy @ 4.0 = 0.0%
Accuracy @ 4.5 = 0.0%
Accuracy @ 5.0 = 0.0%
Accuracy @ 10.0 = 0.0%
Iteration 17
Accuracy @ 0.0 = 100.0%
Accuracy @ 1.0 = 100.0%
Accuracy @ 2.5 = 94.11764705882352%
Accuracy @ 3.0 = 94.11764705882352%
Accuracy @ 3.5 = 0.0%
Accuracy @ 4.0 = 0.0%
Accuracy @ 4.5 = 0.0%
Accuracy @ 5.0 = 0.0%
Accuracy @ 10.0 = 0.0%
Iteration 25
Accuracy @ 0.0 = 100.0%
Accuracy @ 1.0 = 96.0%
Accuracy @ 2.5 = 92.0%
Accuracy @ 3.0 = 84.0%
Accuracy @ 3.5 = 0.0%
Accuracy @ 4.0 = 0.0%
Accuracy @ 4.5 = 0.0%
Accuracy @ 5.0 = 0.0%
Accuracy @ 10.0 = 0.0%
Iteration 33
Accuracy @ 0.0 = 100.0%
Accuracy @ 1.0 = 96.96969696969697%
Accu

In [32]:
len(mnist_test_dataset)

10000