# winddy

用训练集CIFAR 对抗训练 ResNet18

# 正常训练

In [1]:
import numpy as np
import torch
import os
import torchvision.transforms as transforms 
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from LeNet import LeNet
import sys
sys.path.append('./model')
from resnet import ResNet18


%matplotlib inline

In [2]:
NORMALIZE = True
RESUME = False

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str,[0,2,7]))
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cudnn.benchmark = True

In [3]:
if NORMALIZE:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
else:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
    ])

In [4]:
data_home = '/data/winddy/'

train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=trans_train)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=trans_test)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [16]:
# 构建网络结构
net = ResNet18()
net = net.to(DEVICE)
net = torch.nn.DataParallel(net)

if RESUME:
# Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

In [17]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                              train_loss/(batch_idx+1), 100.*correct/total ), end='')
#         progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#             % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                              test_loss/(batch_idx+1), 100.*correct/total ), end='')
#             progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [18]:
best_acc = 0  # best test accuracy
start_epoch = 0
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)


Epoch: 0
 batch_idx: 78 | Loss: 1.3250955675221696 | Acc: 56.83 4551282051 4  Saving..

Epoch: 1
 batch_idx: 78 | Loss: 0.9298833668986454 | Acc: 67.66 2532051282    Saving..

Epoch: 2
 batch_idx: 78 | Loss: 0.7305333297463912 | Acc: 75.27 44871794872  Saving..

Epoch: 3
 batch_idx: 78 | Loss: 0.6912991412078278 | Acc: 77.11 4935897436   Saving..

Epoch: 4
 batch_idx: 78 | Loss: 0.7177353218386445 | Acc: 77.05 23717948718  
Epoch: 5
 batch_idx: 78 | Loss: 0.6310713445838494 | Acc: 79.27 84294871794   Saving..

Epoch: 6
 batch_idx: 78 | Loss: 0.6136832286285449 | Acc: 79.28 79487179488   Saving..

Epoch: 7
 batch_idx: 78 | Loss: 0.4970438027683693 | Acc: 83.99 440705128206  Saving..

Epoch: 8
 batch_idx: 78 | Loss: 0.5072396141064318 | Acc: 83.77 2243589743    
Epoch: 9
 batch_idx: 78 | Loss: 0.4806847044184238 | Acc: 84.26 79166666667   Saving..

Epoch: 10
 batch_idx: 78 | Loss: 0.43260920575902434 | Acc: 85.68 673076923 1  Saving..

Epoch: 11
 batch_idx: 78 | Loss: 0.5101283098323436

 batch_idx: 78 | Loss: 0.34959426663721666 | Acc: 90.63 98397435898   
Epoch: 100
 batch_idx: 78 | Loss: 0.2864048461182208 | Acc: 91.88 698717948718   
Epoch: 101
 batch_idx: 78 | Loss: 0.31656663817695424 | Acc: 91.24 96153846153   
Epoch: 102
 batch_idx: 78 | Loss: 0.27775239152244374 | Acc: 92.36 5641025641    Saving..

Epoch: 103
 batch_idx: 78 | Loss: 0.3557837994038304 | Acc: 90.36 45673076923    
Epoch: 104
 batch_idx: 78 | Loss: 0.3111067242637465 | Acc: 91.55 650641025641 2 
Epoch: 105
 batch_idx: 78 | Loss: 0.290947226595275 | Acc: 91.73 167467948718    
Epoch: 106
 batch_idx: 78 | Loss: 0.3522158863046501 | Acc: 90.64 50641025641    
Epoch: 107
 batch_idx: 78 | Loss: 0.3101107897826388 | Acc: 91.21 591346153847 6 
Epoch: 108
 batch_idx: 78 | Loss: 0.30303065922064115 | Acc: 91.64 1858974359    
Epoch: 109
 batch_idx: 78 | Loss: 0.3098806233345708 | Acc: 91.69 69871794872  2 
Epoch: 110
 batch_idx: 78 | Loss: 0.32385309144288676 | Acc: 91.2 89743589743    
Epoch: 111
 batch_

 batch_idx: 78 | Loss: 0.32112033985838107 | Acc: 91.72 77884615384   
Epoch: 199
 batch_idx: 78 | Loss: 0.302431238084277 | Acc: 91.71 9671474358974   

In [19]:
## 保存模型
if not os.path.exists('./model'):
    os.makedirs('./model')
if NORMALIZE:
    model_path = './model/ResNet18_CIFAR10.pt'
else:
    model_path = './model/ResNet18_CIFAR10_unNormalize.pt'
torch.save(net.state_dict(), model_path)

# 对抗训练

In [1]:
import numpy as np
import torch
import os
import torchvision.transforms as transforms 
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from LeNet import LeNet
import sys
sys.path.append('./model')
from resnet import ResNet18
sys.path.append('./utils')
from myUtils import my_fgsm, my_imshow
%matplotlib inline

In [2]:
NORMALIZE = True
RESUME = True

os.environ['CUDA_VISIBLE_DEVICES'] = "1,3,6,7"
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
cudnn.benchmark = True

In [3]:
if NORMALIZE:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
else:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
    ])

In [4]:
data_home = '/data/winddy/'

train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=trans_train)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=trans_test)


Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=16)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [6]:
# 构建网络结构
if NORMALIZE:
    model_path = './model/ResNet18_CIFAR10.pt'
else:
    model_path = './model/ResNet18_CIFAR10_unNormalize.pt'

model_adv = ResNet18()
model_adv = model_adv.to(DEVICE)
model_adv = torch.nn.DataParallel(model_adv, device_ids=[1,2,3])
# , map_location=lambda storage, loc: storage.cuda(1)
print('load model: {}'.format(model_path))
model_adv.load_state_dict(torch.load(model_path))

# model_adv = model_adv.cpu()
# model_adv = model_adv.to(DEVICE)
# model_adv = torch.nn.DataParallel(model_adv)

# # model_adv = torch.nn.DataParallel(model_adv)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_adv.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

load model: ./model/ResNet18_CIFAR10.pt


In [7]:
# 正常测试
print('test... ...')

model_adv.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        outputs = model_adv(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                          test_loss/(batch_idx+1), 100.*correct/total ), end='')


test... ...
 batch_idx: 78 | Loss: 0.3051270984018905 | Acc: 91.71 73076923077  

In [10]:
# 对砍测试
print('adversarial test... ...')

epsilon = 0.3
model_adv.eval()
test_loss = 0
correct = 0
total = 0

for batch_idx, (inputs, targets) in enumerate(test_loader):
    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
    inputs, sign = my_fgsm(inputs, targets, model_adv, criterion, epsilon, DEVICE)
    outputs = model_adv(inputs)
    loss = criterion(outputs, targets)

    test_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                      test_loss/(batch_idx+1), 100.*correct/total ), end='')


adversarial test... ...
 batch_idx: 78 | Loss: 4.841943517515931 | Acc: 23.99 8397435897434 

In [10]:
# 对抗训练
epsilon = 0.3

for epoch in range(200):
    model_adv.train()
    train_loss = 0
    correct = 0
    total = 0
    
    print('\nepoch: {}'.format(epoch))
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        inputs, sign = my_fgsm(inputs, targets, model_adv, criterion, epsilon, DEVICE)
        
        optimizer.zero_grad()
        outputs = model_adv(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                              train_loss/(batch_idx+1), 100.*correct/total ), end='')
    print('test... ...')

    model_adv.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(test_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        inputs, sign = my_fgsm(inputs, targets, model_adv, criterion, epsilon, DEVICE)
        outputs = model_adv(inputs)
        loss = criterion(outputs, targets)

        test_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                          test_loss/(batch_idx+1), 100.*correct/total ), end='')



epoch: 0
 batch_idx: 390 | Loss: 1.7167409733128365 | Acc: 37.064 3397435897  test... ...
 batch_idx: 78 | Loss: 2.331100574022607 | Acc: 27.05 221153846153  
epoch: 1
 batch_idx: 390 | Loss: 0.5275202078358902 | Acc: 81.726 076923077  test... ...
 batch_idx: 78 | Loss: 0.9699833234654197 | Acc: 68.97 32051282051 
epoch: 2
 batch_idx: 390 | Loss: 0.374040286254395 | Acc: 87.544 20673076923  test... ...
 batch_idx: 78 | Loss: 0.7727149408074874 | Acc: 76.87 98076923077 
epoch: 3
 batch_idx: 390 | Loss: 0.1937954072719035 | Acc: 93.738 97493573265 test... ...
 batch_idx: 78 | Loss: 0.5101497252530689 | Acc: 85.13 23397435898  
epoch: 4
 batch_idx: 390 | Loss: 0.13921023524173384 | Acc: 95.562 1346153847 test... ...
 batch_idx: 78 | Loss: 0.5901166259110728 | Acc: 83.01 78846153847 
epoch: 5
 batch_idx: 390 | Loss: 0.13996673890811098 | Acc: 95.418 7628205128 test... ...
 batch_idx: 78 | Loss: 0.6957542839684064 | Acc: 80.09 09294871794 
epoch: 6
 batch_idx: 390 | Loss: 0.101098288865307

 batch_idx: 390 | Loss: 0.04799989977722888 | Acc: 98.462 54487179488 test... ...
 batch_idx: 78 | Loss: 1.0717971151388144 | Acc: 75.85 36217948718 
epoch: 52
 batch_idx: 390 | Loss: 0.054578789088236705 | Acc: 98.272 397435898  test... ...
 batch_idx: 78 | Loss: 0.9879989684382572 | Acc: 77.8 451923076923 
epoch: 53
 batch_idx: 390 | Loss: 0.05585531074353649 | Acc: 98.234 7628205128  test... ...
 batch_idx: 78 | Loss: 1.016937546337707 | Acc: 76.29 203525641026 
epoch: 54
 batch_idx: 390 | Loss: 0.051310732941645794 | Acc: 98.386 666666667  test... ...
 batch_idx: 78 | Loss: 0.8721895353703559 | Acc: 79.87 80448717949 
epoch: 55
 batch_idx: 390 | Loss: 0.05457351820262344 | Acc: 98.332 3012820512  test... ...
 batch_idx: 78 | Loss: 0.9581662899331201 | Acc: 77.67 6282051282  
epoch: 56
 batch_idx: 390 | Loss: 0.05095545358746253 | Acc: 98.448 52243589743 test... ...
 batch_idx: 78 | Loss: 0.9418792234191412 | Acc: 78.41 44871794872 
epoch: 57
 batch_idx: 390 | Loss: 0.04536803896584

 batch_idx: 390 | Loss: 0.04996374842074826 | Acc: 98.53 64743589743  test... ...
 batch_idx: 78 | Loss: 0.8909506141384945 | Acc: 80.26 46153846153 
epoch: 103
 batch_idx: 390 | Loss: 0.05735547697204915 | Acc: 98.204 3461538461  test... ...
 batch_idx: 78 | Loss: 0.5151768786997735 | Acc: 86.82 91025641026 
epoch: 104
 batch_idx: 390 | Loss: 0.05556590238686108 | Acc: 98.252 92307692 8  test... ...
 batch_idx: 78 | Loss: 0.6825889058505432 | Acc: 82.77 40384615384 
epoch: 105
 batch_idx: 390 | Loss: 0.04553359056182225 | Acc: 98.554 858974359 9 test... ...
 batch_idx: 78 | Loss: 0.6240606504150584 | Acc: 84.37 96794871794 
epoch: 106
 batch_idx: 390 | Loss: 0.04912189004556907 | Acc: 98.396 4358974359  test... ...
 batch_idx: 78 | Loss: 0.6973944068709507 | Acc: 83.03 83653846153 
epoch: 107
 batch_idx: 390 | Loss: 0.04629948814792554 | Acc: 98.614 8525641026  test... ...
 batch_idx: 78 | Loss: 0.816856311846383 | Acc: 80.56 88782051282  
epoch: 108
 batch_idx: 390 | Loss: 0.04009218

 batch_idx: 78 | Loss: 1.0315678010258493 | Acc: 77.32 1794871794  
epoch: 153
 batch_idx: 390 | Loss: 0.0497405391014979 | Acc: 98.492 759615384616 test... ...
 batch_idx: 78 | Loss: 0.8965647773274893 | Acc: 79.09 50641025641 
epoch: 154
 batch_idx: 390 | Loss: 0.04394725521030786 | Acc: 98.702 92307692308 test... ...
 batch_idx: 78 | Loss: 0.8022089996669866 | Acc: 81.2 91987179488  
epoch: 155
 batch_idx: 390 | Loss: 0.0430303447072387 | Acc: 98.734 97435897436  test... ...
 batch_idx: 78 | Loss: 0.6901809784430492 | Acc: 83.33 33333333333 
epoch: 156
 batch_idx: 390 | Loss: 0.05238144177838665 | Acc: 98.384 1025641026  test... ...
 batch_idx: 78 | Loss: 0.6480053003075756 | Acc: 84.15 63141025641 
epoch: 157
 batch_idx: 390 | Loss: 0.04705542096834811 | Acc: 98.586 3717948718  test... ...
 batch_idx: 78 | Loss: 0.7304009553752367 | Acc: 81.73 73717948718 
epoch: 158
 batch_idx: 390 | Loss: 0.047479442289799376 | Acc: 98.574 2756410257 test... ...
 batch_idx: 78 | Loss: 0.770183232

In [16]:
## 保存模型
model_adv = model_adv.cpu()
if not os.path.exists('./model'):
    os.makedirs('./model')
if NORMALIZE:
    model_path = './model/ResNet18_CIFAR10_adv.pt'
else:
    model_path = './model/ResNet18_CIFAR10_unNormalize_adv.pt'
torch.save(model_adv.state_dict(), model_path)

In [11]:
############################
# 读取对抗训练模型
# 构建网络结构
if NORMALIZE:
    model_path = './model/ResNet18_CIFAR10_adv.pt'


model_adv = ResNet18()
model_adv = model_adv.to(DEVICE)
model_adv = torch.nn.DataParallel(model_adv, device_ids=[1,2,3])
# , map_location=lambda storage, loc: storage.cuda(1)
print('load model: {}'.format(model_path))
model_adv.load_state_dict(torch.load(model_path))

# model_adv = model_adv.cpu()
# model_adv = model_adv.to(DEVICE)
# model_adv = torch.nn.DataParallel(model_adv)

# # model_adv = torch.nn.DataParallel(model_adv)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_adv.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

load model: ./model/ResNet18_CIFAR10_adv.pt


In [12]:
# 对砍测试
print('adversarial test... ...')

epsilon = 0.3
model_adv.eval()
test_loss = 0
correct = 0
total = 0

for batch_idx, (inputs, targets) in enumerate(test_loader):
    inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
    inputs, sign = my_fgsm(inputs, targets, model_adv, criterion, epsilon, DEVICE)
    outputs = model_adv(inputs)
    loss = criterion(outputs, targets)

    test_loss += loss.item()
    _, predicted = outputs.max(1)
    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                      test_loss/(batch_idx+1), 100.*correct/total ), end='')


adversarial test... ...
 batch_idx: 78 | Loss: 0.8189108413231524 | Acc: 80.97 58333333333 