In [8]:
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import numpy as np
import sys
from rl_agent import training, ppo_training, DDPG, SAC, TD3, PPO
from data_util import toDeviceDataLoader, load_cifar, to_device
from model_util import VGG
from util import asr, geo_asr, show_attack, diff_affine, project_lp
from adv_attacks import fgsm, pgd, geometric_attack, pgd_pca
from attack_sim import attack_simulation

torch.manual_seed(43)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dataset_root = 'D:/datasets/'
cifar10_train, cifar10_val, cifar10_test = load_cifar(dataset_root)
train_loader, val_loader, test_loader = toDeviceDataLoader(cifar10_train, cifar10_val, cifar10_test, device = device)

mdl = to_device(VGG('VGG16'), device)
mdl.load_state_dict(torch.load('../models/torch_cifar_vgg.pth'))
mdl = mdl.eval()

In [15]:
geo_asr(test_loader, mdl, geometric_attack)

0.2581999897956848

In [9]:
# print('ASR Base - {}, ASR FGSM - {}, ASR PGD - {}'.format(asr(test_loader, mdl).item(), asr(test_loader, mdl, adv_alg = fgsm).item(), asr(test_loader, mdl, adv_alg = pgd).item()))
# (tensor(0.0829), tensor(0.9119), tensor(0.9999))

In [10]:
# test_x, test_y = next(iter(test_loader))
# v2 = pgd(test_x, test_y, mdl)
# show_attack(test_x, v2, mdl)

In [None]:
env = attack_simulation(mdl = mdl, train_ds = cifar10_train, test_ds = cifar10_test, geometric_xi = 0.2, input_reduction = 'pca', in_pca_components = 512, action_reduction = 'geometric', device = device)

agent = TD3(env=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, num_test_episodes=5, max_ep_len=40)
training(agent = agent, dir = './tmp', steps_per_epoch=200, epochs=500, n_runs=1, start_steps=10000)