In [1]:
import torch 
import numpy as np
import os
import torchvision.transforms as transforms
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from utils import my_fgsm
%matplotlib inline

In [2]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
NORMALIZE = False

CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_home = 'F:\\work'
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=test_transform)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=2)

In [5]:
def get_cost_matric(i_label):
    C = torch.ones(10,10)
    C[i_label,:] = 2
    C[:,i_label] = 2
    C = C - torch.diag(C.diag())
    return C

class Loss_cost_sensitive(nn.Module):
    def __init__(self,model):
        super(Loss_cost_sensitive, self).__init__()
        self.model = model
        
    def forward(self, data, target, c):
        
        l1 = F.cross_entropy(data, target, reduction='mean')
        p = F.softmax(data, 1)
        
        cost_sentive = c[:,target]
        cost_sentive = cost_sentive.T
        l2 = p.mul(cost_sentive)
        l2 = l2.sum(1).mean()
        
        conv_weight = self.model.conv1.weight
        loss_x = torch.norm(conv_weight, p=1) - torch.norm(conv_weight, p=2)
#         loss_x = -0.1*torch.norm(conv_weight, p=1) - torch.norm(conv_weight, p=2)
#         return l1
        return l1+l2 + loss_x

In [6]:
# 一般对抗训练的模型
model_ADV = LeNet()

model_ADV.load_state_dict(torch.load('../model/Lenet_CIFAR.pt'))

# 正常模型
model_normal = LeNet()
model_normal.load_state_dict(torch.load('../model/Lenet_CIFAR.pt'))

<All keys matched successfully>

In [None]:
# 记录结果
results_infos = {}

# 先读取未经过对抗训练的模型
# 在进行对抗训练

# 参数
epsilon = 0.3
if NORMALIZE:
    model_path = ''
else:
    model_path = '../model/Lenet_CIFAR.pt'

# 循环 对每一个类分别进行保护
for i_label in range(10):
    # 读取预训练模型
    model_cost_sensitive = LeNet()
#     model_cost_sensitive.load_state_dict(torch.load(model_path))
#     print('load model for initialization: {}'.format(model_path))
    model_cost_sensitive = model_cost_sensitive.to(DEVICE)
    
    criterion_cost_sensitive = Loss_cost_sensitive(model_cost_sensitive)
#     criterion_cost_sensitive = nn.CrossEntropyLoss()
    
    optimizer = torch.optim.Adam(params=model_cost_sensitive.parameters(), lr=0.001)
    schedule = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)
    
    C = get_cost_matric(i_label)
    C = C.to(DEVICE)
    print('protect label: {}'.format(i_label))
    print('load cost matric: ')
    print(C)
    
    LABEL = 'Protect Label ' + str(i_label)
    
    # 开始训练
    print('开始训练')
    for epoch in range(20):
        count = 0
        loss_sum = 0
        model_cost_sensitive.train()
        schedule.step(epoch)
        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)

            optimizer.zero_grad()
            output = model_cost_sensitive(data)    
            loss = criterion_cost_sensitive(output, target, C)
#             loss = criterion_cost_sensitive(output,target)
            loss_sum += loss.item()
            loss.backward()
            optimizer.step()
            
            count += len(data)
            print('\r {}|{}, loss:{}'.format(count, len(train_loader.dataset), loss_sum), end='')
        
        # 测试
        correct = 0
        model_cost_sensitive.eval()
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model_cost_sensitive(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
        print('epoch: {}, test correct: {}'.format(epoch,correct/len(test_loader.dataset)))
        
#         correct = 0
#         for data, target in test_loader:
#             data, target = data.to(DEVICE), target.to(DEVICE)
#             data, sign = my_fgsm(data, target, model_cost_sensitive, criterion_cost_sensitive, epsilon, DEVICE, C)
#             output = model_cost_sensitive(data)
#             pred = output.argmax(dim=1, keepdim=True)
#             correct += pred.eq(target.view_as(pred)).sum().item()
#         print('epoch: {}, test correct on adv: {}'.format(epoch,correct/len(test_loader.dataset)))
    
#     # 保存模型
#     if NORMALIZE:
#         model_save_path = './model/LeNet_MNIST_cost_sensitive_extension_' + str(i_label) +'.pt'
#     else:
#         model_save_path = './model/LeNet_MNIST_unnormalized_cost_sensitive_extension_' + str(i_label) +'.pt'
#     torch.save(model_cost_sensitive.state_dict(), model_save_path)
    
    
    # 训练结束
#     比较结果

    ## 对 model_cost_sensitive 的评估
    print('对 model_cost_sensitive 的评估')
    results_info = {}
    model_cost_sensitive.eval()
    images_targets = {}
    for special_index in range(10):
        count = 0
        correct = 0

        for data, target in test_loader:
            data = data[target==special_index]
            target = target[target==special_index]
            if len(data) == 0:
                continue

            data, target = data.to(DEVICE), target.to(DEVICE)
            data, sign = my_fgsm(data, target, model_cost_sensitive, criterion_cost_sensitive, epsilon, DEVICE, C)
            output = model_cost_sensitive(data)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += len(data)
            print('\r {}'.format(count), end='')
        images_targets[special_index] = [count, correct/count]
        print('\n {} correct: {}'.format(special_index,correct/count))
    results_info[1] = images_targets
    
    ## 对 model_normal 的评估
    print('对 model_normal的评估')
    model_normal.eval()
    criterion_normal = nn.CrossEntropyLoss()
    for special_index in range(10):
        count = 0
        correct = 0
        for data, target in test_loader:
            data = data[target==special_index]
            target = target[target==special_index]
            if len(data) == 0:
                continue
            data, target = data.to(DEVICE), target.to(DEVICE)
            data, sign = my_fgsm(data, target, model_normal, criterion_normal, epsilon, DEVICE)
            output = model_normal(data)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += len(data)
            print('\r {}'.format(count), end='')   
        print('\n {} correct: {}'.format(special_index,correct/count))

    
    ## 对 model_adv 的评估
    print('对 model_adv的评估')
    model_ADV.eval()
    images_targets = {}
    criterion_normal = nn.CrossEntropyLoss()
    for special_index in range(10):
        count = 0
        correct = 0

        for data, target in test_loader:
            data = data[target==special_index]
            target = target[target==special_index]
            if len(data) == 0:
                continue

            data, target = data.to(DEVICE), target.to(DEVICE)
            data, sign = my_fgsm(data, target, model_ADV, criterion_normal, epsilon, DEVICE)
            output = model_ADV(data)

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += len(data)
            print('\r {}'.format(count), end='')
        images_targets[special_index] = [count, correct/count]
        print('\n {} correct: {}'.format(special_index,correct/count))
    results_info[2] = images_targets
    
    # 记录结果
    results_infos[i_label] = results_info

protect label: 0
load cost matric: 
tensor([[0., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [2., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 1., 0., 1., 1., 1., 1., 1., 1., 1.],
        [2., 1., 1., 0., 1., 1., 1., 1., 1., 1.],
        [2., 1., 1., 1., 0., 1., 1., 1., 1., 1.],
        [2., 1., 1., 1., 1., 0., 1., 1., 1., 1.],
        [2., 1., 1., 1., 1., 1., 0., 1., 1., 1.],
        [2., 1., 1., 1., 1., 1., 1., 0., 1., 1.],
        [2., 1., 1., 1., 1., 1., 1., 1., 0., 1.],
        [2., 1., 1., 1., 1., 1., 1., 1., 1., 0.]])
开始训练




 50000|50000, loss:3488.5374925136566epoch: 0, test correct: 0.2494
 50000|50000, loss:2358.2944946289062epoch: 1, test correct: 0.3232
 50000|50000, loss:2202.1200180053717epoch: 2, test correct: 0.3484
 50000|50000, loss:2098.3165965080267epoch: 3, test correct: 0.3799
 50000|50000, loss:2042.5490832328796epoch: 4, test correct: 0.3829
 50000|50000, loss:2006.4566664695745epoch: 5, test correct: 0.3979
 50000|50000, loss:1974.7070100307465epoch: 6, test correct: 0.4073
 50000|50000, loss:1950.2349448204046epoch: 7, test correct: 0.418
 50000|50000, loss:1925.6282587051392epoch: 8, test correct: 0.4188
 50000|50000, loss:1905.8589544296265epoch: 9, test correct: 0.4224
 50000|50000, loss:1887.3484418392181epoch: 10, test correct: 0.4314
 50000|50000, loss:1874.6392643451696epoch: 11, test correct: 0.4345
 50000|50000, loss:1859.4913933277134epoch: 12, test correct: 0.4335
 50000|50000, loss:1847.5138382911682epoch: 13, test correct: 0.4384
 50000|50000, loss:1838.1556876897812epoch: 1