In [33]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from torch import nn
import copy
from torch.nn import functional as F
from torch.autograd import Variable
from torch.autograd.gradcheck import zero_gradients

In [8]:
# LeNet 网络
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x
    
def my_fgsm(input, labels, model, criterion, epsilon, device, c=None):
    assert isinstance(model, torch.nn.Module), "Input parameter model is not nn.Module. Check the model"
    assert isinstance(criterion, torch.nn.Module), "Input parameter criterion is no Loss. Check the criterion"
    assert (0 <= epsilon <= 1), "episilon must be 0 <= epsilon <= 1"

    # For calculating gradient
    input_for_gradient = Variable(input, requires_grad=True).to(device)
    out = model(input_for_gradient)
    if c==None:
        loss = criterion(out, Variable(labels))
    else:
        loss = criterion(out, Variable(labels), c)

    # Calculate gradient
    loss.backward()

    # Calculate sign of gradient
    signs = torch.sign(input_for_gradient.grad.data)

    # Add
    input_for_gradient.data = input_for_gradient.data + (epsilon * signs)

    return input_for_gradient, signs

# 训练LeNet

记作：model_NORMAL

In [9]:
NORMALIZE = True
if NORMALIZE:
    trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
else:
    trans = transforms.Compose([
        transforms.ToTensor(),
    ])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [10]:
data_home = 'F:\\work'

train_set = torchvision.datasets.MNIST(root=os.path.join(data_home, 'dataset/MNIST'), train=True, download=True, transform=trans)
test_set = torchvision.datasets.MNIST(root=os.path.join(data_home, 'dataset/MNIST'), train=False, download=True, transform=trans)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=2)

In [12]:
# 声明网络
model_NORMAL = LeNet()
model_NORMAL = model_NORMAL.to(DEVICE)

# 参数
learning_rate = 0.01
epochs = 10

optimizer_NORMAL = torch.optim.SGD(params=model_NORMAL.parameters(), lr=learning_rate, momentum=0.5)
criterion_NORMAL = torch.nn.functional.cross_entropy

# 开始训练

for epoch in range(epochs):
    model_NORMAL.train()
    count = 0
    for data, target in train_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        optimizer_NORMAL.zero_grad()
        output = model_NORMAL(data)
        loss = criterion_NORMAL(output, target)
        loss.backward()
        optimizer_NORMAL.step()
    
        count += len(data)
        print('\r {}|{}'.format(count, len(train_loader.dataset)), end='')
    
    # 测试
    correct = 0
    model_NORMAL.eval()
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        output = model_NORMAL(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('epoch: {}, test correct on clean data: {}'.format(epoch,correct/len(test_loader.dataset)))

 60000|60000epoch: 0, test correct on clean data: 0.9468
 60000|60000epoch: 1, test correct on clean data: 0.9703
 60000|60000epoch: 2, test correct on clean data: 0.9821
 60000|60000epoch: 3, test correct on clean data: 0.9839
 60000|60000epoch: 4, test correct on clean data: 0.985
 60000|60000epoch: 5, test correct on clean data: 0.9864
 60000|60000epoch: 6, test correct on clean data: 0.9879
 60000|60000epoch: 7, test correct on clean data: 0.9881
 60000|60000epoch: 8, test correct on clean data: 0.9877
 60000|60000epoch: 9, test correct on clean data: 0.9882


In [18]:
# model_NORMAL 的对抗（FGSM）测试
epsilon = 0.3

model_NORMAL.eval()
correct_total = 0
criterion_NORMAL = nn.CrossEntropyLoss()
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data, sign = my_fgsm(data, target, model_NORMAL, criterion_NORMAL, epsilon, DEVICE)
        output = model_NORMAL(data)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.8316326530612245
 1135
 1 correct: 0.8405286343612335
 1032
 2 correct: 0.6036821705426356
 1010
 3 correct: 0.8118811881188119
 982
 4 correct: 0.4226069246435845
 892
 5 correct: 0.7219730941704036
 958
 6 correct: 0.7807933194154488
 1028
 7 correct: 0.6585603112840467
 974
 8 correct: 0.5328542094455853
 1009
 9 correct: 0.6392467789890981
avg acc on FGSM attach: 0.686


# LeNet对抗训练

FGSM生成对抗样本

In [21]:
# 声明网络，并用model_NORMAL初始化
model_ADV = copy.deepcopy(model_NORMAL)
model_ADV = model_ADV.to(DEVICE)

epsilon = 0.3
criterion_ADV = torch.nn.functional.cross_entropy
criterion_ADV_v = nn.CrossEntropyLoss()
optimizer_ADV = torch.optim.SGD(params=model_ADV.parameters(), lr=0.01, momentum=0.5)

for epoch in range(10):
    count = 0
    for data, target in train_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        data, sign = my_fgsm(data, target, model_ADV, criterion_ADV_v, epsilon, DEVICE)

        optimizer_ADV.zero_grad()
        output = model_ADV(data)
        loss = criterion_ADV(output, target)
        loss.backward()
        optimizer_ADV.step()
        count += len(data)
        print('\r {}|{}'.format(count, len(train_loader.dataset)), end='')
      
    # 测试
    correct = 0
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)

        data, sign = my_fgsm(data, target, model_ADV, criterion_ADV_v, epsilon, DEVICE)

        output = model_ADV(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('epoch:{}, test correct: {}'.format(epoch, correct/len(test_loader.dataset)))    

 60000|60000epoch:0, test correct: 0.9077
 60000|60000epoch:1, test correct: 0.922
 60000|60000epoch:2, test correct: 0.9321
 60000|60000epoch:3, test correct: 0.9333
 60000|60000epoch:4, test correct: 0.9358
 60000|60000epoch:5, test correct: 0.9342
 60000|60000epoch:6, test correct: 0.9413
 60000|60000epoch:7, test correct: 0.9457
 60000|60000epoch:8, test correct: 0.9471
 60000|60000epoch:9, test correct: 0.9474


In [24]:
## 测试model_ADV每一类的在FGSM样本上的准确率
epsilon = 0.3

model_ADV.eval()
correct_total = 0
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data, sign = my_fgsm(data, target, model_ADV, criterion_ADV_v, epsilon, DEVICE)
        output = model_ADV(data)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.9795918367346939
 1135
 1 correct: 0.9859030837004406
 1032
 2 correct: 0.9437984496124031
 1010
 3 correct: 0.9603960396039604
 982
 4 correct: 0.9348268839103869
 892
 5 correct: 0.9473094170403588
 958
 6 correct: 0.9582463465553236
 1028
 7 correct: 0.943579766536965
 974
 8 correct: 0.9158110882956879
 1009
 9 correct: 0.8999008919722498
avg acc on FGSM attach: 0.9474


# 训练CSE

只训练保护类0的模型

In [25]:
class Loss_cost_sensitive(nn.Module):
    def __init__(self,model):
        super(Loss_cost_sensitive, self).__init__()
        self.model = model
        
    def forward(self, data, target, c):
        
        l1 = F.cross_entropy(data, target, reduction='mean')
        p = F.softmax(data, 1)
        
        cost_sentive = c[:,target]
        cost_sentive = cost_sentive.T
        l2 = p.mul(cost_sentive)
        l2 = l2.sum(1).mean()
        
        conv_weight = self.model.conv1.weight
        loss_x = torch.norm(conv_weight, p=1) - torch.norm(conv_weight, p=2)
        
        return l1+l2 + loss_x
    
def get_cost_matric(i_label):
    C = torch.ones(10,10)
    C[i_label,:] = 10
    C[:,i_label] = 10
    C = C - torch.diag(C.diag())
    return C

In [28]:
# 声明网络并初始化
model_CSE_0 = copy.deepcopy(model_NORMAL)
model_CSE_0 = model_CSE_0.to(DEVICE)

criterion_CSE = Loss_cost_sensitive(model_CSE_0)
optimizer_CSE = torch.optim.SGD(params=model_CSE_0.parameters(), lr=0.01, momentum=0.5)

# 生成保护0类的代价矩阵
C = get_cost_matric(0)
C = C.to(DEVICE)

# 开始训练
for epoch in range(10):
    count = 0
    model_CSE_0.train()
    for data, target in train_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer_CSE.zero_grad()
        output = model_CSE_0(data)
        loss = criterion_CSE(output, target, C)

        loss.backward()
        optimizer_CSE.step()

        count += len(data)
        print('\r {}|{}'.format(count, len(train_loader.dataset)), end='')

    # 测试
    correct = 0
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        output = model_CSE_0(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('epoch: {}, test correct: {}'.format(epoch,correct/len(test_loader.dataset)))


 60000|60000epoch: 0, test correct: 0.977
 60000|60000epoch: 1, test correct: 0.9764
 60000|60000epoch: 2, test correct: 0.9816
 60000|60000epoch: 3, test correct: 0.981
 60000|60000epoch: 4, test correct: 0.9841
 60000|60000epoch: 5, test correct: 0.9835
 60000|60000epoch: 6, test correct: 0.9823
 60000|60000epoch: 7, test correct: 0.9834
 60000|60000epoch: 8, test correct: 0.985
 60000|60000epoch: 9, test correct: 0.9847


In [29]:
# 测试CSE在FGSM样本下每一类的准确率
epsilon = 0.3

model_CSE_0.eval()
correct_total = 0
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data, sign = my_fgsm(data, target, model_CSE_0, criterion_CSE, epsilon, DEVICE, C)
        output = model_CSE_0(data)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.9846938775510204
 1135
 1 correct: 0.9806167400881057
 1032
 2 correct: 0.9486434108527132
 1010
 3 correct: 0.9435643564356435
 982
 4 correct: 0.9419551934826884
 892
 5 correct: 0.9159192825112108
 958
 6 correct: 0.965553235908142
 1028
 7 correct: 0.9542801556420234
 974
 8 correct: 0.8655030800821355
 1009
 9 correct: 0.8572844400396432
avg acc on FGSM attach: 0.9366


# 测试上面三个模型在deepfool下的性能

In [30]:
def deepfool(image, net, num_classes=10, overshoot=0.02, max_iter=3):

    """
       :param image: Image of size HxWx3
       :param net: network (input: images, output: values of activation **BEFORE** softmax).
       :param num_classes: num_classes (limits the number of classes to test against, by default = 10)
       :param overshoot: used as a termination criterion to prevent vanishing updates (default = 0.02).
       :param max_iter: maximum number of iterations for deepfool (default = 50)
       :return: minimal perturbation that fools the classifier, number of iterations that it required, new estimated_label and perturbed image
    """
    is_cuda = torch.cuda.is_available()

    if is_cuda:
#         print("Using GPU")
        image = image.cuda()
        net = net.cuda()
    else:
        pass
#         print("Using CPU")


    f_image = net.forward(Variable(image[None, :, :, :], requires_grad=True)).data.cpu().numpy().flatten()
    I = (np.array(f_image)).flatten().argsort()[::-1]

    I = I[0:num_classes]
    label = I[0]

    input_shape = image.cpu().numpy().shape
    pert_image = copy.deepcopy(image)
    w = np.zeros(input_shape)
    r_tot = np.zeros(input_shape)

    loop_i = 0

    x = Variable(pert_image[None, :], requires_grad=True)
    fs = net.forward(x)
    fs_list = [fs[0,I[k]] for k in range(num_classes)]
    k_i = label

    while k_i == label and loop_i < max_iter:

        pert = np.inf
        fs[0, I[0]].backward(retain_graph=True)
        grad_orig = x.grad.data.cpu().numpy().copy()

        for k in range(1, num_classes):
            zero_gradients(x)

            fs[0, I[k]].backward(retain_graph=True)
            cur_grad = x.grad.data.cpu().numpy().copy()

            # set new w_k and new f_k
            w_k = cur_grad - grad_orig
            f_k = (fs[0, I[k]] - fs[0, I[0]]).data.cpu().numpy()

            pert_k = abs(f_k)/np.linalg.norm(w_k.flatten())

            # determine which w_k to use
            if pert_k < pert:
                pert = pert_k
                w = w_k

        # compute r_i and r_tot
        # Added 1e-4 for numerical stability
        r_i =  (pert+1e-4) * w / np.linalg.norm(w)
        r_tot = np.float32(r_tot + r_i)

        if is_cuda:
            pert_image = image + (1+overshoot)*torch.from_numpy(r_tot).cuda()
        else:
            pert_image = image + (1+overshoot)*torch.from_numpy(r_tot)

        x = Variable(pert_image, requires_grad=True)
        fs = net.forward(x)
        k_i = np.argmax(fs.data.cpu().numpy().flatten())

        loop_i += 1

    r_tot = (1+overshoot)*r_tot

    return r_tot, loop_i, label, k_i, pert_image



In [31]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=True, num_workers=2)

## CSE

In [34]:
# 测试CSE在FGSM样本下每一类的准确率

model_CSE_0.eval()
correct_total = 0
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data = data.reshape(1,28,28)
        r, loop_i, label_orig, label_pert, pert_image = deepfool(data, model_CSE_0)
        output = model_CSE_0(pert_image)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.8683673469387755
 1135
 1 correct: 0.8449339207048459
 1032
 2 correct: 0.7868217054263565
 1010
 3 correct: 0.801980198019802
 982
 4 correct: 0.7260692464358453
 892
 5 correct: 0.7488789237668162
 958
 6 correct: 0.848643006263048
 1028
 7 correct: 0.7675097276264592
 974
 8 correct: 0.6765913757700205
 1009
 9 correct: 0.5153617443012884
avg acc on FGSM attach: 0.7594


## ADV

In [35]:
model_ADV.eval()
correct_total = 0
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data = data.reshape(1,28,28)
        r, loop_i, label_orig, label_pert, pert_image = deepfool(data, model_ADV)
        output = model_ADV(pert_image)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.2571428571428571
 1135
 1 correct: 0.3832599118942731
 1032
 2 correct: 0.24321705426356588
 1010
 3 correct: 0.3425742574257426
 982
 4 correct: 0.11507128309572301
 892
 5 correct: 0.38228699551569506
 958
 6 correct: 0.21085594989561587
 1028
 7 correct: 0.20428015564202334
 974
 8 correct: 0.16119096509240247
 1009
 9 correct: 0.09117938553022795
avg acc on FGSM attach: 0.2399


## Normal

In [36]:
model_NORMAL.eval()
correct_total = 0
for special_index in range(10):
    count = 0
    correct = 0

    for data, target in test_loader:
        data = data[target==special_index]
        target = target[target==special_index]
        if len(data) == 0:
            continue

        data, target = data.to(DEVICE), target.to(DEVICE)
        data = data.reshape(1,28,28)
        r, loop_i, label_orig, label_pert, pert_image = deepfool(data, model_NORMAL)
        output = model_NORMAL(pert_image)

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        count += len(data)
        print('\r {}'.format(count), end='')
    correct_total += correct
    print('\n {} correct: {}'.format(special_index,correct/count))
print('avg acc on FGSM attach: {}'.format(correct_total/len(test_loader.dataset)))

 980
 0 correct: 0.0336734693877551
 1135
 1 correct: 0.04669603524229075
 1032
 2 correct: 0.050387596899224806
 1010
 3 correct: 0.09504950495049505
 982
 4 correct: 0.023421588594704685
 892
 5 correct: 0.0795964125560538
 958
 6 correct: 0.04488517745302714
 1028
 7 correct: 0.04669260700389105
 974
 8 correct: 0.054414784394250515
 1009
 9 correct: 0.007928642220019821
avg acc on FGSM attach: 0.048
