In [None]:
# %% [markdown]
# # Measure performance
# This notebook loads a file with precomputed measures (*qmeans*, *qbas* & *qinv*) for a set of rankings for a given instance of the dataset and measures the performance of the different alternative measures
# 
# ## 1. Load libraries, model and data

# %%

# Import the necessary libraries
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(''))))
sys.path.append(os.path.join(PROJ_DIR,'src'))
import xai_faithfulness_experiments_lib_edits as fl
import numpy as np
from typing import Optional
from matplotlib import pyplot as plt
from quantus import AttributionLocalisation

DATASET = 'cmnist'
MODEL_NAME = 'resnet18'
REFERENCE_MODEL_NAME = MODEL_NAME#'resnet50w'
GENERATION = '_randomattr'
TARGET_MEASURE = 'AttributionLocalisation'#'faithfulness_correlation'
at = AttributionLocalisation(abs=False, normalise=False, disable_warnings=True)
network = fl.load_pretrained_cmnist_resnet18_model(os.path.join(PROJ_DIR,'assets','models','cmnist-resnet18.pth'))

DATA_MEAN = [0.2675, 0.2565, 0.2761]
DATA_STD = [0.5071, 0.4867, 0.4408]

q_invs = []
q_invs_file = []
q_bas = []


for FILENAME in os.listdir(os.path.join(PROJ_DIR,'results')):
    if FILENAME.startswith(DATASET) and FILENAME.endswith(f'{MODEL_NAME}{GENERATION}_localization_s_area_measures.npz'):
        print(FILENAME)

        # Load data
        data = fl.load_generated_data(os.path.join(PROJ_DIR, 'results', FILENAME))
        
        qmeans = data[TARGET_MEASURE]
        qmeans_inv = data[TARGET_MEASURE + '_inv']
        rankings = data['rankings']
        mask = data['s_mask']
        maskb = mask.astype(bool).flatten()

        for i in range(min(rankings.shape[0], 5)):
            c = rankings[i]
            c_inv = fl.get_inverse(c)
            c_positive =  c#np.where(c>0, c, 0)
            cinv_positive =  c_inv#np.where(c_inv>0, c_inv, 0)
            fig, axs = plt.subplots(2,3)
            axs[0][0].imshow(np.moveaxis(data['row'], 0, -1) * DATA_STD + DATA_MEAN, cmap='winter')
            axs[0][1].imshow(np.moveaxis(c, 0, -1).sum(axis=2))
            axs[0][2].imshow(np.moveaxis(c_inv, 0, -1).sum(axis=2))
            axs[1][0].imshow(np.moveaxis(mask, 0, -1))
            axs[1][1].imshow(np.moveaxis((c_positive*mask), 0, -1).sum(axis=2))
            axs[1][2].imshow(np.moveaxis((cinv_positive*mask), 0, -1).sum(axis=2))
            plt.title(f'{data["label"]} ({data["label"]})')
            plt.show()
            
            # Prepare shapes.
            a = c.sum(axis=0).flatten()
            s = np.expand_dims(data['s_mask'], axis=0).flatten().astype(bool)

            # Compute ratio.
            size_bbox = float(np.sum(s))
            size_data = s.size
            ratio = size_bbox / size_data

            # Compute inside/outside ratio.
            inside_attribution = np.sum(a[s])
            total_attribution = np.sum(a)
            inside_attribution_ratio = float(inside_attribution / total_attribution)

            print('inside_attribution', inside_attribution)
            print('total_attribution', total_attribution)
            print('inside_attribution_ratio', inside_attribution_ratio)

            localization = at(model=network, 
                            x_batch=np.expand_dims(data['row'], axis=0), 
                            y_batch=np.expand_dims(data['label'], axis=0),
                            a_batch=np.expand_dims(c.sum(axis=0, keepdims=True), axis=0),
                            s_batch=np.expand_dims(data['s_mask'], axis=0))[0]
            localization_random = at(model=network, 
                            x_batch=np.expand_dims(data['row'], axis=0), 
                            y_batch=np.expand_dims(data['label'], axis=0),
                            a_batch=np.expand_dims(np.random.normal(size=c.shape).sum(axis=0, keepdims=True), axis=0),
                            s_batch=np.expand_dims(data['s_mask'], axis=0))[0]
            print('localization:',localization)
            localization_inv = at(model=network, 
                            x_batch=np.expand_dims(data['row'], axis=0), 
                            y_batch=np.expand_dims(data['label'], axis=0),
                            a_batch=np.expand_dims(c_inv.sum(axis=0, keepdims=True), axis=0),
                            s_batch=np.expand_dims(data['s_mask'], axis=0))[0]
            print('localization_inv:',localization_inv)
            print('localization_random:',localization_random)
            print('localization wrt inv:',localization - localization_inv)
            print('localization_bas:',localization - localization_random)
            #print((np.moveaxis(c_positive, 0, -1).sum(axis=2)*mask).sum())
            #print((np.moveaxis(cinv_positive, 0, -1).sum(axis=2)*mask).sum())
            print('qmeans',qmeans[i])
            print('qmeans_inv',qmeans_inv[i])
            print('-'*20)
            q_invs.append((localization, localization - localization_inv))
            q_bas.append((localization, localization - localization_random))
            q_invs_file.append((localization, qmeans_inv[i]))

qmean_mean = np.mean(qmeans)
qmean_std = np.std(qmeans)

plt.scatter(list(map(lambda x:(x[0]-qmean_mean)/qmean_std,q_invs)), list(map(lambda x:x[1],q_invs)), label='qinvs')
plt.scatter(list(map(lambda x:(x[0]-qmean_mean)/qmean_std,q_bas)), list(map(lambda x:x[1],q_bas)), label='qbas')
plt.legend()
plt.show()

plt.scatter(list(map(lambda x:(x[0]-qmean_mean)/qmean_std,q_invs_file)), list(map(lambda x:x[1],q_invs)), label='qinvs')
plt.scatter(list(map(lambda x:(x[0]-qmean_mean)/qmean_std,q_bas)), list(map(lambda x:x[1],q_bas)), label='qbas')
plt.legend()
plt.show()


In [None]:
# Compute z-score
qmean_mean = np.mean(qmeans)
qmean_std = np.std(qmeans)
z_scores = ((qmeans - qmean_mean) / qmean_std).flatten()

plt.scatter(z_scores, qmeans_inv)
plt.show()

In [None]:
# Compute qmeans_bas[2-10]
def compute_qbas(measure, num_samples, reference:np.ndarray):
    random_indices = np.random.randint(0, measure.shape[0], (measure.shape[0], num_samples))
    random_qmeans = reference[random_indices]
    mean = np.mean(random_qmeans, axis=1)

    # First way to deal with std==0; add some epsilon
    #std = np.std(random_qmeans, axis=1) + 1e-10

    # Second way to deal with std==0; ignore std (divide by 1)
    std = np.std(random_qmeans, axis=1)
    std[std==0] = 1

    # Always ignore std
    std=1
    return (measure - mean) / std

plt.scatter(z_scores, compute_qbas(qmeans, 1, qmeans))
plt.show()