In [1]:
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
import numpy as np
import sys
from rl_agent import training, ppo_training, DDPG, SAC, TD3, PPO
from data_util import toDeviceDataLoader, load_cifar, to_device
from model_util import VGG
from util import asr, geo_asr, show_attack, diff_affine, project_lp
from adv_attacks import fgsm, pgd, geometric_attack, pgd_pca
from attack_sim import attack_simulation

torch.manual_seed(43)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
dataset_root = 'D:/datasets/'
cifar10_train, cifar10_val, cifar10_test = load_cifar(dataset_root)
train_loader, val_loader, test_loader = toDeviceDataLoader(cifar10_train, cifar10_val, cifar10_test, device = device)

mdl = to_device(VGG('VGG16'), device)
mdl.load_state_dict(torch.load('../models/torch_cifar_vgg.pth'))
mdl = mdl.eval()

In [15]:
geo_asr(test_loader, mdl, geometric_attack)

0.2581999897956848

In [9]:
# print('ASR Base - {}, ASR FGSM - {}, ASR PGD - {}'.format(asr(test_loader, mdl).item(), asr(test_loader, mdl, adv_alg = fgsm).item(), asr(test_loader, mdl, adv_alg = pgd).item()))
# (tensor(0.0829), tensor(0.9119), tensor(0.9999))

In [10]:
# test_x, test_y = next(iter(test_loader))
# v2 = pgd(test_x, test_y, mdl)
# show_attack(test_x, v2, mdl)

In [3]:
env = attack_simulation(mdl = mdl, train_ds = cifar10_train, test_ds = cifar10_test, geometric_xi = 0.2, input_reduction = 'pca', in_pca_components = 512, action_reduction = 'geometric', device = device)

In [4]:
agent = TD3(env=env, ac_kwargs=dict(hidden_sizes=[256] * 2), gamma=0.99, num_test_episodes=5, max_ep_len=40)
training(agent = agent, dir = './tmp', steps_per_epoch=200, epochs=500, n_runs=1, start_steps=10000)

run 0
1 [-0.0024696842301636934, 1.025192341330694e-05, 0.013367761857807636, 0.0001920275535667315, 0.00015919574070721865]
2 [0.023434162139892578, 8.344636626134161e-07, 0.5767449140548706, 0.00011824862303910777, 2.622600504764705e-06]
3 [0.715265154838562, -6.532447878271341e-05, 0.007919995114207268, 9.05985689314548e-06, 0.6970409154891968]
4 [0.0004459800838958472, -4.2914925870718434e-06, 0.000505319272633642, 0.6567240953445435, -0.0002577176783233881]
5 [-0.11596108973026276, 7.986986929608975e-06, -0.0112962881103158, 0.013266916386783123, 0.8490469455718994]
6 [0.11355499923229218, 0.0017153569497168064, -2.381903648376465, 1.5199663639068604, 0.01636498235166073]
7 [0.0032965634018182755, 0.021989166736602783, 2.1177315711975098, 0.08442249894142151, 0.007910056039690971]
8 [0.7852991819381714, 0.00013314727402757853, -0.8347034454345703, 0.00207909825257957, 2.980217686854303e-06]
9 [3.457058937783586e-06, 0.7147964835166931, 4.3867839849554e-05, 0.8059764504432678, 0.88

KeyboardInterrupt: 