# Getting Started with MetaQuantus

This notebook shows the functionality of the MetaQuantus. For this purpose, we use a pre-trained ResNet-9 model and customised-MNIST dataset.

Make sure to have GPUs enabled to speed up computation.

#### Metric selection
- Step 1. Load data, model and explanations
- Step 2. Evaluate explanations with Quanuts metrics 
    - Produce rankings
- Step 3. Select a metric with MetaQuantus
    - Produce different visualisations

#### Hp tuning
- Step 1. Load data, model and explanations
- Step 2. Choose an estimator with Quantus
- Step 3. Optimise hyperparameter of metric

#### Study Convergent Validity
- Step 1. Load data, model and explanations
- Step 2. Choose a category of estimators
- Step 3. Calculate intra-correlation with MetaQuantus

In [4]:
from IPython.display import clear_output
!pip install "quantus[torch]"
!pip install captum
clear_output()

In [7]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Using device:', torch.cuda.get_device_name(0))
!nvidia-smi

Using device: Tesla T4
Mon Dec 26 14:46:53 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P8    11W /  70W |      3MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+----------------------------------------------------------------

In [6]:
# Mount Google Drive.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [8]:
path_assets = "drive/MyDrive/Projects/assets/"
path_results = "/content/drive/MyDrive/Projects/MetaQuantus/results/"

In [10]:
import sys

# Import local packages.
path = "/content/drive/MyDrive/Projects"

sys.path.append(f'{path}/quantus')
import quantus

sys.path.append(f'{path}/MetaQuantus')
import metaquantus
from metaquantus.utils import *
from metaquantus.models import *
import metaquantus.configs

## Step 1. Load data, model and explanations

In [12]:
# Paths.
path_cmnist_model = path_assets + "models/cmnist_resnet9.ckpt"
path_cmnist_assets = path_assets + "test_sets/cmnist_test_set.npy"
s_type = "box"

# Example for how to reload assets and models to notebook.
model_cmnist = ResNet9(nr_channels=3, nr_classes=10)
model_cmnist.load_state_dict(torch.load(path_cmnist_model))

assets_cmnist = np.load(path_cmnist_assets, allow_pickle=True).item()
x_batch_cmnist = assets_cmnist["x_batch"].detach().numpy()
y_batch_cmnist = assets_cmnist["y_batch"].detach().numpy()
s_batch_cmnist = assets_cmnist[f"s_batch_{s_type}"]

s_batch_cmnist = s_batch_cmnist.reshape(len(x_batch_cmnist), 1, 32, 32)

# Collect all experimental settings in one dictionary.
experimental_settings ={"cMNIST": {
    "x_batch": x_batch_cmnist, 
    "y_batch": y_batch_cmnist, 
    "s_batch": s_batch_cmnist,
    "models": {"ResNet9": model_cmnist}, 
    "gc_layers": {"ResNet9": 'list(model.named_modules())[1][1][-6]'}, 
    "estimator_kwargs": {
        "features": 32*2,
        "num_classes": 10,
        "img_size": 32,
        "percentage": 0.1,
        }
    }}

In [None]:
# Get explanation methods.
xai_methods = setup_xai_methods(
    gc_layer=dataset_settings[dataset_name]["gc_layers"][model_name],
    img_size=dataset_kwargs["img_size"],
    nr_channels=dataset_kwargs["nr_channels"],
)

In [21]:
from metaquantus.model_perturbation_test import ModelPerturbationTest
from metaquantus.input_perturbation_test import InputPerturbationTest

# Define analyser suite.
analyser_suite = {
            "Model Resilience Test": ModelPerturbationTest(
                **{
                    "noise_type": "multiplicative",
                    "mean": 1.0,
                    "std": 0.001,
                    "type": "Resilience",
                }
            ),
            "Model Adversary Test": ModelPerturbationTest(
                **{
                    "noise_type": "multiplicative",
                    "mean": 1.0,
                    "std": 2.0,
                    "type": "Adversary",
                }
            ),
            "Input Resilience Test": InputPerturbationTest(
                **{
                    "noise": 0.001,
                    "type": "Resilience",
                }
            ),
            "Input Adversary Test": InputPerturbationTest(
                **{
                    "noise": 5.0,
                    "type": "Adversary",
                }
            ),
        }

In [22]:
# Define estimators.

estiamtors = {
    "Localisation": {
            "Pointing-Game": (
                quantus.PointingGame(
                    abs=False,
                    normalise=True,
                    normalise_func=quantus.normalise_func.normalise_by_max,
                    return_aggregate=False,
                    aggregate_func=np.mean,
                    disable_warnings=True,
                ),
                False,
            ),
             "Top-K Intersection": (quantus.TopKIntersection(
                k=10,
                abs=False,
                normalise=True,
                normalise_func=quantus.normalise_func.normalise_by_max,
                return_aggregate=False,
                aggregate_func=np.mean,
                disable_warnings=True,
             ), False),
             "Relevance Rank Accuracy": (quantus.RelevanceRankAccuracy(
                abs=False,
                normalise=True,
                normalise_func=quantus.normalise_func.normalise_by_max,
                return_aggregate=False,
                aggregate_func=np.mean,
                disable_warnings=True,
             ), False),
            "Relevance Mass Accuracy": (
                quantus.RelevanceMassAccuracy(
                    abs=False,
                    normalise=True,
                    normalise_func=quantus.normalise_func.normalise_by_max,
                    return_aggregate=False,
                    aggregate_func=np.mean,
                    disable_warnings=True,
                ),
                False,
            ),
        }
    }

In [None]:


    

    ########################
    # Master run settings. #
    ########################

    # Define metric.
    estimator_category = "Complexity"
    estimator_name = "Sparseness"

    # Define master!
    master = MetaEvaluation(
        analyser_suite=analyser_suite,
        xai_methods=xai_methods,
        iterations=iters,
        nr_perturbations=K,
        write_to_file=False,
    )

    master(
        estimator=estimators[estimator_category][estimator_name][0],
        model=dataset_settings[dataset_name]["models"][model_name],
        x_batch=dataset_settings[dataset_name]["x_batch"],
        y_batch=dataset_settings[dataset_name]["y_batch"],
        a_batch=None,
        s_batch=dataset_settings[dataset_name]["s_batch"],
        channel_first=True,
        softmax=False,
        device=device,
        lower_is_better=estimators[estimator_category][estimator_name][1],
    )