In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import torchvision
import torchvision.transforms as transforms

import torchattacks

import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import os
import sys

from tqdm import tqdm

from resnet import ResNet18
from data_aug import collect_advs, AdvDataSet

TITAN RTX


In [2]:
if torch.cuda.is_available() == True:
    device = torch.device('cuda:0')
    print(torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
device

TITAN RTX


device(type='cuda', index=0)

In [3]:
batch_size = 128
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def load_data():    
    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                    transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    dataset = {'train': trainset, 'test': testset}
    data_loader = {'train': train_loader, 'test': test_loader}
    return dataset, data_loader

def load_iter(data_loader, data_type):
    bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
    
    if data_type == 'train':
        train_loader = data_loader['train']
        train_iter = tqdm(enumerate(train_loader), total=len(train_loader), unit_scale=batch_size, bar_format=bar_format)
        return train_iter
    elif data_type == 'test':
        test_loader = data_loader['test']
        test_iter = tqdm(enumerate(test_loader), total=len(test_loader), unit_scale=batch_size, bar_format=bar_format)
        return test_iter
    else:
        print('Data Error!!!')

In [4]:
def imshow(img):
    npimg = img.cpu().detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [5]:
def train(model):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    step = 0
    train_iter = load_iter(data_loader, 'train')
    for i, (batch, label) in train_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        
        optimizer.zero_grad()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        train_iter.set_description(f'[{acc:.2f}% ({correct}/{total})]', True)

In [6]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    acc = 0.
    test_iter = load_iter(data_loader, 'test')

    for i, (batch, label) in test_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        loss = loss_function(output, label)

        test_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        test_iter.set_description(f'[{acc:.2f}%({correct}/{total})]', True)
    return acc

In [7]:
def save_model(epoch, acc, optimizer):
    global best_acc
    if acc > best_acc:
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/model_{epoch}.pth')
        best_acc = acc
        print('Saving Model...')

def load_model(name):
    state_dict = torch.load(f'./models/{name}.pth', map_location=device)
    model = ResNet18()
    model.to(device)
    model.load_state_dict(state_dict['model'])
    optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
#     optimizer.load_state_dict(state_dict['optimizer'])
    return model, optimizer

In [8]:
# Training
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 35])
best_acc = 0
dataset, data_loader = load_data()
for epoch in range(1, 51):
    print(f'Epoch {epoch}')
    train(model)
    test_acc = test(model)
    save_model(epoch, test_acc, optimizer)
    scheduler.step()

KeyboardInterrupt: 

In [None]:
# Validation
model, optimizer = load_model('baseline_3')
loss_function = nn.CrossEntropyLoss()
dataset, data_loader = load_data()
test(model)

## Adversarial Training

In [None]:
model, optimizer = load_model('baseline_3')
epsilon = 0.01
dataset, data_loader = load_data()
dataset, data_loader = load_data()
adv_instances = collect_advs(model, data_loader, epsilon)

In [20]:
def adv_train(model):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    step = 0
    train_iter = load_iter(data_loader, 'train')
    for i, (batch, label) in train_iter:
        batch, label = batch.to(device), label.to(device)
        adv_batch = atk(batch, label)
        
        _, pred = model(batch).max(1)
        _, adv_pred = model(adv_batch).max(1)
        
        output = torch.zeros((batch.shape[0], 10), device=device)
        for j, k in enumerate(pred.eq(adv_pred)):
            if k == True:
                output[j] = model(batch[j].view(1, 3, 32, 32)).view(-1)
            else:
                output[j] = 0.5 * model(batch[j].view(1, 3, 32, 32)).view(-1) + 0.5 * model(adv_batch[j].view(1, 3, 32, 32)).view(-1)

        optimizer.zero_grad()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        train_iter.set_description(f'[{acc:.2f}% ({correct}/{total})]', True)

In [None]:
# Training
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 35])
best_acc = 0
dataset, data_loader = load_data()
baseline, _ = load_model('baseline_3')
atk = torchattacks.PGD(baseline, eps=8/255, alpha=2/255, steps=4)
for epoch in range(1, 51):
    print(f'Epoch {epoch}')
    adv_train(model)
    test_acc = test(model)
    save_model(epoch, test_acc, optimizer)
    scheduler.step()

Files already downloaded and verified
Files already downloaded and verified


                               0/50048 [00:00<? ?it/s] 

Epoch 1


██████████████████████████████ 50048/50048 [06:56<00:00 120.13it/s] [36.05% (18024/50000)]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5597.18it/s] [36.39%(3639/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 2


██████████████████████████████ 50048/50048 [07:07<00:00 117.09it/s] [63.03% (31515/50000)]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5586.81it/s] [45.30%(4530/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 3


██████████████████████████████ 50048/50048 [07:17<00:00 114.49it/s] [73.33% (36663/50000)]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5579.54it/s] [48.51%(4851/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 4


██████████████████████████████ 50048/50048 [07:16<00:00 114.60it/s] [78.84% (39418/50000)]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5576.82it/s] [49.34%(4934/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 5


███████████████████████████▌   46080/50048 [06:48<00:35 111.65it/s] [83.73% (38582/46080)]: 

In [None]:
advset = AdvDataSet(adv_instances, need_perturb_label=True)
advset.slicing(42, 1000)
trainset = dataset['train']
advset_list = advset.concat_dataset(trainset)
len(advset)

In [None]:
model_list = ['baseline_3']
model_list += [f'adv_{i}' for i in range(4)]
for m in model_list: 
    state_dict = torch.load(f'models/{m}.pth')
    print(m, state_dict['acc'])

## Hyperplane Distance

In [None]:
weight, bias = model.load_weight()
weight.shape, bias.shape

In [None]:
def distance(a, b, c):
    a = a.cpu().detach().numpy()
    b = b.cpu().detach().numpy()
    c = c.cpu().detach().numpy()
    return np.abs(np.matmul(a, b) + c) / np.linalg.norm(b)

In [None]:
adv_loader = torch.utils.data.DataLoader(advset, batch_size=batch_size, shuffle=True)
bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
adv_iter = tqdm(enumerate(adv_loader), total=len(adv_loader), unit_scale=batch_size, bar_format=bar_format)

sort_dict = {}
for i in range(10):
    sort_dict[i] = []

new_adv_instances = []
for i, (batch, label, perturb_label) in adv_iter:
    batch, label = batch.to(device), label.to(device)
    output = model(batch)
    
    latent_vec = model.load_vec()
    old_label = label
    new_label = perturb_label
    
    
    for j in range(batch.shape[0]):
        dis = distance(latent_vec[j], weight[old_label[j]], bias[old_label[j]])
        sort_dict[old_label[j].item()].append((dis, batch[j].cpu().detach(), old_label[j].item()))

for i in range(10):
    sort_dict[i].sort()
    idx = int(len(sort_dict[i]) * 0.1)
    sort_dict[i] = sort_dict[i][:-idx]
    
    for s in sort_dict[i]:
        new_adv_instances.append((s[1], s[2]))
len(new_adv_instances)

In [None]:
advset = AdvDataSet(new_adv_instances, need_perturb_label=False)
advset.slicing(42, 3727)
trainset = dataset['train']
advset_list = advset.concat_dataset(trainset)
len(advset_list[0])