In [None]:
import torch, pickle, argparse, os, warnings, copy, time, mlflow
import numpy as np, pytorch_lightning as pl, matplotlib.pyplot as plt, eagerpy as ep
from models import ConvNet, S2ConvNet
from data_loader import load_test_data
from foolbox import PyTorchModel
from attacks import LinfRandomSearch
from tqdm.notebook import tqdm
from attack_helper import batched_predictions
# from attack_helper import run_batched_attack_cpu, batched_accuracy, batched_predictions_eps, save_pickle, batched_logits_eps
from mlflow.tracking.artifact_utils import get_artifact_uri

In [None]:
tracking_uri = 'sqlite:///mlruns/database.db'
mlflow.set_tracking_uri(tracking_uri)
df=mlflow.search_runs(experiment_names=['model_training'])
run_id=df[df['tags.mlflow.runName']==str(1668361949)]['run_id'].values[0]
artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)
dirs=os.listdir(artifact_path)

for s in dirs:
    if s.find('.ckpt') >= 0:
        checkpoint = s
        break

checkpoint_path = os.path.join(artifact_path, checkpoint)

best_model = torch.load(checkpoint_path)
hparams = argparse.Namespace(**best_model['hyper_parameters'])
model = ConvNet(hparams, None, None).eval()
model.load_state_dict(best_model['state_dict'])


TEST_PATH = "s2_mnist_cs1.gz"
test_data = load_test_data(TEST_PATH)

In [None]:
# tracking_uri = 'sqlite:///mlruns/database.db'
# mlflow.set_tracking_uri(tracking_uri)
# df=mlflow.search_runs(experiment_names=['model_training'])
# run_id=df[df['tags.mlflow.runName']==str(1668637573)]['run_id'].values[0]
# artifact_path = get_artifact_uri(run_id=run_id, tracking_uri=tracking_uri)
# dirs=os.listdir(artifact_path)

# for s in dirs:
#     if s.find('.ckpt') >= 0:
#         checkpoint = s
#         break

# checkpoint_path = os.path.join(artifact_path, checkpoint)

# best_model = torch.load(checkpoint_path)
# hparams = argparse.Namespace(**best_model['hyper_parameters'])
# model = S2ConvNet(hparams, None, None).eval()
# model.load_state_dict(best_model['state_dict'])


# TEST_PATH = "s2_mnist_cs1.gz"
# test_data = load_test_data(TEST_PATH)

In [None]:
images_ = test_data[:][0]
labels_ = test_data[:][1]

images = images_[labels_ == 0][:10]
for i in range(1,10):
    images = torch.cat((images, images_[labels_ == i][:10]))
    
labels = labels_[labels_ == 0][:10]
for i in range(1,10):
    labels = torch.cat((labels, labels_[labels_ == i][:10]))

fmodel = PyTorchModel(model, bounds=(0, 255))

In [None]:
bs = 100

clean_pred = batched_predictions(model, images, bs)

In [None]:
epsilons = [0, 0.1, 0.25, 0.5, 1, 3, 5, 7.5, 10, 20, 50, 100]
attack = LinfRandomSearch()

In [None]:
success = []
for i in tqdm(range(100)):
    *_, success_ = attack(fmodel, ep.astensor(images.cuda()), ep.astensor(clean_pred.cuda()), epsilons=epsilons)
    success.append(success_.raw.cpu())
    
success = torch.stack(success).permute(1,2,0)

In [None]:
success_per_sample = ep.astensor(success).float32().mean(axis=-1).raw

In [None]:
success_rate = torch.mean(success_per_sample, dim=-1)
print(success_rate)

In [None]:
plt.plot(epsilons, success_rate, 'o-')
plt.xlabel('$\epsilon$')
plt.ylabel('success rate')
plt.show()

In [None]:
# _, advs, success = attack(fmodel, ep.astensor(images.cuda()), ep.astensor(labels.cuda()), epsilons=epsilons)

# plt.imshow(advs[1][0,0].raw.detach().cpu().numpy(), cmap='gray')
# plt.show()

# plt.imshow(images[0,0].numpy(), cmap='gray')
# plt.show()

# print(torch.all(images[0,0] == advs[0][0,0].raw.detach().cpu()))
# print(torch.all(images[0,0] == advs[1][0,0].raw.detach().cpu()))

# for i in range(1,len(epsilons)):
#     print(torch.allclose(images[0,0], advs[i][0,0].raw.detach().cpu(), rtol=0, atol=epsilons[i]))
#     print(torch.allclose(images[0,0], advs[i][0,0].raw.detach().cpu(), rtol=0, atol=epsilons[i]-0.01))
    
# diff = images[0,0] - advs[1][0,0].raw.detach().cpu()

# plt.imshow(diff, cmap='gray')
# plt.show()

# print(diff)
# print(torch.sum((diff == 0)) / torch.prod(torch.Tensor([*diff.size()])))
# diff2 = images[1,0] - advs[1][1,0].raw.detach().cpu()
# print(diff2)