# winddy

用训练集CIFAR 对抗训练 ResNet18

In [14]:
import numpy as np
import torch
import os
import torchvision.transforms as transforms 
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from LeNet import LeNet
import sys
sys.path.append('./model')
from resnet import ResNet18

%matplotlib inline

In [15]:
NORMALIZE = True
RESUME = True

os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str,[0,2,7]))
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cudnn.benchmark = True

In [16]:
if NORMALIZE:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
else:
    trans_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    trans_test = transforms.Compose([
        transforms.ToTensor(),
    ])

In [17]:
data_home = '/data/winddy/'

train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=trans_train)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=trans_test)

Files already downloaded and verified
Files already downloaded and verified


In [18]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [19]:
# 构建网络结构
net = ResNet18()
net = net.to(DEVICE)
net = torch.nn.DataParallel(net)

if RESUME:
# Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/ckpt.pth')
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

==> Resuming from checkpoint..


In [20]:
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                              train_loss/(batch_idx+1), 100.*correct/total ), end='')
#         progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#             % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
        
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            print('\r batch_idx: {} | Loss: {} | Acc: {} '.format(batch_idx, 
                                                              test_loss/(batch_idx+1), 100.*correct/total ), end='')
#             progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [21]:
best_acc = 0  # best test accuracy
start_epoch = 0
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)


Epoch: 0
 batch_idx: 78 | Loss: 0.3216810831917992 | Acc: 91.05 565705128206   Saving..

Epoch: 1
 batch_idx: 78 | Loss: 0.32489821475140657 | Acc: 91.45 36217948718   Saving..

Epoch: 2
 batch_idx: 78 | Loss: 0.30908908081960074 | Acc: 91.36 15384615384   
Epoch: 3
 batch_idx: 78 | Loss: 0.30799488081962245 | Acc: 91.72 6282051282  2 Saving..

Epoch: 4
 batch_idx: 78 | Loss: 0.2805704056179222 | Acc: 92.44 790064102564   Saving..

Epoch: 5
 batch_idx: 78 | Loss: 0.34105642347396176 | Acc: 90.79 5641025641    
Epoch: 6
 batch_idx: 78 | Loss: 0.33029308311546907 | Acc: 91.59 5448717949    
Epoch: 7
 batch_idx: 78 | Loss: 0.306318882051148 | Acc: 92.28 767628205128  6 
Epoch: 8
 batch_idx: 78 | Loss: 0.33230514880977097 | Acc: 90.96 5608974359    
Epoch: 9
 batch_idx: 78 | Loss: 0.3095919775623309 | Acc: 91.76 79487179488    
Epoch: 10
 batch_idx: 78 | Loss: 0.29556303084651125 | Acc: 91.6 53846153847    
Epoch: 11
 batch_idx: 78 | Loss: 0.31615748718569553 | Acc: 91.16 83333333333   
E

 batch_idx: 78 | Loss: 0.35714713550066646 | Acc: 90.82 0448717949  7 
Epoch: 101
 batch_idx: 78 | Loss: 0.3014474937432929 | Acc: 91.73 76282051282    
Epoch: 102
 batch_idx: 78 | Loss: 0.28286192122894 | Acc: 92.5 4979967948718 4   
Epoch: 103
 batch_idx: 78 | Loss: 0.30559930441123023 | Acc: 91.76 8108974359    
Epoch: 104
 batch_idx: 78 | Loss: 0.31470954908600335 | Acc: 91.83 2307692308  6 
Epoch: 105
 batch_idx: 78 | Loss: 0.29612850935398777 | Acc: 91.94 9935897436  5 
Epoch: 106
 batch_idx: 78 | Loss: 0.33725800299191777 | Acc: 91.36 15384615384   
Epoch: 107
 batch_idx: 78 | Loss: 0.29265336173621914 | Acc: 92.02 2756410257  9 
Epoch: 108
 batch_idx: 78 | Loss: 0.29337241962740696 | Acc: 91.58 0641025641  1 
Epoch: 109
 batch_idx: 78 | Loss: 0.31668127243277394 | Acc: 91.69 73076923077   
Epoch: 110
 batch_idx: 78 | Loss: 0.31059765711992604 | Acc: 91.44 29807692308   
Epoch: 111
 batch_idx: 78 | Loss: 0.2924452227882192 | Acc: 92.04 25961538461    
Epoch: 112
 batch_idx: 78 |

In [22]:
## 保存模型
if not os.path.exists('./model'):
    os.makedirs('./model')
if NORMALIZE:
    model_path = './model/ResNet18_CIFAR10.pt'
else:
    model_path = './model/ResNet18_CIFAR10_unNormalize.pt'
torch.save(net.state_dict(), model_path)