In [26]:
#调试开关
import logging

#添加系统路径
import sys
sys.path.append("../AdvBox/")

import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
import torch.utils.data.dataloader as Data
from advbox.adversary import Adversary
from advbox.attacks.gradient_method import FGSM
from advbox.models.pytorch import PytorchModel
import matplotlib.pyplot as plt
import numpy as np

# 自适应使用GPU还是CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [27]:
# 训练mnist的cnn网络定义
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Sequential(   # (1,28,28)   
            # in_channels:输入的通道数， out_channels：卷积核数量， kernel_size：卷积核大小， stride：步长
            # stride=1时， padding=(kernel_size-1)/2， 图片长宽不变
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),    # (16,28,28)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),    # (16,14,14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),    # (32,14,14)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),    # （32,7,7）
        )
        self.fc1 = nn.Sequential(
            nn.Linear(32*7*7, 500),
            nn.ReLU()
        )
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)    # (batch,32,7,7)
        x = x.view(x.size(0), -1)    # (batch,32*7*7)
        x = self.fc1(x)
        output = self.fc2(x)
        return output

In [28]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2),
        )
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2,mode="nearest"),
            nn.Conv2d(128,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2,mode="nearest"),
            nn.Conv2d(64,1,kernel_size=3,stride=1,padding=1),
        )

    def forward(self, x):
        output = self.encoder(x)
        output = self.decoder(output)
        return output

In [29]:
autoencoder = AutoEncoder().to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
loss_func = nn.MSELoss()

In [30]:
#使用MNIST训练数据集 
train_data=datasets.MNIST('/home/shenchenkai/data/mnist_pytorch/', train=True, download=False, transform=transforms.Compose([
        transforms.ToTensor(),
    ]))

train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=128, shuffle=True)

In [31]:
#迭代训练10轮
for epoch in range(10):
    for i, data in enumerate(train_loader):
        inputs, labels = data             
        inputs, labels = inputs.to(device), labels.to(device)
        
        #增加噪声
        inputs_noise=inputs+0.1*torch.randn(inputs.shape).to(device)
        inputs_noise=torch.clamp(inputs_noise,0.0,1.0)
        
        optimizer.zero_grad() 
        output = autoencoder(inputs_noise)
        loss = loss_func(output, inputs)                    
        loss.backward()                     
        optimizer.step()                  

        if (i % 100 == 0) and ( i > 0 ):
            print("Epoch={} batch={} loss={}".format(epoch+1, i, loss.data.cpu().numpy()))

Epoch=1 batch=100 loss=0.03581857308745384
Epoch=1 batch=200 loss=0.0293037798255682
Epoch=1 batch=300 loss=0.029111456125974655
Epoch=1 batch=400 loss=0.028390636667609215
Epoch=2 batch=100 loss=0.027548573911190033
Epoch=2 batch=200 loss=0.02530878223478794
Epoch=2 batch=300 loss=0.025680825114250183
Epoch=2 batch=400 loss=0.024080397561192513
Epoch=3 batch=100 loss=0.026186615228652954
Epoch=3 batch=200 loss=0.023570651188492775
Epoch=3 batch=300 loss=0.02365405671298504
Epoch=3 batch=400 loss=0.02442971058189869
Epoch=4 batch=100 loss=0.025155559182167053
Epoch=4 batch=200 loss=0.023913631215691566
Epoch=4 batch=300 loss=0.02372409589588642
Epoch=4 batch=400 loss=0.02442202717065811
Epoch=5 batch=100 loss=0.024600552394986153
Epoch=5 batch=200 loss=0.024714380502700806
Epoch=5 batch=300 loss=0.023196186870336533
Epoch=5 batch=400 loss=0.023859839886426926
Epoch=6 batch=100 loss=0.023595696315169334
Epoch=6 batch=200 loss=0.023206466808915138
Epoch=6 batch=300 loss=0.023420490324497

## 验证自编码器去噪对模型识别的影响

In [19]:
TOTAL_NUM = 1000
pretrained_model="model/mnist_model_dict.pth"

In [20]:
#使用MNIST测试数据集 随机挑选TOTAL_NUM个
# Pytorch下的MNIST数据集默认就是归一化了
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/shenchenkai/data/mnist_pytorch/', train=False, download=False, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=1, shuffle=True)

In [21]:
# Define what device we are using
logging.info("CUDA Available: {}".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the network
model = Net().to(device)
# Load the pretrained model
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
# Set the model in evaluation mode. In this case this is for the Dropout layers
model.eval()
# use test data to generate adversarial examples
total_count = 0
#去噪前正确识别个数
pre_count=0
#去噪后正确识别个数
decoded_count = 0

In [22]:
for i, data in enumerate(test_loader):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device)
    total_count += 1

    #去噪前
    pre_label = np.argmax(model(inputs).data.cpu().numpy())
    if pre_label == labels[0]:
        pre_count += 1

    #使用自编码器去噪
    output = autoencoder(inputs)
    output = output.view(1, 1, 28, 28)
    decoded_label = np.argmax(model(output).data.cpu().numpy())
    if decoded_label == labels[0]:
        decoded_count += 1

    if total_count >= TOTAL_NUM:
        print(
            "[TEST_DATASET]: pre_count=%d, total_count=%d, pre_count_rate=%f  decoded_count=%d decoded_count_rate=%f"
            % (pre_count, total_count, float(pre_count) / total_count,
               decoded_count, float(decoded_count) / total_count))
        break

[TEST_DATASET]: pre_count=980, total_count=1000, pre_count_rate=0.980000  decoded_count=956 decoded_count_rate=0.956000


## 使用自编码器过滤噪音

In [23]:
TOTAL_NUM = 1000
pretrained_model="model/mnist_model_dict.pth"
loss_func = torch.nn.CrossEntropyLoss()

In [24]:
#使用MNIST测试数据集 随机挑选TOTAL_NUM个
# Pytorch下的MNIST数据集默认就是归一化了
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/shenchenkai/data/mnist_pytorch/', train=False, download=False, transform=transforms.Compose([
        transforms.ToTensor(),
    ])),
    batch_size=1, shuffle=True)

# Define what device we are using
logging.info("CUDA Available: {}".format(torch.cuda.is_available()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the network
model = Net().to(device)
# Load the pretrained model
model.load_state_dict(torch.load(pretrained_model, map_location='cpu'))
# Set the model in evaluation mode. In this case this is for the Dropout layers
model.eval()
# advbox demo
m = PytorchModel(model, loss_func,(0, 1), channel_axis=1)
#实例化FGSM
attack = FGSM(m)
#设置攻击步长为0.1
attack_config = {"epsilons": 0.01}

# use test data to generate adversarial examples
total_count = 0
# 去噪前的攻击成功个数
fooling_count = 0
# 去噪后的攻击成功个数
decoded_fooling_count = 0
#记录原始数据经过自编码器去噪后可以正常识别的个数
decoded_count = 0

cuda


In [25]:
for i, data in enumerate(test_loader):
    inputs, labels = data
    inputs, labels = inputs.numpy(), labels.numpy()

    total_count += 1
    adversary = Adversary(inputs, labels[0])

    # FGSM non-targeted attack
    adversary = attack(adversary, **attack_config)
    if adversary.is_successful():
        fooling_count += 1
        print(
            'attack success, original_label=%d, adversarial_label=%d, count=%d'
            % (labels, adversary.adversarial_label, total_count))

        #对抗样本保存在adversary.adversarial_example
        #adversary_image=np.copy(adversary.adversarial_example[0])
        adversary_image = torch.from_numpy(
            np.copy(adversary.adversarial_example)).to(device).float()
        pre_label = np.argmax(model(adversary_image).data.cpu().numpy())

        #使用自编码器去噪
        output = autoencoder(adversary_image)
        output = output.view(1, 1, 28, 28)

        decoded_label = np.argmax(model(output).data.cpu().numpy())

        if decoded_label != labels[0]:
            print("orig_label={} adv_label={} decoded_label={}".format(
                labels[0], pre_label, decoded_label))
            decoded_fooling_count += 1

    else:
        print('attack failed, original_label=%d, count=%d' %
              (labels, total_count))

    if total_count >= TOTAL_NUM:
        print(
            "[TEST_DATASET]: fooling_count=%d, total_count=%d, fooling_rate=%f decoded_fooling_count=%d  decoded_fooling_count_rate=%f"
            % (fooling_count, total_count,
               float(fooling_count) / total_count, decoded_fooling_count,
               float(decoded_fooling_count) / total_count))
        break
print("fgsm attack done")

attack success, original_label=7, adversarial_label=9, count=1
attack success, original_label=2, adversarial_label=3, count=2
attack success, original_label=5, adversarial_label=6, count=3
attack success, original_label=9, adversarial_label=4, count=4
attack success, original_label=9, adversarial_label=4, count=5
attack success, original_label=2, adversarial_label=8, count=6
attack success, original_label=6, adversarial_label=4, count=7
attack success, original_label=2, adversarial_label=7, count=8
orig_label=2 adv_label=7 decoded_label=7
attack success, original_label=0, adversarial_label=6, count=9
attack success, original_label=1, adversarial_label=3, count=10
attack success, original_label=4, adversarial_label=9, count=11
attack success, original_label=2, adversarial_label=3, count=12
attack success, original_label=9, adversarial_label=7, count=13
attack success, original_label=2, adversarial_label=1, count=14
attack success, original_label=3, adversarial_label=6, count=15
attack s

attack success, original_label=0, adversarial_label=2, count=128
attack success, original_label=1, adversarial_label=8, count=129
attack success, original_label=7, adversarial_label=3, count=130
attack success, original_label=8, adversarial_label=6, count=131
attack success, original_label=9, adversarial_label=4, count=132
attack success, original_label=9, adversarial_label=7, count=133
attack success, original_label=8, adversarial_label=2, count=134
attack success, original_label=0, adversarial_label=9, count=135
attack success, original_label=7, adversarial_label=9, count=136
attack success, original_label=9, adversarial_label=8, count=137
attack success, original_label=4, adversarial_label=0, count=138
attack success, original_label=7, adversarial_label=9, count=139
attack success, original_label=6, adversarial_label=8, count=140
attack success, original_label=4, adversarial_label=9, count=141
attack success, original_label=7, adversarial_label=9, count=142
attack success, original_

attack success, original_label=0, adversarial_label=9, count=252
attack success, original_label=2, adversarial_label=3, count=253
attack success, original_label=6, adversarial_label=8, count=254
attack success, original_label=3, adversarial_label=9, count=255
attack success, original_label=2, adversarial_label=8, count=256
attack success, original_label=4, adversarial_label=9, count=257
attack success, original_label=5, adversarial_label=6, count=258
attack success, original_label=0, adversarial_label=6, count=259
attack success, original_label=6, adversarial_label=0, count=260
attack success, original_label=6, adversarial_label=4, count=261
attack success, original_label=8, adversarial_label=6, count=262
attack success, original_label=2, adversarial_label=1, count=263
attack success, original_label=0, adversarial_label=9, count=264
attack success, original_label=9, adversarial_label=8, count=265
attack success, original_label=1, adversarial_label=8, count=266
attack success, original_

attack success, original_label=0, adversarial_label=9, count=375
attack success, original_label=7, adversarial_label=9, count=376
attack success, original_label=0, adversarial_label=9, count=377
attack success, original_label=5, adversarial_label=8, count=378
attack success, original_label=1, adversarial_label=8, count=379
attack success, original_label=3, adversarial_label=9, count=380
attack success, original_label=9, adversarial_label=8, count=381
attack success, original_label=9, adversarial_label=4, count=382
attack success, original_label=6, adversarial_label=0, count=383
attack success, original_label=3, adversarial_label=5, count=384
attack success, original_label=0, adversarial_label=6, count=385
attack success, original_label=5, adversarial_label=9, count=386
attack success, original_label=4, adversarial_label=8, count=387
orig_label=4 adv_label=8 decoded_label=8
attack success, original_label=4, adversarial_label=8, count=388
attack success, original_label=8, adversarial_lab

attack success, original_label=9, adversarial_label=4, count=503
attack success, original_label=9, adversarial_label=8, count=504
attack success, original_label=1, adversarial_label=8, count=505
attack success, original_label=4, adversarial_label=9, count=506
attack success, original_label=4, adversarial_label=9, count=507
attack success, original_label=2, adversarial_label=1, count=508
attack success, original_label=6, adversarial_label=4, count=509
attack success, original_label=8, adversarial_label=6, count=510
attack success, original_label=4, adversarial_label=9, count=511
attack success, original_label=3, adversarial_label=9, count=512
attack success, original_label=5, adversarial_label=9, count=513
attack success, original_label=8, adversarial_label=6, count=514
attack success, original_label=2, adversarial_label=3, count=515
attack success, original_label=7, adversarial_label=9, count=516
attack success, original_label=0, adversarial_label=9, count=517
attack success, original_

attack success, original_label=5, adversarial_label=3, count=626
orig_label=5 adv_label=3 decoded_label=6
attack success, original_label=3, adversarial_label=8, count=627
attack success, original_label=5, adversarial_label=9, count=628
attack success, original_label=8, adversarial_label=5, count=629
orig_label=8 adv_label=5 decoded_label=5
attack success, original_label=1, adversarial_label=8, count=630
attack success, original_label=3, adversarial_label=8, count=631
attack success, original_label=2, adversarial_label=8, count=632
attack success, original_label=2, adversarial_label=3, count=633
attack success, original_label=3, adversarial_label=9, count=634
attack success, original_label=8, adversarial_label=5, count=635
attack success, original_label=5, adversarial_label=3, count=636
attack success, original_label=9, adversarial_label=8, count=637
attack success, original_label=5, adversarial_label=3, count=638
attack success, original_label=9, adversarial_label=4, count=639
attack s

attack success, original_label=8, adversarial_label=4, count=747
attack success, original_label=2, adversarial_label=5, count=748
attack success, original_label=3, adversarial_label=9, count=749
orig_label=3 adv_label=9 decoded_label=8
attack success, original_label=4, adversarial_label=8, count=750
attack success, original_label=3, adversarial_label=5, count=751
attack success, original_label=2, adversarial_label=3, count=752
attack success, original_label=8, adversarial_label=3, count=753
attack success, original_label=7, adversarial_label=8, count=754
attack success, original_label=8, adversarial_label=5, count=755
attack success, original_label=8, adversarial_label=3, count=756
attack success, original_label=4, adversarial_label=0, count=757
attack success, original_label=5, adversarial_label=9, count=758
attack success, original_label=8, adversarial_label=3, count=759
attack success, original_label=5, adversarial_label=8, count=760
attack success, original_label=8, adversarial_lab

attack success, original_label=1, adversarial_label=8, count=873
attack success, original_label=3, adversarial_label=5, count=874
attack success, original_label=7, adversarial_label=9, count=875
attack success, original_label=9, adversarial_label=8, count=876
attack success, original_label=9, adversarial_label=5, count=877
orig_label=9 adv_label=5 decoded_label=5
attack success, original_label=8, adversarial_label=5, count=878
attack success, original_label=8, adversarial_label=0, count=879
orig_label=8 adv_label=0 decoded_label=0
attack success, original_label=8, adversarial_label=4, count=880
attack success, original_label=5, adversarial_label=8, count=881
attack success, original_label=2, adversarial_label=7, count=882
attack success, original_label=7, adversarial_label=3, count=883
attack success, original_label=9, adversarial_label=8, count=884
attack success, original_label=3, adversarial_label=5, count=885
attack success, original_label=7, adversarial_label=9, count=886
attack s

attack success, original_label=7, adversarial_label=9, count=996
attack success, original_label=9, adversarial_label=8, count=997
attack success, original_label=6, adversarial_label=4, count=998
attack success, original_label=6, adversarial_label=8, count=999
attack success, original_label=5, adversarial_label=8, count=1000
[TEST_DATASET]: fooling_count=1000, total_count=1000, fooling_rate=1.000000 decoded_fooling_count=59  decoded_fooling_count_rate=0.059000
fgsm attack done
