In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorflow.python.platform import flags
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.autograd.gradcheck import zero_gradients

from cleverhans.attacks import FastGradientMethod
from cleverhans.model import CallableModelWrapper
from cleverhans.utils import AccuracyReport
from cleverhans.utils_pytorch import convert_pytorch_model_to_tf

FLAGS = flags.FLAGS


class nnModel(nn.Module):
    """ Basic MNIST model from github
    https://github.com/rickiepark/pytorch-examples/blob/master/mnist.ipynb
    """

    def __init__(self):
        super(nnModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding

        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64 * 7 * 7)  # reshape Variable
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class aceModel_f(nn.Module):
    """ Basic MNIST model from github
    https://github.com/rickiepark/pytorch-examples/blob/master/mnist.ipynb
    """

    def __init__(self):
        super(aceModel_f, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64 * 7 * 7, 128)

    def forward(self, x):
        f = F.max_pool2d(F.relu(self.conv1(x)), 2)
        f = F.max_pool2d(F.relu(self.conv2(f)), 2)
        f = f.view(-1, 64 * 7 * 7)  # reshape Variable
        f = F.relu(self.fc1(f))
        return f

class aceModel_g(nn.Module):
    """ Basic MNIST model from github
    https://github.com/rickiepark/pytorch-examples/blob/master/mnist.ipynb
    """
    def __init__(self):
        super(aceModel_g, self).__init__()
        self.fc1 = nn.Linear(10, 128)

    def forward(self, y):
        g = F.relu(self.fc1(y))
        return g

class aceModel(nn.Module):
    """ Basic MNIST model from github
    https://github.com/rickiepark/pytorch-examples/blob/master/mnist.ipynb
    """
    def __init__(self):
        super(aceModel, self).__init__()
        # input is 28x28
        # padding=2 for same padding
        self.conv1 = nn.Conv2d(1, 32, 5, padding=2)
        self.conv2 = nn.Conv2d(32, 64, 5, padding=2)
        # feature map size is 14*14 by pooling
        # padding=2 for same padding

        # feature map size is 7*7 by pooling
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 64 * 7 * 7)  # reshape Variable
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def neg_hscore(f,g):
    f0 = f - torch.mean(f,0)
    g0 = g - torch.mean(g,0)
    corr = torch.mean(torch.sum(f0*g0,1))
    cov_f = torch.mm(torch.t(f0),f0) / (f0.size()[0]-1.)
    cov_g = torch.mm(torch.t(g0),g0) / (g0.size()[0]-1.)
    return - corr + torch.trace(torch.mm(cov_f, cov_g)) / 2.

In [2]:
nb_epochs=1
batch_size=128
train_end=-1
test_end=-1
learning_rate=0.001

model_ace = aceModel()
model_nn = nnModel()
model_f = aceModel_f()
model_g = aceModel_g()

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size)

train_loader.dataset.train_data = train_loader.dataset.train_data[:train_end]
test_loader.dataset.test_data = test_loader.dataset.test_data[:test_end]

optimizer_nn = optim.Adam(model_nn.parameters(),lr=learning_rate)
optimizer_ace = optim.Adam(list(model_f.parameters())+list(model_g.parameters()), lr=learning_rate)
     

In [3]:
for epoch in range(nb_epochs):
    train_total = 0
    correct = 0
    step = 0
    py = np.zeros((1,10))
    loss_nn_total=0
    for xs, ys in train_loader:
        ys_1hot = torch.zeros(len(ys), 10).scatter_(1, ys.resize(len(ys),1), 1)
        # print(ys,ys.size(),ys.type())
        xs, ys = Variable(xs), Variable(ys)
        if torch.cuda.is_available():
            xs, ys ,ys_1hot = xs.cuda(), ys.cuda(), ys_1hot.cuda()
        optimizer_nn.zero_grad()
        optimizer_ace.zero_grad()
        logits_nn = model_nn(xs)
        # pred = torch.max(logits_nn,1)[1]
        # acc = (pred==ys).sum()
        # print(xs[-1])
        f = model_f(xs)
        g = model_g(ys_1hot)
        loss_ace = neg_hscore(f,g)
        loss_ace.backward()
        loss_nn = F.cross_entropy(logits_nn, ys)
        loss_nn.backward()  # calc gradients
        # print(loss_nn)
        optimizer_nn.step()
        optimizer_ace.step()  # update gradients
        py = py + torch.sum(ys_1hot,0).cpu().numpy()
        train_total += len(xs)
    print('Epoch {} finished.'.format(epoch))
    print("loss_nn:{},loss_ace:{}".format(loss_nn,loss_ace))
    py = py.astype(float) / train_total
    # Evaluate on clean data
    total = 0
    correct_ace = 0
    correct_nn = 0
    if torch.cuda.is_available():
        eye = torch.eye(10).cuda()
    else:
        eye = torch.eye(10)
    g_test = model_g(eye).data.cpu().numpy()
    g_test = g_test - np.mean(g_test, axis = 0)
    for xs, ys in test_loader:
        xs, ys = Variable(xs), Variable(ys)
        if torch.cuda.is_available():
            xs, ys = xs.cuda(), ys.cuda()

        logits_nn = model_nn(xs).data.cpu().numpy()
        # logits_nn_np = logits_nn.data.cpu().numpy()
        f_test = model_f(xs).data.cpu().numpy()
        f_test = f_test - np.mean(f_test, axis = 0)

        # py = np.mean(y_train, axis = 0)
        pygx = py * (1 + np.matmul(f_test, g_test.T))
        # ace_acc = np.mean(np.argmax(pygx, axis = 1) == np.argmax(y_test, axis = 1))

        correct_ace += (np.argmax(pygx, axis = 1) == ys).sum()
        correct_nn += (np.argmax(logits_nn, axis=1) == ys).sum()
        total += len(xs)

    nn_acc = float(correct_nn) / total
    ace_acc = float(correct_ace) / total

    print('NN test accuracy: %.2f%%' % (nn_acc * 100))
    print('ACE test accuracy: %.2f%%' % (ace_acc * 100))

# print("py:{},pygx:{}".format(py,pygx))
model_f_dict = model_f.state_dict()
model_ace_dict = model_ace.state_dict()
model_f_dict = {key: value for key, value in model_f_dict.items() if key in model_ace_dict}
if torch.cuda.is_available():
    model_f_dict['fc2.weight'] = torch.from_numpy((py * g_test.T).T).cuda()
    model_f_dict['fc2.bias'] = torch.from_numpy(py).view(10).cuda()
else:
    model_f_dict['fc2.weight'] = torch.from_numpy((py * g_test.T).T)
    model_f_dict['fc2.bias'] = torch.from_numpy(py).view(10)           
model_ace_dict.update(model_f_dict)
model_ace.load_state_dict(model_f_dict)

torch.save(model_nn.state_dict(), 'model_nn_cpu_pytorch_{}epochs.pkl'.format(nb_epochs))
torch.save(model_ace.state_dict(), 'model_ace_cpu_pytorch_{}epochs.pkl'.format(nb_epochs))



Epoch 0 finished.
loss_nn:0.08271235227584839,loss_ace:-3.479642629623413
NN test accuracy: 98.24%
ACE test accuracy: 98.69%


In [17]:
print("Adversarial Training Start.")
eps=0.2
epochs=1
for epoch in range(epochs): 
    for xs,ys in train_loader:
        ys_1hot = torch.zeros(len(ys), 10).scatter_(1, ys.resize(len(ys),1), 1)

        x_nn, y_nn = Variable(xs, requires_grad=True), Variable(ys)
        x_ace, y_ace, y_1hot_ace = Variable(xs, requires_grad=True), Variable(ys), Variable(ys_1hot)

        zero_gradients(x_nn)
        out = model_nn(x_nn)
        _loss = F.cross_entropy(out, y_nn)
        _loss.backward()
        normed_grad = eps * torch.sign(x_nn.grad.data)
        nn_adv = x_nn.data + normed_grad
        nn_adv = torch.clamp(nn_adv, 0.0, 1.0)

        zero_gradients(x_ace)
        f = model_f(x_ace)
        g = model_g(y_1hot_ace)
        _loss_ace = neg_hscore(f,g)
        _loss_ace.backward()
        normed_grad = eps * torch.sign(x_ace.grad.data)
        ace_adv = x_ace.data + normed_grad
        ace_adv = torch.clamp(ace_adv, 0.0, 1.0)

        optimizer_nn.zero_grad()
        logits_nn = model_nn(nn_adv)
        loss_nn = F.cross_entropy(logits_nn, y_nn)
        loss_nn.backward()  # calc gradients
        optimizer_nn.step()

        optimizer_ace.zero_grad()
        f = model_f(ace_adv)
        g = model_g(ys_1hot)
        loss_ace = neg_hscore(f,g)
        loss_ace.backward()
        optimizer_ace.step()  # update gradients
    print("Epoch {} finished.".format(epoch+1))
    
g_test = model_g(torch.eye(10)).data.cpu().numpy()
g_test = g_test - np.mean(g_test, axis = 0)
model_f_dict = model_f.state_dict()
model_ace_dict = model_ace.state_dict()
model_f_dict = {key: value for key, value in model_f_dict.items() if key in model_ace_dict}
if torch.cuda.is_available():
    model_f_dict['fc2.weight'] = torch.from_numpy((py * g_test.T).T).cuda()
    model_f_dict['fc2.bias'] = torch.from_numpy(py).view(10).cuda()
else:
    model_f_dict['fc2.weight'] = torch.from_numpy((py * g_test.T).T)
    model_f_dict['fc2.bias'] = torch.from_numpy(py).view(10)           
model_ace_dict.update(model_f_dict)
model_ace.load_state_dict(model_f_dict)
print("Adversarial Training Finish.")

Adversarial Training Start.




Epoch 1 finished.
Adversarial Training Finish.


In [18]:
model_ace_ori = aceModel()
model_ace_ori.load_state_dict(torch.load('model_ace_cpu_pytorch_1epochs.pkl'))
model_nn_ori = nnModel()
model_nn_ori.load_state_dict(torch.load('model_nn_cpu_pytorch_1epochs.pkl'))


sess = tf.Session()
x_op = tf.placeholder(tf.float32, shape=(None, 1, 28, 28,))

# Convert pytorch model to a tf_model and wrap it in cleverhans
tf_model_nn = convert_pytorch_model_to_tf(model_nn)
tf_model_ace = convert_pytorch_model_to_tf(model_ace)
tf_model_nn_ori = convert_pytorch_model_to_tf(model_nn_ori)
tf_model_ace_ori = convert_pytorch_model_to_tf(model_ace_ori)
cleverhans_model_nn = CallableModelWrapper(tf_model_nn, output_layer='logits')
cleverhans_model_ace = CallableModelWrapper(tf_model_ace, output_layer='logits')
cleverhans_model_nn_ori = CallableModelWrapper(tf_model_nn_ori, output_layer='logits')
cleverhans_model_ace_ori = CallableModelWrapper(tf_model_ace_ori, output_layer='logits')

# Create an FGSM attack
for eps in np.arange(0.05,0.45,0.05):

    fgsm_params = {'eps': eps,
                   'clip_min': 0.,
                   'clip_max': 1.}

    fgsm_nn = FastGradientMethod(cleverhans_model_nn, sess=sess)
    adv_x_nn = fgsm_nn.generate(x_op, **fgsm_params)
    adv_pred_nn = tf_model_nn(adv_x_nn)

    fgsm_ace = FastGradientMethod(cleverhans_model_ace, sess=sess)
    adv_x_ace = fgsm_ace.generate(x_op, **fgsm_params)
    adv_pred_ace = tf_model_ace(adv_x_ace)

    fgsm_nn_ori = FastGradientMethod(cleverhans_model_nn_ori, sess=sess)
    adv_x_nn_ori = fgsm_nn_ori.generate(x_op, **fgsm_params)
    adv_pred_nn_ori = tf_model_nn_ori(adv_x_nn_ori)

    fgsm_ace_ori = FastGradientMethod(cleverhans_model_ace_ori, sess=sess)
    adv_x_ace_ori = fgsm_ace_ori.generate(x_op, **fgsm_params)
    adv_pred_ace_ori = tf_model_ace_ori(adv_x_ace_ori)
    # Run an evaluation of our model against fgsm
    total = 0
    correct_nn = 0
    correct_nn_ori = 0
    correct_ace = 0
    correct_ace_ori = 0
    for xs, ys in test_loader:
        adv_xs_nn, adv_preds_nn = sess.run([adv_x_nn, adv_pred_nn] , feed_dict={x_op: xs})
        adv_xs_nn__ori, adv_preds_nn_ori = sess.run([adv_x_nn_ori, adv_pred_nn_ori] , feed_dict={x_op: xs})
        adv_xs_ace, adv_preds_ace = sess.run([adv_x_ace, adv_pred_ace], feed_dict={x_op: xs})
        adv_xs_ace_ori, adv_preds_ace_ori = sess.run([adv_x_ace_ori, adv_pred_ace_ori], feed_dict={x_op: xs})
        # print(xs)
        # print(x_op[-1])
        # print(np.amax(adv_xs_nn[-1]),np.amin(adv_xs_nn[-1]))
        correct_nn += (np.argmax(adv_preds_nn, axis=1) == ys).sum()
        correct_nn_ori += (np.argmax(adv_preds_nn_ori, axis=1) == ys).sum()
        correct_ace += (np.argmax(adv_preds_ace, axis=1) == ys).sum()
        correct_ace_ori += (np.argmax(adv_preds_ace_ori, axis=1) == ys).sum()
        total += len(xs)

    acc_nn = float(correct_nn) / total
    acc_nn_ori = float(correct_nn_ori) / total
    acc_ace = float(correct_ace) / total
    acc_ace_ori = float(correct_ace_ori) / total
    print('eps:{}'.format(eps))
    print('nn Adv accuracy: {:.3f}'.format(acc_nn * 100))
    print('nn_ori Adv accuracy: {:.3f}'.format(acc_nn_ori * 100))
    print('ace Adv accuracy: {:.3f}'.format(acc_ace * 100))
    print('ace_ori Adv accuracy: {:.3f}'.format(acc_ace_ori * 100))

eps:0.05
nn Adv accuracy: 97.680
nn_ori Adv accuracy: 94.589
ace Adv accuracy: 97.350
ace_ori Adv accuracy: 96.080
eps:0.1
nn Adv accuracy: 96.120
nn_ori Adv accuracy: 80.658
ace Adv accuracy: 97.540
ace_ori Adv accuracy: 87.719
eps:0.15000000000000002
nn Adv accuracy: 93.379
nn_ori Adv accuracy: 53.185
ace Adv accuracy: 97.090
ace_ori Adv accuracy: 69.417
eps:0.2
nn Adv accuracy: 89.439
nn_ori Adv accuracy: 25.463
ace Adv accuracy: 96.150
ace_ori Adv accuracy: 37.774
eps:0.25
nn Adv accuracy: 82.528
nn_ori Adv accuracy: 8.781
ace Adv accuracy: 94.599
ace_ori Adv accuracy: 9.201
eps:0.3
nn Adv accuracy: 72.307
nn_ori Adv accuracy: 3.130
ace Adv accuracy: 92.399
ace_ori Adv accuracy: 3.990
eps:0.35000000000000003
nn Adv accuracy: 56.886
nn_ori Adv accuracy: 1.600
ace Adv accuracy: 89.449
ace_ori Adv accuracy: 3.150
eps:0.4
nn Adv accuracy: 38.574
nn_ori Adv accuracy: 1.130
ace Adv accuracy: 85.359
ace_ori Adv accuracy: 2.980


In [19]:
torch.save(model_nn.state_dict(), 'model_adv_nn_cpu_pytorch_1epochs.pkl')
torch.save(model_ace.state_dict(), 'model_adv_ace_cpu_pytorch_1epochs.pkl')

In [20]:
model_ace = aceModel()
model_ace.load_state_dict(torch.load('model_adv_ace_cpu_pytorch_1epochs.pkl'))
model_nn = nnModel()
model_nn.load_state_dict(torch.load('model_adv_nn_cpu_pytorch_1epochs.pkl'))
model_ace_ori = aceModel()
model_ace_ori.load_state_dict(torch.load('model_ace_cpu_pytorch_1epochs.pkl'))
model_nn_ori = nnModel()
model_nn_ori.load_state_dict(torch.load('model_nn_cpu_pytorch_1epochs.pkl'))


sess = tf.Session()
x_op = tf.placeholder(tf.float32, shape=(None, 1, 28, 28,))

# Convert pytorch model to a tf_model and wrap it in cleverhans
tf_model_nn = convert_pytorch_model_to_tf(model_nn)
tf_model_ace = convert_pytorch_model_to_tf(model_ace)
tf_model_nn_ori = convert_pytorch_model_to_tf(model_nn_ori)
tf_model_ace_ori = convert_pytorch_model_to_tf(model_ace_ori)
cleverhans_model_nn = CallableModelWrapper(tf_model_nn, output_layer='logits')
cleverhans_model_ace = CallableModelWrapper(tf_model_ace, output_layer='logits')
cleverhans_model_nn_ori = CallableModelWrapper(tf_model_nn_ori, output_layer='logits')
cleverhans_model_ace_ori = CallableModelWrapper(tf_model_ace_ori, output_layer='logits')

# Create an FGSM attack
for eps in np.arange(0.05,0.45,0.05):

    fgsm_params = {'eps': eps,
                   'clip_min': 0.,
                   'clip_max': 1.}

    fgsm_nn = FastGradientMethod(cleverhans_model_nn, sess=sess)
    adv_x_nn = fgsm_nn.generate(x_op, **fgsm_params)
    adv_pred_nn = tf_model_nn(adv_x_nn)

    fgsm_ace = FastGradientMethod(cleverhans_model_ace, sess=sess)
    adv_x_ace = fgsm_ace.generate(x_op, **fgsm_params)
    adv_pred_ace = tf_model_ace(adv_x_ace)

    fgsm_nn_ori = FastGradientMethod(cleverhans_model_nn_ori, sess=sess)
    adv_x_nn_ori = fgsm_nn_ori.generate(x_op, **fgsm_params)
    adv_pred_nn_ori = tf_model_nn_ori(adv_x_nn_ori)

    fgsm_ace_ori = FastGradientMethod(cleverhans_model_ace_ori, sess=sess)
    adv_x_ace_ori = fgsm_ace_ori.generate(x_op, **fgsm_params)
    adv_pred_ace_ori = tf_model_ace_ori(adv_x_ace_ori)
    # Run an evaluation of our model against fgsm
    total = 0
    correct_nn = 0
    correct_nn_ori = 0
    correct_ace = 0
    correct_ace_ori = 0
    for xs, ys in test_loader:
        adv_xs_nn, adv_preds_nn = sess.run([adv_x_nn, adv_pred_nn] , feed_dict={x_op: xs})
        adv_xs_nn__ori, adv_preds_nn_ori = sess.run([adv_x_nn_ori, adv_pred_nn_ori] , feed_dict={x_op: xs})
        adv_xs_ace, adv_preds_ace = sess.run([adv_x_ace, adv_pred_ace], feed_dict={x_op: xs})
        adv_xs_ace_ori, adv_preds_ace_ori = sess.run([adv_x_ace_ori, adv_pred_ace_ori], feed_dict={x_op: xs})
        # print(xs)
        # print(x_op[-1])
        # print(np.amax(adv_xs_nn[-1]),np.amin(adv_xs_nn[-1]))
        correct_nn += (np.argmax(adv_preds_nn, axis=1) == ys).sum()
        correct_nn_ori += (np.argmax(adv_preds_nn_ori, axis=1) == ys).sum()
        correct_ace += (np.argmax(adv_preds_ace, axis=1) == ys).sum()
        correct_ace_ori += (np.argmax(adv_preds_ace_ori, axis=1) == ys).sum()
        total += len(xs)

    acc_nn = float(correct_nn) / total
    acc_nn_ori = float(correct_nn_ori) / total
    acc_ace = float(correct_ace) / total
    acc_ace_ori = float(correct_ace_ori) / total
    print('eps:{}'.format(eps))
    print('nn Adv accuracy: {:.3f}'.format(acc_nn * 100))
    print('nn_ori Adv accuracy: {:.3f}'.format(acc_nn_ori * 100))
    print('ace Adv accuracy: {:.3f}'.format(acc_ace * 100))
    print('ace_ori Adv accuracy: {:.3f}'.format(acc_ace_ori * 100))

eps:0.05
nn Adv accuracy: 97.680
nn_ori Adv accuracy: 94.589
ace Adv accuracy: 97.350
ace_ori Adv accuracy: 96.080
eps:0.1
nn Adv accuracy: 96.120
nn_ori Adv accuracy: 80.658
ace Adv accuracy: 97.540
ace_ori Adv accuracy: 87.719
eps:0.15000000000000002
nn Adv accuracy: 93.379
nn_ori Adv accuracy: 53.185
ace Adv accuracy: 97.090
ace_ori Adv accuracy: 69.417
eps:0.2
nn Adv accuracy: 89.439
nn_ori Adv accuracy: 25.463
ace Adv accuracy: 96.150
ace_ori Adv accuracy: 37.774
eps:0.25
nn Adv accuracy: 82.528
nn_ori Adv accuracy: 8.781
ace Adv accuracy: 94.599
ace_ori Adv accuracy: 9.201
eps:0.3
nn Adv accuracy: 72.307
nn_ori Adv accuracy: 3.130
ace Adv accuracy: 92.399
ace_ori Adv accuracy: 3.990
eps:0.35000000000000003
nn Adv accuracy: 56.886
nn_ori Adv accuracy: 1.600
ace Adv accuracy: 89.449
ace_ori Adv accuracy: 3.150
eps:0.4
nn Adv accuracy: 38.574
nn_ori Adv accuracy: 1.130
ace Adv accuracy: 85.359
ace_ori Adv accuracy: 2.980
