# Load library from Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

PATH_DIR = '/content/drive/MyDrive/XAI-Anna-Carlos/'

import sys
sys.path.append(PATH_DIR)

import xai_faithfulness_experiments_lib as ff

Mounted at /content/drive


## Load data

In [None]:
import torch
import torchvision

batch_size = 64

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./files/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./files/MNIST/raw/train-images-idx3-ubyte.gz to ./files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./files/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./files/MNIST/raw/train-labels-idx1-ubyte.gz to ./files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./files/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./files/MNIST/raw/t10k-images-idx3-ubyte.gz to ./files/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./files/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./files/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./files/MNIST/raw



# Load model

In [None]:
PATH_PRETRAINED = PATH_DIR + 'mnist-classifier.pth'
network = ff.load_pretrained_model(PATH_PRETRAINED)

## Get measures for a given ranking

In [None]:
import numpy as np
example_num = 10
image = example_data[example_num]
label = example_targets[example_num]

# Create a random ranking for testing purposes
some_ranking = np.random.rand(image.numpy().size).reshape(image.shape[1:])

In [None]:
measures = ff.get_measures_for_attributions(image, some_ranking, label, network, with_inverse=True)
print(measures)

{'output_curve': array([-2.3258843e+00, -1.8889745e+00, -2.1040361e+00, -2.2863321e+00,
       -2.4401395e+00, -2.1139369e+00, -1.7141179e+00, -2.0613632e+00,
       -1.1738228e+00, -5.6077927e-01, -1.4116015e-01, -8.9424133e-02,
       -4.9353648e-02, -2.2094758e-02, -1.5155052e-02, -9.4965147e-03,
       -5.7075149e-03, -4.8965542e-03, -2.6477063e-03, -2.3115363e-03],
      dtype=float32), 'is_hit_curve': array([False,  True, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True]), 'mean': -0.95058167, 'at_first_argmax': -1.8889745, 'auc': -17.847538, 'output_curve_inv': array([-2.3207054e+00, -2.1665447e+00, -2.0572743e+00, -1.4884809e+00,
       -7.4710906e-01, -6.9540399e-01, -5.5270916e-01, -2.9220417e-01,
       -2.8339130e-01, -2.1383408e-01, -1.4586502e-01, -4.1034106e-02,
       -2.5854582e-02, -2.1281697e-02, -1.2065759e-02, -1.0044858e-02,
       -7.0660221e-03, -3.5049217e-03, -3.5438850e-

  return F.log_softmax(x)
