In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.models.resnet import BasicBlock, resnet18, resnet50
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from pgd_purify import vae_purify, stae_purify, pgd_linf
from model.nn_model import ResNetEnc, ResNetVAE
import random
import os

def seed_everything(seed: int):    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(0)

In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
batch_size = 256
epoch_num = 2048
lr_decay_step = 1024
classification_weight = 2048
vae_beta = 1

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor()])

random_transforms_list = transforms.RandomApply(torch.nn.ModuleList([transforms.ColorJitter(),
                                                         transforms.RandomAffine((-20, 20),
                                                        translate=(0.0, 0.1), scale=(0.9, 1.1), 
                                                        fill=0.5)]), p=0.3)

aug_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.RandomHorizontalFlip(p=0.5), random_transforms_list])

trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=aug_transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1)

testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
net = ResNetEnc(image_size=32)
net = net.to(device)
ResVAE = ResNetVAE(net).to(device)
ResVAE = ResVAE.to(device)



In [5]:
CE_Loss = nn.CrossEntropyLoss()
mseloss = nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(ResVAE.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[lr_decay_step], gamma=0.1)

In [6]:
test_acc = 0
pbar = tqdm(range(epoch_num))
for epoch in pbar:
    loss_sum = 0
    ResVAE.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        # VAE
        x_reconst, z, y, mu, log_var = ResVAE(data, deterministic=False, classification_only=False)
        recons_loss = torch.sum((x_reconst - data) ** 2)
        kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp()) 
        # jointly training
        loss_val = CE_Loss(y, target) * classification_weight + recons_loss + vae_beta * kld_loss
        loss_sum += loss_val.item()
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
            

    scheduler.step()    
    ResVAE.eval()
    
    # eval on test
    pred_list = []
    gt_list = []
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            x_reconst, z, y_test, mu, log_var = ResVAE(data, deterministic=True, classification_only=False)
        
        pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
        gt_list += list(target.detach().cpu().numpy())
        
    test_acc = np.sum(np.array(gt_list) == np.array(pred_list)) / len(gt_list)
    pbar.set_postfix({"train loss sum": loss_sum, "test acc": test_acc})

100%|██████████████████████████████████████████████████████████████████████████| 2048/2048 [11:59:16<00:00, 21.07s/it, train loss sum=3.65e+6, test acc=0.935]


In [7]:
ResVAE = ResVAE.eval()
torch.save(ResVAE.state_dict(), './model/cifar_resnet.pth')

In [8]:
print(classification_report(gt_list, pred_list))
print(confusion_matrix(gt_list, pred_list))

              precision    recall  f1-score   support

           0       0.94      0.95      0.94      1000
           1       0.97      0.97      0.97      1000
           2       0.92      0.91      0.91      1000
           3       0.86      0.86      0.86      1000
           4       0.93      0.93      0.93      1000
           5       0.89      0.89      0.89      1000
           6       0.94      0.97      0.96      1000
           7       0.96      0.94      0.95      1000
           8       0.96      0.97      0.96      1000
           9       0.97      0.95      0.96      1000

    accuracy                           0.94     10000
   macro avg       0.94      0.93      0.93     10000
weighted avg       0.94      0.94      0.93     10000

[[948   2  14   7   2   0   0   4  18   5]
 [  3 974   0   0   1   1   2   1   1  17]
 [ 18   0 912  19  13  16  14   5   3   0]
 [  3   1  15 865  19  65  20   4   6   2]
 [  4   0  18  15 932   9  13   8   1   0]
 [  6   0  10  63  13 890 

In [9]:
# attack and purify
pred_list = []
pfy_pred_list = []
gt_list = []
ResVAE = ResVAE.eval()
for batch_idx, (data, target) in tqdm(enumerate(test_loader)):
    data, target = data.to(device), target.to(device)
    adv_vae = pgd_linf(data.to(device), target.to(device), ResVAE, atk_itr=128, eps=8/255, alpha=1/255, device=device)
    with torch.no_grad():
        y_test = ResVAE(adv_vae, deterministic=True, classification_only=True)

    purify_data_vae = vae_purify(adv_vae, ResVAE, atk_itr=32, eps=8/255, random_iteration=16, device=device)

    with torch.no_grad():
        pfy_y_test = ResVAE(purify_data_vae, deterministic=True, classification_only=True)
    
    pred_list += list(y_test.argmax(-1).cpu().detach().numpy())
    pfy_pred_list += list(pfy_y_test.argmax(-1).cpu().detach().numpy())
    gt_list += list(target.detach().cpu().numpy())

print('adversarial acc')
print(classification_report(gt_list, pred_list))
print(confusion_matrix(gt_list, pred_list))
print('purify acc')
print(classification_report(gt_list, pfy_pred_list))
print(confusion_matrix(gt_list, pfy_pred_list))

40it [39:35, 59.38s/it]

adversarial acc
              precision    recall  f1-score   support

           0       0.22      0.20      0.21      1000
           1       0.38      0.41      0.39      1000
           2       0.10      0.13      0.12      1000
           3       0.08      0.08      0.08      1000
           4       0.21      0.16      0.18      1000
           5       0.13      0.15      0.14      1000
           6       0.14      0.13      0.13      1000
           7       0.28      0.27      0.28      1000
           8       0.31      0.25      0.28      1000
           9       0.27      0.32      0.29      1000

    accuracy                           0.21     10000
   macro avg       0.21      0.21      0.21     10000
weighted avg       0.21      0.21      0.21     10000

[[201  42 141  28  51  10 228  22 218  59]
 [ 79 407  42  43   6  60  26  14  62 261]
 [ 79  27 134 103 102  50  83 221  33 168]
 [ 21  60  93  77 165 399  95  37  20  33]
 [ 24   1 222 179 160  97 112 125  30  50]
 [ 16 120 


