In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

PATH_DIR = '/content/drive/MyDrive/XAI-Anna-Carlos/'

import sys
sys.path.append(PATH_DIR)

import xai_faithfulness_experiments_lib as ff

Mounted at /content/drive


## Load model

In [5]:
PATH_PRETRAINED = PATH_DIR + 'mnist-classifier.pth'
network = ff.load_pretrained_model(PATH_PRETRAINED)

## Load file in old format

In [18]:
filename = PATH_DIR + 'Old data/genetic_generated.npz'
old_file = ff.load_generated_data_old_format(filename)

In [19]:
old_file.keys()

dict_keys(['image', 'label', 'rankings', 'qmeans', 'qmean_invs', 'qargmaxs', 'qargmax_invs', 'qaucs', 'qauc_invs', 'output_curves', 'is_hit_curves', 'output_curves_inv', 'is_hit_curves_inv'])

## Compute the measures that were not in the file

In [21]:
from tqdm import tqdm
import numpy as np
curves = old_file['output_curves']
hit_curves = old_file['is_hit_curves']
curves_inv = old_file['output_curves_inv']
hit_curves_inv = old_file['is_hit_curves_inv']

qaucs = np.zeros(curves.shape[0])
qauc_invs = np.zeros(curves.shape[0])
qargmaxs = np.zeros(curves.shape[0])
qargmax_invs = np.zeros(curves.shape[0])

for i in tqdm(range(curves.shape[0])):
    qaucs[i] = ff.measure_auc(curves[i])
    qauc_invs[i] = qaucs[i] - ff.measure_auc(curves_inv[i])
    qargmaxs[i] = ff.measure_output_at_first_argmax(curves[i], hit_curves[i])
    qargmax_invs[i] = qargmaxs[i] - ff.measure_output_at_first_argmax(curves_inv[i], hit_curves_inv[i])

100%|██████████| 500000/500000 [00:15<00:00, 32282.74it/s]


In [22]:
np.savez(PATH_DIR + 'genetic_generated.npz', \
         image=old_file['image'], \
         label=old_file['label'], \
         rankings=old_file['rankings'], \
         qmeans=old_file['qmeans'], \
         qmean_invs=old_file['qmean_invs'], \
         qargmaxs=qargmaxs, \
         qargmax_invs=qargmax_invs, \
         qaucs=qaucs, \
         qauc_invs=qauc_invs, \
         output_curves=curves, \
         is_hit_curves=hit_curves, \
         output_curves_inv=curves_inv, \
         is_hit_curves_inv=hit_curves_inv)

# TEST

In [17]:
new_file = ff.load_generated_data(PATH_DIR + 'random_generated.npz')
print(dict(new_file))

{'image': array([[[-0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296],
        [-0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296],
        [-0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.42421296, -0.42421296, -0.42421296, -0.42421296,
         -0.