# 在CIFAR数据集上测试CSE、CSA

In [1]:
import torch 
import numpy as np
import os
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import copy
from torch.autograd.gradcheck import zero_gradients
from utils import *


In [2]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = f.relu(self.conv1(x))
        x = f.max_pool2d(x, 2)
        x = f.relu(self.conv2(x))
        x = f.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = f.relu(self.fc1(x))
        x = f.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# FGSM 攻击

In [3]:
NORMALIZE = True

CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
DEVICE = torch.device(cuda_num if torch.cuda.is_available() else "cpu")
batch_size = 64

data_home = 'F:\\work'

if NORMALIZE:
    train_transform = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    test_transform = transforms.Compose([transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
else:
    train_transform = transforms.Compose([transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def my_fgsm(input, labels, model, criterion, epsilon, device, c=None):
    assert isinstance(model, torch.nn.Module), "Input parameter model is not nn.Module. Check the model"
    assert isinstance(criterion, torch.nn.Module), "Input parameter criterion is no Loss. Check the criterion"
    assert (0 <= epsilon <= 1), "episilon must be 0 <= epsilon <= 1"

    # For calculating gradient
    input_for_gradient = Variable(input, requires_grad=True).to(device)
    out = model(input_for_gradient)
    if c==None:
        loss = criterion(out, Variable(labels))
    else:
        loss = criterion(out, Variable(labels), c)

    # Calculate gradient
    loss.backward()

    # Calculate sign of gradient
    signs = torch.sign(input_for_gradient.grad.data)

    # Add
    input_for_gradient.data = input_for_gradient.data + (epsilon * signs)

    return input_for_gradient, signs

In [None]:
# 记录结果
results_infos = {}

# 先读取未经过对抗训练的模型
# 在进行对抗训练

# 参数
epsilon_model = 0.3
epsilon_attack = 0.3
criterion_CSA = Loss_cost_sensitive()


# 循环 对每一个类分别进行保护
for i_label in range(10):
    ################################
    # 读取模型
    model_CSA = LeNet()
    path_model_CSA = '../model/LeNet_CIFAR_adv_cost_sensitive_'+ str(i_label) +'_e'+ str(epsilon_model) +'.pt'
    model_CSA.load_state_dict(torch.load(path_model_CSA))
    print('load model for initialization: {}'.format(path_model_CSA))
    model_CSA = model_CSA.to(DEVICE)
    
    
    C = get_cost_matric(i_label)
    C = C.to(DEVICE)
    print('protect label: {}'.format(i_label))
    print('load cost matric: ')
    print(C)
    LABEL = 'Protect Label ' + str(i_label)

    ## 对 model_CSA 的评估
    print('对 model_CSA 的评估')
    results_info = {}
    model_CSA.eval()
    images_targets = {}
    for special_index in range(10):
        count = 0
        correct = 0

        for data, target in test_loader:
            data = data[target==special_index]
            target = target[target==special_index]
            if len(data) == 0:
                continue

            data, target = data.to(DEVICE), target.to(DEVICE)
            data, sign = my_fgsm(data, target, model_CSA, criterion_CSA, epsilon, DEVICE, C)
            output = model_CSA(data)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += len(data)
            print('\r {}'.format(count), end='')
        images_targets[special_index] = [count, correct/count]
        print('\n {} correct: {}'.format(special_index,correct/count))
    results_info[1] = images_targets
    
##################################
    ## 对 model_CSE 的评估
    print('对 model_CSE 的评估')
    # 读取模型
    model_CSE = LeNet()
    path_model_CSE = ''
    model_CSE.load_state_dict(torch.load(path_model_CSE))
    print('load model for initialization: {}'.format(path_model_CSE))
    model_CSE = model_CSE.to(DEVICE)
    criterion_CSE = Loss_CSE(model_CSE)

    model_CSE.eval()
    images_targets = {}
    for special_index in range(10):
        count = 0
        correct = 0

        for data, target in test_loader:
            data = data[target==special_index]
            target = target[target==special_index]
            if len(data) == 0:
                continue

            data, target = data.to(DEVICE), target.to(DEVICE)
            data, sign = my_fgsm(data, target, model_CSE, criterion_CSE, epsilon, DEVICE, C)
            output = model_CSE(data)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += len(data)
            print('\r {}'.format(count), end='')
        images_targets[special_index] = [count, correct/count]
        print('\n {} correct: {}'.format(special_index,correct/count))
    results_info[2] = images_targets
    
##################################
    
    # 记录结果
    results_infos[i_label] = results_info

## 保存

In [None]:
import pandas as pd
output_dirs = './output/CIFAR'
if os.path.exists(output_dirs) is False:
    os.makedirs(output_dirs)

I_avg = {'CSA':[], 'CSE':[]}
writer = pd.ExcelWriter(os.path.join(output_dirs, 'CIFAR_FGSM.xlsx'))
for i in results_infos.keys():
    tmp = results_infos[i]
    
    I_avg['CSA'].append(tmp[1][i][1])
    I_avg['CSE'].append(tmp[2][i][1])
    
    df  = pd.DataFrame(tmp)
    df.columns = ['CSA', 'CSE']
    df = pd.DataFrame([df[i].apply(lambda x: x[1]) for i in df.columns])
    df = df.sort_index()
    
    df.to_excel(writer, sheet_name=str(i))
writer.save()

print('I of CSA: {}'.format(np.mean(I_avg['CSA'])))
print('I of CSE: {}'.format(np.mean(I_avg['CSE'])))