In [None]:

import logging
import torch
torch.manual_seed(0)
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.INFO)
from pathlib import Path


import torchvision
import torchvision.transforms as transforms

from ada_verona.database.dataset.experiment_dataset import ExperimentDataset
from ada_verona.database.dataset.pytorch_experiment_dataset import PytorchExperimentDataset
from ada_verona.database.experiment_repository import ExperimentRepository
from ada_verona.dataset_sampler.dataset_sampler import DatasetSampler
from ada_verona.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler
from ada_verona.epsilon_value_estimator.binary_search_epsilon_value_estimator import (
    BinarySearchEpsilonValueEstimator,
)
from ada_verona.epsilon_value_estimator.epsilon_value_estimator import EpsilonValueEstimator
from ada_verona.verification_module.attack_estimation_module import AttackEstimationModule
from ada_verona.verification_module.attacks.fgsm_attack import FGSMAttack
from ada_verona.verification_module.property_generator.one2any_property_generator import (
    One2AnyPropertyGenerator,
)
from ada_verona.verification_module.property_generator.property_generator import PropertyGenerator
from ada_verona.database.machine_learning_model.pytorch_network import PyTorchNetwork
from ada_verona.database.machine_learning_model.torch_model_wrapper import TorchModelWrapper

import torch
from onnx2torch import convert

import matplotlib.pyplot as plt

import timm as timm
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
model = timm.create_model("vit_large_patch14_clip_224.openai_ft_in1k", pretrained=True)

network = PyTorchNetwork(model, [1, 3, 224,224], "vit_large_patch14_clip_224.openai_ft_in1k")

In [None]:
epsilon_list = np.arange(0.00, 0.4, 0.0039)
experiment_repository_path = Path("test_experiment")
torch_dataset = torchvision.datasets.ImageNet(
    root="./data", split="val", transform=transforms.Compose([transforms.Resize([224,224]), transforms.ToTensor()])
)
dataset = PytorchExperimentDataset(dataset=torch_dataset)

In [None]:
experiment_repository = ExperimentRepository(base_path=experiment_repository_path, network_folder=None)

experiment_name = "vit_large_patch14_clip_224.openai_ft_in1k-FGSM"
property_generator = One2AnyPropertyGenerator()

verifier = AttackEstimationModule(attack=FGSMAttack(), top_k=5)

epsilon_value_estimator = BinarySearchEpsilonValueEstimator(
    epsilon_value_list=epsilon_list.copy(), verifier=verifier
)
dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True, top_k=5)
experiment_repository.initialize_new_experiment(experiment_name)
experiment_repository.save_configuration(
    dict(
        experiment_name=experiment_name,
        experiment_repository_path=str(experiment_repository_path),
        dataset=str(dataset),
        epsilon_list=[str(x) for x in epsilon_list],
    )
)

sampled_data = dataset_sampler.sample(network, dataset)

In [None]:
for data_point in sampled_data:
    
    verification_context = experiment_repository.create_verification_context(network, data_point, property_generator)

    epsilon_value_result = epsilon_value_estimator.compute_epsilon_value(verification_context)

    experiment_repository.save_result(epsilon_value_result)
