# Robustness Experiment Tutorial

## General

This notebook shows examples how to use the different components of ada-verona. 
If these experiments are executed in a non-notebook context, one can make use of the ExperimentRepostory class to 
create and organise experiments in a structured manner. However, because this notebook shall show just the components 
and their input / output it does not make use of the ExperimentRepository class.
To see examples on how to use it, one can take a look at the example scripts in the scripts/ folder.

## Importing Necessary Components

In [None]:
import logging
from datetime import datetime
from pathlib import Path

import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
from autoverify.verifier import Nnenum
from ada_verona.database.network import Network

from ada_verona.analysis.report_creator import ReportCreator
from ada_verona.database.dataset.image_file_dataset import ImageFileDataset
from ada_verona.database.dataset.pytorch_experiment_dataset import PytorchExperimentDataset
from ada_verona.database.verification_context import VerificationContext
from ada_verona.dataset_sampler.predictions_based_sampler import PredictionsBasedSampler
from ada_verona.epsilon_value_estimator.binary_search_epsilon_value_estimator import (
    BinarySearchEpsilonValueEstimator,
)
from ada_verona.verification_module.auto_verify_module import (
    AutoVerifyModule,
)
from ada_verona.verification_module.property_generator.one2any_property_generator import (
    One2AnyPropertyGenerator,
)

%matplotlib inline
torch.manual_seed(0)
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.DEBUG)

## Defining Dataset

In [None]:
# define pytorch dataset. Preprocessing can be defined in the transform parameter
torch_dataset = torchvision.datasets.MNIST(root="data", train=False, download=True, transform=transforms.ToTensor())

# wrap pytorch dataset into experiment dataset to keep track of image id
experiment_dataset = PytorchExperimentDataset(dataset=torch_dataset)

# work on subset of the dataset to keep experiment small
experiment_dataset = experiment_dataset.get_subset([x for x in range(0, 10)])

### Custom dataset

In [None]:
# Alternatively, one can also use a custom dataset from the storage. 
# For this, one can make use of the ImageFileDataset class

# Here, one can also add a preprocessing. 
# However, as of now just the loading of torch tensors from the directory is supported
preprocessing = transform = torchvision.transforms.Compose([torchvision.transforms.Normalize((0.1307,), (0.3081,))])
custom_experiment_dataset = ImageFileDataset(
    image_folder=Path("../tests/test_experiment/data/images"),
    label_file=Path("../tests/test_experiment/data/image_labels.csv"),
    preprocessing=preprocessing,
)

## Component Setup

In [None]:
# define verifier
timeout = 300

# In this example, a one to any property generator is used. 
# That creates vnnlib files for one to any robustness queries
# A one to one property generator is also already implemented in the package and could be used here as well
# For the property generator, we have to define the number of classes, 
# the lower bound of the data and the upper bound of the data
property_generator = One2AnyPropertyGenerator(number_classes=10, data_lb=0, data_ub=10)

# In this example, Nnenum is used. 
# All the other verifiers offered by the autoverify package can be used too in the AutoVerifyModule
verifier = AutoVerifyModule(verifier=Nnenum(), timeout=timeout)

In [6]:
# To compute critical epsilon values, one can use the BinaySearchEpsilonValueEstimator class
epsilon_value_list = [0.001, 0.1, 0.2, 0.3, 0.4]
epsilon_value_estimator = BinarySearchEpsilonValueEstimator(epsilon_value_list=epsilon_value_list, verifier=verifier)

In [7]:
# For this example we take one of the test networks
network = Network(Path("../tests/test_experiment/data/networks/mnist-net_256x2.onnx"))

## Sampling Datapoints

In [8]:
# To compute the robustness of a network, one first has
# to check which data points are classified correctly.
# For that the PredictionsBasedSampler class is used
dataset_sampler = PredictionsBasedSampler(sample_correct_predictions=True)

# Here all the data points that are correctly predicted by the network are sampled
sampled_data = dataset_sampler.sample(network, experiment_dataset)

In [None]:
# All the 10 images in the sub dataset are predicted correctly by the network
print(f"Size of sampled dataset: {len(sampled_data)}")

## Computing Robustness Distribution

In [None]:
# To compute a critical epsilon values, for a given network and datapoint,
# a verification context is created.
# Also a folder for intermediate results needs to be provided to the VerificationContext,
# so the vnnlib files can be stored there.
# In addition, the results of the epsilon values queries can be stored there
results = []
now = datetime.now()
now_string = now.strftime("%d-%m-%Y+%H_%M")

# Here the intermediate results (the per epsilon queries )
intermediate_result_base_path = Path(f"intermediate_results/{now_string}")

for data_point in sampled_data:
    network_name = network.path.name.split(".")[0]
    intermediate_result_path = Path(intermediate_result_base_path / f"{network_name}/image_{data_point.id}")

    verification_context = VerificationContext(
        network,
        data_point,
        intermediate_result_path,
        property_generator=property_generator,
    )
    epsilon_value_result = epsilon_value_estimator.compute_epsilon_value(verification_context)

    print(f"result: {epsilon_value_result}")
    results.append(epsilon_value_result)

## Create Plots

In [None]:
result_dicts = [x.to_dict() for x in results]
result_df = pd.DataFrame(result_dicts)
result_df["network"] = (
    result_df.network_path.astype(str).str.split("/").apply(lambda x: x[-1]).apply(lambda x: x.split(".")[0])
)

In [None]:
report_creator = ReportCreator(result_df)

In [None]:
report_creator.create_box_figure()

In [None]:
report_creator.create_ecdf_figure()

In [None]:
report_creator.create_hist_figure()

In [None]:
report_creator.create_anneplot()