In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler

import torchvision
import torchvision.transforms as transforms

import torchattacks

import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import os
import sys

from tqdm import tqdm

from resnet import ResNet18
from data_aug import collect_advs, AdvDataSet

TITAN RTX


In [2]:
if torch.cuda.is_available() == True:
    device = torch.device('cuda:1')
    print(torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
device

TITAN RTX


device(type='cuda', index=1)

In [3]:
batch_size = 128
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def load_data():    
    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                    transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([transforms.ToTensor()])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_test)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    dataset = {'train': trainset, 'test': testset}
    data_loader = {'train': train_loader, 'test': test_loader}
    return dataset, data_loader

def load_iter(data_loader, data_type):
    bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
    
    if data_type == 'train':
        train_loader = data_loader['train']
        train_iter = tqdm(enumerate(train_loader), total=len(train_loader), unit_scale=batch_size, bar_format=bar_format)
        return train_iter
    elif data_type == 'test':
        test_loader = data_loader['test']
        test_iter = tqdm(enumerate(test_loader), total=len(test_loader), unit_scale=batch_size, bar_format=bar_format)
        return test_iter
    else:
        print('Data Error!!!')

In [4]:
def imshow(img):
    npimg = img.cpu().detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [5]:
def train(model):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    step = 0
    train_iter = load_iter(data_loader, 'train')
    for i, (batch, label) in train_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        
        optimizer.zero_grad()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        train_iter.set_description(f'[{acc:.2f}% ({correct}/{total})]', True)

In [6]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    acc = 0.
    test_iter = load_iter(data_loader, 'test')

    for i, (batch, label) in test_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        loss = loss_function(output, label)

        test_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        test_iter.set_description(f'[{acc:.2f}%({correct}/{total})]', True)
    return acc

In [7]:
def save_model(epoch, acc, optimizer):
    global best_acc
    if acc > best_acc:
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/model_{epoch}.pth')
        best_acc = acc
        print('Saving Model...')

def load_model(name):
    state_dict = torch.load(f'./models/{name}.pth', map_location=device)
    model = ResNet18()
    model.to(device)
    model.load_state_dict(state_dict['model'])
    optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
#     optimizer.load_state_dict(state_dict['optimizer'])
    return model, optimizer

In [20]:
# Training
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 35])
best_acc = 0
dataset, data_loader = load_data()
for epoch in range(1, 51):
    print(f'Epoch {epoch}')
    train(model)
    test_acc = test(model)
    save_model(epoch, test_acc, optimizer)
    scheduler.step()

Files already downloaded and verified
Files already downloaded and verified


▏                              256/50048 [00:00<00:30 1627.43it/s] [10.94% (28/256)]: 

Epoch 1
torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


▍                              768/50048 [00:00<00:28 1719.90it/s] [13.93% (107/768)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


▌                              1024/50048 [00:00<00:27 1758.41it/s] [14.67% (169/1152)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


▉                              1536/50048 [00:00<00:26 1811.47it/s] [15.43% (237/1536)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


█                              1792/50048 [00:01<00:26 1827.18it/s] [16.61% (319/1920)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


█▍                             2304/50048 [00:01<00:25 1847.77it/s] [17.36% (400/2304)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


█▌                             2560/50048 [00:01<00:25 1853.14it/s] [18.42% (495/2688)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


█▊                             3072/50048 [00:01<00:25 1859.72it/s] [19.17% (589/3072)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


█▉                             3328/50048 [00:01<00:25 1858.46it/s] [19.79% (684/3456)]: 

torch.Size([128, 10])
torch.Size([128, 10])
torch.Size([128, 10])


██▏                            3584/50048 [00:02<00:26 1774.67it/s] [20.09% (720/3584)]: 

torch.Size([128, 10])
torch.Size([128, 10])





KeyboardInterrupt: 

In [20]:
# Validation
model, optimizer = load_model('baseline_2')
loss_function = nn.CrossEntropyLoss()
dataset, data_loader = load_data()
test(model)

Files already downloaded and verified
Files already downloaded and verified


██████████████████████████████ 10112/10112 [00:01<00:00 5482.08it/s] [87.17%(8717/10000)]: 


87.17

## Adversarial Training

In [38]:
def adv_train(baseline, model):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    step = 0
    count = 0
    train_iter = load_iter(data_loader, 'train')
    for i, (batch, label) in train_iter:
        batch, label = batch.to(device), label.to(device)
        adv_batch = atk(batch, label)
        
        _, pred = baseline(batch).max(1)
        _, adv_pred = baseline(adv_batch).max(1)
        latent_vec = baseline.load_vec()
        
        output = model(batch)
        adv_output = model(adv_batch)
        
        
        for j, equal in enumerate(pred.eq(adv_pred)):
            dis = distance(latent_vec[j], weight[label[j].item()], bias[label[j].item()])
            if not equal and dis <= boundary_dict[label[j].item()]['bot_20']:
                count += 1
                output[j] = 0.5 * output[j] + 0.5 * adv_output[j]
    
        optimizer.zero_grad()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        train_iter.set_description(f'[{acc:.2f}% ({correct}/{total})] [{count}]', True)

In [39]:
# Training
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 35])
best_acc = 0
dataset, data_loader = load_data()
baseline, _ = load_model('baseline_2')
atk = torchattacks.FGSM(baseline, eps=0.5/255)
for epoch in range(1, 51):
    print(f'Epoch {epoch}')
    adv_train(baseline, model)
    test_acc = test(model)
    save_model('bot20', test_acc, optimizer)
    scheduler.step()

Files already downloaded and verified
Files already downloaded and verified


                               0/50048 [00:00<? ?it/s] 

Epoch 1


██████████████████████████████ 50048/50048 [01:28<00:00 567.30it/s] [45.30% (22649/50000)] [5068]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5639.74it/s] [54.34%(5434/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 2


██████████████████████████████ 50048/50048 [01:28<00:00 565.96it/s] [67.32% (33660/50000)] [5090]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5573.48it/s] [72.63%(7263/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 3


██████████████████████████████ 50048/50048 [01:28<00:00 564.97it/s] [76.60% (38302/50000)] [5093]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5611.79it/s] [75.76%(7576/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 4


██████████████████████████████ 50048/50048 [01:28<00:00 566.30it/s] [81.83% (40914/50000)] [5050]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5811.31it/s] [80.03%(8003/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 5


██████████████████████████████ 50048/50048 [01:28<00:00 563.82it/s] [85.64% (42818/50000)] [5012]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5608.14it/s] [81.50%(8150/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 6


██████████████████████████████ 50048/50048 [01:28<00:00 563.82it/s] [88.23% (44116/50000)] [5134]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5613.13it/s] [81.55%(8155/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 7


██████████████████████████████ 50048/50048 [01:28<00:00 564.10it/s] [90.77% (45386/50000)] [5037]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5612.54it/s] [82.61%(8261/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 8


██████████████████████████████ 50048/50048 [01:28<00:00 563.98it/s] [92.90% (46451/50000)] [5058]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5605.09it/s] [84.03%(8403/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 9


██████████████████████████████ 50048/50048 [01:28<00:00 564.03it/s] [94.11% (47053/50000)] [5072]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5620.08it/s] [83.66%(8366/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 10


██████████████████████████████ 50048/50048 [01:28<00:00 564.54it/s] [95.21% (47606/50000)] [4973]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5632.06it/s] [81.89%(8189/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 11


██████████████████████████████ 50048/50048 [01:28<00:00 563.98it/s] [96.21% (48105/50000)] [5078]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5652.67it/s] [84.99%(8499/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 12


██████████████████████████████ 50048/50048 [01:28<00:00 564.30it/s] [96.65% (48326/50000)] [5018]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5640.62it/s] [83.37%(8337/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 13


██████████████████████████████ 50048/50048 [01:28<00:00 564.15it/s] [97.39% (48697/50000)] [5067]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5667.21it/s] [83.77%(8377/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 14


██████████████████████████████ 50048/50048 [01:28<00:00 564.35it/s] [97.55% (48774/50000)] [5044]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5662.15it/s] [83.85%(8385/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 15


██████████████████████████████ 50048/50048 [01:28<00:00 564.09it/s] [97.67% (48836/50000)] [5115]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5660.36it/s] [82.87%(8287/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 16


██████████████████████████████ 50048/50048 [01:28<00:00 564.21it/s] [97.96% (48981/50000)] [5051]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5662.58it/s] [83.42%(8342/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 17


██████████████████████████████ 50048/50048 [01:28<00:00 564.05it/s] [97.96% (48982/50000)] [5104]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5656.43it/s] [84.47%(8447/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 18


██████████████████████████████ 50048/50048 [01:28<00:00 564.39it/s] [98.13% (49064/50000)] [5034]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5655.04it/s] [82.99%(8299/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 19


██████████████████████████████ 50048/50048 [01:28<00:00 564.13it/s] [98.11% (49054/50000)] [5064]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5656.37it/s] [84.03%(8403/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 20


██████████████████████████████ 50048/50048 [01:28<00:00 564.28it/s] [98.16% (49081/50000)] [5030]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5657.75it/s] [84.27%(8427/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 21


██████████████████████████████ 50048/50048 [01:28<00:00 564.18it/s] [98.17% (49083/50000)] [5052]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5660.96it/s] [84.90%(8490/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 22


██████████████████████████████ 50048/50048 [01:28<00:00 564.24it/s] [98.52% (49259/50000)] [5065]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5663.48it/s] [83.33%(8333/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 23


██████████████████████████████ 50048/50048 [01:28<00:00 564.26it/s] [98.24% (49119/50000)] [5073]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5668.19it/s] [83.63%(8363/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 24


██████████████████████████████ 50048/50048 [01:28<00:00 564.05it/s] [98.21% (49107/50000)] [5124]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5658.48it/s] [83.88%(8388/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 25


██████████████████████████████ 50048/50048 [01:28<00:00 564.23it/s] [98.52% (49260/50000)] [5070]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5666.47it/s] [83.91%(8391/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 26


██████████████████████████████ 50048/50048 [01:28<00:00 564.38it/s] [99.64% (49820/50000)] [5044]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5665.39it/s] [86.89%(8689/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 27


██████████████████████████████ 50048/50048 [01:28<00:00 563.62it/s] [99.97% (49985/50000)] [5112]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5669.39it/s] [86.92%(8692/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 28


██████████████████████████████ 50048/50048 [01:28<00:00 564.26it/s] [100.00% (49998/50000)] [5031]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5664.06it/s] [86.96%(8696/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 29


██████████████████████████████ 50048/50048 [01:28<00:00 564.20it/s] [99.99% (49997/50000)] [5039]:  
██████████████████████████████ 10112/10112 [00:01<00:00 5671.21it/s] [87.12%(8712/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 30


██████████████████████████████ 50048/50048 [01:28<00:00 564.32it/s] [100.00% (49999/50000)] [5075]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5613.50it/s] [87.22%(8722/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 31


██████████████████████████████ 50048/50048 [01:28<00:00 563.92it/s] [100.00% (49999/50000)] [5131]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5655.32it/s] [87.30%(8730/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 32


██████████████████████████████ 50048/50048 [01:28<00:00 564.56it/s] [100.00% (49999/50000)] [5005]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5664.63it/s] [87.29%(8729/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 33


██████████████████████████████ 50048/50048 [01:28<00:00 564.51it/s] [100.00% (50000/50000)] [5041]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5672.76it/s] [87.26%(8726/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 34


██████████████████████████████ 50048/50048 [01:28<00:00 564.37it/s] [100.00% (50000/50000)] [5078]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5615.75it/s] [87.24%(8724/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 35


██████████████████████████████ 50048/50048 [01:28<00:00 564.14it/s] [100.00% (50000/50000)] [5098]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5668.94it/s] [87.34%(8734/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 36


██████████████████████████████ 50048/50048 [01:28<00:00 564.46it/s] [100.00% (49998/50000)] [5038]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5669.45it/s] [87.38%(8738/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Saving Model...
Epoch 37


██████████████████████████████ 50048/50048 [01:28<00:00 564.77it/s] [100.00% (49999/50000)] [4950]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5671.67it/s] [87.31%(8731/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 38


██████████████████████████████ 50048/50048 [01:28<00:00 564.19it/s] [100.00% (50000/50000)] [5088]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5663.10it/s] [87.30%(8730/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 39


██████████████████████████████ 50048/50048 [01:28<00:00 564.50it/s] [100.00% (49999/50000)] [5042]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5670.78it/s] [87.26%(8726/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 40


██████████████████████████████ 50048/50048 [01:28<00:00 564.13it/s] [100.00% (49999/50000)] [5084]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5669.70it/s] [87.28%(8728/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 41


██████████████████████████████ 50048/50048 [01:28<00:00 564.32it/s] [100.00% (49999/50000)] [5057]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5671.54it/s] [87.27%(8727/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 42


██████████████████████████████ 50048/50048 [01:28<00:00 564.11it/s] [100.00% (50000/50000)] [5094]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5664.47it/s] [87.18%(8718/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 43


██████████████████████████████ 50048/50048 [01:28<00:00 564.04it/s] [100.00% (49999/50000)] [5094]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5667.46it/s] [87.37%(8737/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 44


██████████████████████████████ 50048/50048 [01:28<00:00 564.53it/s] [100.00% (50000/50000)] [5007]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5644.42it/s] [87.27%(8727/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 45


██████████████████████████████ 50048/50048 [01:28<00:00 564.23it/s] [100.00% (50000/50000)] [5086]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5669.01it/s] [87.24%(8724/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 46


██████████████████████████████ 50048/50048 [01:28<00:00 564.28it/s] [100.00% (50000/50000)] [5068]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5669.99it/s] [87.33%(8733/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 47


██████████████████████████████ 50048/50048 [01:28<00:00 564.26it/s] [100.00% (50000/50000)] [5075]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5671.73it/s] [87.23%(8723/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 48


██████████████████████████████ 50048/50048 [01:28<00:00 564.36it/s] [100.00% (50000/50000)] [5052]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5633.08it/s] [87.32%(8732/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 49


██████████████████████████████ 50048/50048 [01:28<00:00 564.24it/s] [100.00% (50000/50000)] [5079]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5668.58it/s] [87.36%(8736/10000)]: 
                               0/50048 [00:00<? ?it/s] 

Epoch 50


██████████████████████████████ 50048/50048 [01:28<00:00 564.32it/s] [100.00% (49999/50000)] [5024]: 
██████████████████████████████ 10112/10112 [00:01<00:00 5659.96it/s] [87.41%(8741/10000)]: 


Saving Model...


In [8]:
baseline, optimizer = load_model('baseline_2')
atk = torchattacks.FGSM(baseline, eps=0.5/255)
dataset, data_loader = load_data()
adv_instances = collect_advs(baseline, data_loader, atk)

Files already downloaded and verified
Files already downloaded and verified


██████████████████████████████ 50048/50048 [00:35<00:00 1415.43it/s] [# of Collected Adv Instances : 6023]: 


In [13]:
advset = AdvDataSet(adv_instances, need_perturb_label=True)
advset.slicing(42, 1000)
trainset = dataset['train']
advset_list = advset.concat_dataset(trainset)
len(advset)

6023

In [14]:
model_list = ['baseline_3']
model_list += [f'adv_{i}' for i in range(4)]
for m in model_list: 
    state_dict = torch.load(f'models/{m}.pth')
    print(m, state_dict['acc'])

baseline_3 87.66
adv_0 87.46
adv_1 86.93
adv_2 86.89
adv_3 86.74


## Hyperplane Distance

In [15]:
weight, bias = baseline.load_weight()
weight.shape, bias.shape

(torch.Size([10, 512]), torch.Size([10]))

In [16]:
def distance(a, b, c):
    a = a.cpu().detach().numpy()
    b = b.cpu().detach().numpy()
    c = c.cpu().detach().numpy()
    return np.abs(np.matmul(a, b) + c) / np.linalg.norm(b)

In [17]:
adv_loader = torch.utils.data.DataLoader(advset, batch_size=batch_size, shuffle=True)
bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
adv_iter = tqdm(enumerate(adv_loader), total=len(adv_loader), unit_scale=batch_size, bar_format=bar_format)

sort_dict = {}
boundary_dict = {}
for i in range(10):
    sort_dict[i] = []
    boundary_dict[i] = {}

new_adv_instances = []
for i, (batch, label) in adv_iter:
    batch, label = batch.to(device), label.to(device)
    output = baseline(batch)
    
    latent_vec = baseline.load_vec()
    old_label = label
    
    for j in range(batch.shape[0]):
        dis = distance(latent_vec[j], weight[old_label[j]], bias[old_label[j]])
        sort_dict[old_label[j].item()].append((dis, batch[j].cpu().detach(), old_label[j].item()))

for i in range(10):
    sort_dict[i].sort()
    idx_10 = int(len(sort_dict[i]) * 0.1)
    idx_20 = int(len(sort_dict[i]) * 0.2)
    boundary_dict[i]['top_10'] = sort_dict[i][idx_10][0]
    boundary_dict[i]['top_20'] = sort_dict[i][idx_20][0]
    boundary_dict[i]['bot_10'] = sort_dict[i][-idx_10][0]
    boundary_dict[i]['bot_20'] = sort_dict[i][-idx_20][0]
    
boundary_dict

██████████████████████████████ 6144/6144 [00:01<00:00 3631.48it/s] 


{0: {'top_10': 1.9921111,
  'top_20': 2.4569414,
  'bot_10': 4.5054946,
  'bot_20': 4.029607},
 1: {'top_10': 4.2631626,
  'top_20': 4.670611,
  'bot_10': 7.936803,
  'bot_20': 7.073451},
 2: {'top_10': 0.6874836,
  'top_20': 1.0158967,
  'bot_10': 3.1587703,
  'bot_20': 2.7136226},
 3: {'top_10': 0.7757223,
  'top_20': 1.1098003,
  'bot_10': 3.1452658,
  'bot_20': 2.6339307},
 4: {'top_10': 0.8189386,
  'top_20': 1.1776764,
  'bot_10': 3.1004822,
  'bot_20': 2.7183628},
 5: {'top_10': 1.1155912,
  'top_20': 1.5504944,
  'bot_10': 3.7928708,
  'bot_20': 3.3094227},
 6: {'top_10': 1.5036038,
  'top_20': 1.9126078,
  'bot_10': 3.988985,
  'bot_20': 3.6457815},
 7: {'top_10': 1.8043453,
  'top_20': 2.1294463,
  'bot_10': 4.8365936,
  'bot_20': 4.1259365},
 8: {'top_10': 2.617714,
  'top_20': 3.0114908,
  'bot_10': 5.2799406,
  'bot_20': 4.776397},
 9: {'top_10': 3.269793,
  'top_20': 3.720141,
  'bot_10': 6.762093,
  'bot_20': 5.922116}}

In [None]:
advset = AdvDataSet(new_adv_instances, need_perturb_label=False)
advset.slicing(42, 3727)
trainset = dataset['train']
advset_list = advset.concat_dataset(trainset)
len(advset_list[0])

## Adversarial Attack

In [12]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    acc = 0.
    test_iter = load_iter(data_loader, 'test')

    for i, (batch, label) in test_iter:
        batch, label = batch.to(device), label.to(device)
        adv_batch = atk(batch, label)
        output = model(adv_batch)
        loss = loss_function(output, label)

        test_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        test_iter.set_description(f'[{acc:.2f}%({correct}/{total})]', True)
    return acc

In [16]:
model_list = ['baseline_2', 'fgsm', 'top10', 'top20', 'bot10', 'bot20']
loss_function = nn.CrossEntropyLoss()
dataset, data_loader = load_data()
acc = []
for m in model_list: 
    state_dict = torch.load(f'models/{m}.pth')
    model, optimizer = load_model(f'{m}')
    atk = torchattacks.FGSM(model, eps=0.5/255)
    acc.append((m, state_dict['acc'], test(model)))
acc

Files already downloaded and verified
Files already downloaded and verified


██████████████████████████████ 10112/10112 [00:06<00:00 1654.52it/s] [61.35%(6135/10000)]: 
██████████████████████████████ 10112/10112 [00:06<00:00 1672.33it/s] [69.03%(6903/10000)]: 
██████████████████████████████ 10112/10112 [00:06<00:00 1669.64it/s] [67.63%(6763/10000)]: 
██████████████████████████████ 10112/10112 [00:06<00:00 1674.64it/s] [66.04%(6604/10000)]: 
██████████████████████████████ 10112/10112 [00:06<00:00 1671.01it/s] [68.72%(6872/10000)]: 
██████████████████████████████ 10112/10112 [00:06<00:00 1671.75it/s] [66.81%(6681/10000)]: 


[('baseline_2', 87.17, 61.35),
 ('fgsm', 87.26, 69.03),
 ('top10', 87.64, 67.63),
 ('top20', 86.65, 66.04),
 ('bot10', 87.07, 68.72),
 ('bot20', 86.57, 66.81)]