In [1]:
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
from torchvision import datasets, transforms
from model.VAE import *
from blackbox_pgd_model.wideresnet_update import *
from pgd_attack import *
import torch.optim as optim
import numpy as np
from util import *
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# parser = argparse.ArgumentParser(description='PyTorch CIFAR10 VAE Training')
# parser.add_argument('--batch-size', type=int, default=200, metavar='N',
#                     help='input batch size for training (default: 128)')
# parser.add_argument('--test-batch-size', type=int, default=200, metavar='N',
#                     help='input batch size for testing (default: 128)')
# parser.add_argument('--x-dim', type=int, default=784)
# parser.add_argument('--hidden-dim', type=int, default=400)
# parser.add_argument('--latent-dim', type=int, default=200)
# parser.add_argument('--epochs', type=int, default=30)
# args = parser.parse_args()
torch.manual_seed(1)
torch.cuda.manual_seed(1)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

transform_test = transforms.Compose([
    transforms.ToTensor(),
])
batch_size = 200
test_batch_size = 200
beta = 0.5
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform_test)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, **kwargs)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, **kwargs)

def train(vae_model, c_model, data_loader, vae_optimizer, c_optimizer, epoch_num):
    vae_model.train()
    c_model.train()
    v_loss_sum = 0
    c_loss_sum = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        #data = data.view(batch_size, x_dim)
        data, target = data.to(device), target.to(device)
        vae_optimizer.zero_grad()
        c_optimizer.zero_grad()
        x_hat, mean, log_v, x_ = vae_model(data)
        # x_cat = torch.cat((mean, log_v),1)
        logit = c_model(x_.detach().view(-1,160,8,8))
        #logit = c_model(x_cat)
        v_loss, c_loss = loss_function_mean(data, target, x_hat, mean, log_v, logit)
        #print(loss)
        v_loss_sum += v_loss
        c_loss_sum += c_loss
        # if epoch_num % 2 == 1:
        if epoch_num <= 30:
            v_loss.backward()
            vae_optimizer.step()
            c_loss.backward()
        else:
            c_loss.backward()
            c_optimizer.step()
            v_loss.backward()
    return v_loss_sum, c_loss_sum

def eval_train(vae_model, c_model):
    vae_model.eval()
    c_model.eval()
    err_num = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            x_hat, mean, log_v, x_ = vae_model(data)
            logit = c_model(x_.detach().view(-1,160,8,8))
            err_num += (logit.data.max(1)[1] != target.data).float().sum()
    print('train error num:{}'.format(err_num))
def eval_test(vae_model, c_model):
    vae_model.eval()
    c_model.eval()
    err_num = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            x_hat, mean, log_v, x_ = vae_model(data)
            logit = c_model(x_.detach().view(-1,160,8,8))
            err_num += (logit.data.max(1)[1] != target.data).float().sum()
    print('test error num:{}'.format(err_num))


In [2]:
def testtime_update(vae_model, c_model, x_adv, target, learning_rate=0.1, num = 30, mode = 'mean'):
    x_adv = x_adv.detach()
    x_hat_adv, mean, log_v, x_ = vae_model(x_adv)
    for _ in range(num):
        if (x_hat_adv != x_hat_adv).sum() > 0:
            print('nan Error')
            exit()
        if mode == 'mean':
            loss = nn.functional.binary_cross_entropy(x_hat_adv, x_adv, size_average=False, reduction='mean')
            # loss = vae_loss_mean(x_adv, x_hat_adv, mean, log_v)
        else:
            loss = nn.functional.binary_cross_entropy(x_hat_adv, x_adv, reduction='sum')
            # loss = vae_loss_sum(x_adv, x_hat_adv, mean, log_v)
        # x_.retain_grad()
        mean.retain_grad()
        log_v.retain_grad()
        loss.backward(retain_graph=True)
        with torch.no_grad():
            # x_.data -= learning_rate * x_.grad.data
            mean.data -= learning_rate * mean.grad.data
            log_v.data -= learning_rate * log_v.grad.data
        # x_.grad.data.zero_()
        mean.grad.data.zero_()
        log_v.grad.data.zero_()
        x_hat_adv = vae_model.decoder(vae_model.reparameterize(mean, log_v))
        # x_hat_adv = vae_model.re_forward(x_)
    x_cat = torch.cat((mean, log_v), 1)
    logit_adv = c_model(x_cat)
        # print((logit_adv.data.max(1)[1] != target.data).float().sum())
    return logit_adv

def test(vae_model, c_model):
    err_num = 0
    err_adv = 0
    c_model.eval()
    vae_model.eval()
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        data = Variable(data.data, requires_grad=True)
        _,mean, log_v, x_ = vae_model(data)
        x_cat = torch.cat((mean, log_v), 1)
        logit = c_model(x_cat)
        err_num += (logit.data.max(1)[1] != target.data).float().sum()

        # x_adv = pgd_cifar_blackbox(vae_model, c_model, source_model, data, target, 20, 0.03, 0.003)
        x_adv = pgd_mnist(vae_model, c_model, data, target, 40, 0.3, 0.01)
        _,mean,log_v,x_ = vae_model(x_adv)
        x_cat = torch.cat((mean, log_v), 1)
        logit_adv = c_model(x_cat)
        logit_nat_new = testtime_update(vae_model, c_model,  data, target,learning_rate=0.01, num=20)
        logit_adv_new = testtime_update(vae_model, c_model,  x_adv, target,learning_rate=1.0, num=1)
        return logit, logit_adv,logit_nat_new, logit_adv_new, target
        # logit_adv = diff_update_cifar(vae_model,c_model, x_adv, target,learning_rate=0.05, num=500)
        # _,_,_,x_adv_ = vae_model(x_adv)
        # logit_adv = c_model(x_adv_.view(-1,160,8,8))
        # logit = c_model(x_.view(-1,160,8,8))
        adv_num = (logit_adv.data.max(1)[1] != target.data).float().sum()
        # exit()
        print(adv_num)
        err_adv += adv_num
        # x_cat_adv = torch.cat((m_adv, log_adv), 1)
        # logit_adv = c_model(x_cat_adv)
        # err_adv += (logit_adv.data.max(1)[1] != target.data).float().sum()
    print(len(test_loader.dataset))
    print(err_num)
    print(err_adv)

In [3]:

vae_model = VAE(zDim=256).to(device)
c_model = classifier(input_dim=256*2).to(device)
vae_model_path = './model-checkpoint/mnist-vae-model-54.pt'
c_model_path = './model-checkpoint/mnist-c-model-54.pt'

vae_model.load_state_dict(torch.load(vae_model_path))
c_model.load_state_dict(torch.load(c_model_path))


In [4]:
logit, logit_adv,logit_nat_new, logit_adv_new, target = test(vae_model, c_model)



In [7]:
# label = logit_calculate(logit_adv, logit_adv_new)
# label = (logit_adv_new - logit_adv).data.max(1)[1]
# label = logit_calculate(logit, logit_nat_new)
# label = label.to(device)
print((logit.data.max(1)[1]!=target).sum())
# print((label!=target).sum())
# print(label)
print(target)

tensor(176, device='cuda:0')
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5,
        4, 0, 7, 4, 0, 1, 3, 1, 3, 4, 7, 2, 7, 1, 2, 1, 1, 7, 4, 2, 3, 5, 1, 2,
        4, 4, 6, 3, 5, 5, 6, 0, 4, 1, 9, 5, 7, 8, 9, 3, 7, 4, 6, 4, 3, 0, 7, 0,
        2, 9, 1, 7, 3, 2, 9, 7, 7, 6, 2, 7, 8, 4, 7, 3, 6, 1, 3, 6, 9, 3, 1, 4,
        1, 7, 6, 9, 6, 0, 5, 4, 9, 9, 2, 1, 9, 4, 8, 7, 3, 9, 7, 4, 4, 4, 9, 2,
        5, 4, 7, 6, 7, 9, 0, 5, 8, 5, 6, 6, 5, 7, 8, 1, 0, 1, 6, 4, 6, 7, 3, 1,
        7, 1, 8, 2, 0, 2, 9, 9, 5, 5, 1, 5, 6, 0, 3, 4, 4, 6, 5, 4, 6, 5, 4, 5,
        1, 4, 4, 7, 2, 3, 2, 7, 1, 8, 1, 8, 1, 8, 5, 0, 8, 9, 2, 5, 0, 1, 1, 1,
        0, 9, 0, 3, 1, 6, 4, 2], device='cuda:0')


In [21]:
idx = 2
print(target[idx])
print(logit[idx])

tensor(1, device='cuda:0')
tensor([-5.1969, 10.0791, -3.8976, -4.7219, -3.6014, -5.1634, -3.6767, -1.5791,
        -2.7491, -4.7115], device='cuda:0', grad_fn=<SelectBackward0>)


In [23]:
print(logit_nat_new[idx])

tensor([-5.0447, 10.0288, -3.9240, -4.8360, -3.6265, -5.2876, -3.6403, -1.7288,
        -2.5805, -4.4675], device='cuda:0', grad_fn=<SelectBackward0>)


In [22]:
print(logit_adv[idx])
print(logit_adv_new[idx])

tensor([-7.2340, -4.7403, -6.8522, -4.4464, -6.8310, -5.9665, -4.9337, -8.9789,
        17.0630, -5.8113], device='cuda:0', grad_fn=<SelectBackward0>)
tensor([ 3.3812, -4.5472, -7.6680, -6.7861, -5.0538, -7.6097,  3.8488, -6.2713,
        -0.3787, -3.1988], device='cuda:0', grad_fn=<SelectBackward0>)
