In [1]:
# Import the necessary libraries
import sys
import os
PROJ_DIR = os.path.realpath(os.path.dirname(os.path.abspath('')))
sys.path.append(os.path.join(PROJ_DIR,'src'))
import xai_faithfulness_experiments_lib_edits as fl
import numpy as np
import captum_generator as cg
import quantus
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device}')

METHOD_NAMES = ['Saliency', 'IntegratedGradients', 'InputXGradient', 'GuidedBackprop', 'Deconvolution', 'LayerGradCam', 'GuidedGradCam']

ACTIVATION_THRESHOLD = 0.9
Z_SCORE_THRESHOLD = 4
DESIRED_EXPLANATIONS = 1000

CHUNKINESS = 32

DATASET = 'imagenet'
MODEL_NAME = 'resnet18'

# Load dataset
if DATASET == '20newsgroups-truncated':
    DATASET_PATH = os.path.join(PROJ_DIR,'assets', 'data', f'{DATASET}.npz')
    # Load dataset
    file_data = np.load(DATASET_PATH)
    x_train = torch.from_numpy(file_data['x_train']).float().to(device)
    y_train = torch.from_numpy(file_data['y_train']).to(device)
    test_loader = [(x_train, y_train)]
else:
    #torch.manual_seed(0)
    test_loader = fl.get_image_test_loader(DATASET, 1000, PROJ_DIR, shuffle = True)


# Load model
if DATASET == 'imagenet':
    network = fl.load_pretrained_imagenet_model(arch = MODEL_NAME.replace('-logits', ''), use_logits = '-logits' in MODEL_NAME)
    DATA_MEAN = [0.485, 0.456, 0.406]
    DATA_STD = [0.229, 0.224, 0.225]
elif DATASET == 'mnist':
    MODEL_PATH = os.path.join(PROJ_DIR,'assets', 'models', f'{DATASET}-{MODEL_NAME}-mlp.pth')
    network = fl.load_pretrained_mnist_model(MODEL_PATH)
    DATA_MEAN = [0.1307]
    DATA_STD = [0.3081]
elif DATASET == 'cifar':
    MODEL_PATH = os.path.join(PROJ_DIR,'assets', 'models', f'{DATASET}-{MODEL_NAME}-mlp.pth')
    network = fl.load_pretrained_cifar_model(MODEL_PATH)
elif DATASET == '20newsgroups-truncated':
    MODEL_PATH = os.path.join(PROJ_DIR,'assets', 'models', f'{DATASET}{MODEL_NAME}-mlp.pth')
    network = fl.load_pretrained_mlp_large_model(MODEL_PATH, x_train.shape[1], 20, [1000, 1000, 800, 500])
    DATA_MEAN = [0.2675, 0.2565, 0.2761]
    DATA_STD = [0.5071, 0.4867, 0.4408]
else:
    raise Exception(f'ERROR: Unknown dataset {DATASET}')

Using cuda:0
Loading Resnet18


In [5]:
import pickle
from tqdm import tqdm

#FILENAME = f'{DATASET}_{MODEL_NAME}_noise_exceptionals_chunky.pkl'
FILENAME = f'{DATASET}_{MODEL_NAME}_exceptionals_fc.pkl'

if os.path.isfile(os.path.join(PROJ_DIR, 'results', FILENAME)):
    with open(os.path.join(PROJ_DIR, 'results', FILENAME), 'rb') as fIn:
        results = pickle.load(fIn)
else:
    # The mean is zero because this dataset is standardized
    num_vars = None
    masking_values = None

    with torch.no_grad():
        valid_elements = []

        for batch_idx, (x_train, y_train) in enumerate(test_loader):
            print(f'Loaded batch  {batch_idx}')
            if masking_values is None:
                masking_values = torch.from_numpy(np.zeros(x_train.shape[1:])).float().to(device)
                num_vars  = 1
                for d in x_train.shape[1:]:
                    num_vars *= d
                num_samples = min(fl.NUM_SAMPLES, num_vars)
                input_shape = x_train.shape[1:]
            # Find elements from the batch that activate the network enough
            outputs = network(x_train.to(device))
            activated_indices = (outputs[torch.arange(x_train.shape[0]), y_train]>ACTIVATION_THRESHOLD).nonzero().flatten()

            for i, sample_index in enumerate(activated_indices):
                if (i+1) % 10 == 0:
                    print(f' Exploring activating sample {i+1}/{activated_indices.size()[0]}')
                row = x_train[sample_index.item()].to(device)
                label = y_train[sample_index.item()].to(device)
            
                x_batch = row.unsqueeze(0).cpu().numpy()
                y_batch = label.unsqueeze(0).cpu().numpy()

                print('Estimating mean and std...')
                #Compute 100 random rankings to compute the average q
                qmeans = []
                for _ in tqdm(range(100)):
                    #For each ranking, retrieve and store Quantus' faithfulness metrics
                    a_batch = fl._get_random_ranking_row(row.shape).unsqueeze(0)
                    
                    qmeans.append(quantus.FaithfulnessCorrelation(
                                                            nr_runs=10,
                                                            subset_size=4,  
                                                            perturb_baseline="black",
                                                            perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
                                                            similarity_func=quantus.similarity_func.correlation_pearson,  
                                                            abs=False,  
                                                            return_aggregate=False,
                                                            disable_warnings=True
                                                        )(model=network, 
                                                        x_batch=x_batch, 
                                                        y_batch=y_batch,
                                                        a_batch=a_batch.cpu().numpy(),
                                                        device=device,
                                                        channel_first=True)[0])


                    #measures = fl.get_measures_for_ranking(row, fl._get_chunky_random_ranking_row(row.shape, CHUNKINESS, CHUNKINESS, True), label, network, with_inverse=False, with_random=False, masking_values=masking_values)
                    #qmeans.append(measures['mean'])
                qmean_mean = np.mean(qmeans)
                qmean_std = np.std(qmeans)
                print('mean', qmean_mean, 'std', qmean_std)
        
                #Grab captum generated explanations and check their z-index
                captum_rankings = torch.tensor(cg.generate_rankings(row, label, network)).to(device)
                for method_index, ranking in enumerate(captum_rankings):
                    #measures = fl.get_measures_for_ranking(row, ranking, label, network, with_inverse=False, with_random=False, masking_values=masking_values)
                    #zscore = (measures['mean'] -  qmean_mean) /  qmean_std

                    fc = quantus.FaithfulnessCorrelation(
                                                            nr_runs=10,
                                                            subset_size=4,  
                                                            perturb_baseline="black",
                                                            perturb_func=quantus.perturb_func.baseline_replacement_by_indices,
                                                            similarity_func=quantus.similarity_func.correlation_pearson,  
                                                            abs=False,  
                                                            return_aggregate=False,
                                                            disable_warnings=True
                                                        )(model=network, 
                                                        x_batch=x_batch, 
                                                        y_batch=y_batch,
                                                        a_batch=ranking.unsqueeze(0).cpu().numpy(),
                                                        device=device,
                                                        channel_first=True)[0]
                    zscore = (fc -  qmean_mean) /  qmean_std

                    print(f'Method #{method_index}: {zscore}')

                    if zscore > 4:
                        valid_elements.append({'row': row,\
                                            'ranking': ranking,\
                                            'label': label,\
                                            'qmean_mean': qmean_mean,\
                                            'qmean_std': qmean_std,\
                                            'method': method_index
                                            })
                        print(f'{len(valid_elements)}/{DESIRED_EXPLANATIONS}')
                        if len(valid_elements) >= DESIRED_EXPLANATIONS:
                            break
                if len(valid_elements) >= DESIRED_EXPLANATIONS:
                    break

    results = []
    for v in valid_elements:
        row  = v['row']
        ranking = v['ranking']
        label = v['label']
        measures = fl.get_measures_for_ranking(row, ranking, label, network, with_inverse=True, with_random=True, masking_values=masking_values, noisy_inverse=True)
        v['qmean'] = measures['mean']
        v['qinv'] = measures['mean_inv']
        v['qbas'] = measures['mean_bas']
        results.append(v)
    with open(os.path.join(PROJ_DIR, 'results', FILENAME), 'wb') as fOut:
        pickle.dump(results, fOut)
                

Loaded batch  0
Estimating mean and std...


100%|██████████| 100/100 [00:04<00:00, 20.77it/s]


mean 0.03830202331943748 std 0.3306022372672858
Method #0: 0.35844624097077227
Method #1: -2.223316608934815
Method #2: 0.07022504853255701
Method #3: 1.7396248933815515
Method #4: -1.2338932601738555
Method #5: -0.41872867666989705
Method #6: -0.45402814856493495
Estimating mean and std...


100%|██████████| 100/100 [00:04<00:00, 21.06it/s]


mean -0.03969074446130821 std 0.3345781660762455
Method #0: 1.6142461829915913
Method #1: -0.9731085074880945
Method #2: 1.8220839182220863
Method #3: -0.6001254367297141
Method #4: 1.6239881312929938
Method #5: 1.0587795720430302
Method #6: -0.8211470074107878
Estimating mean and std...


 25%|██▌       | 25/100 [00:01<00:03, 20.54it/s]


KeyboardInterrupt: 

In [None]:
#filtered = [x for x in results if x['qmean']>0.5]
filtered = results
#filtered = [x for x in results if x['method']==5]

from scipy.stats import spearmanr
from matplotlib import pyplot as plt
zindices = list(map(lambda x: (x['qmean']-x['qmean_mean'])/x['qmean_std'], filtered))
#zindices = list(map(lambda x: x['qmean'], filtered))
qinv = list(map(lambda x: x['qinv'], filtered))
qbas = list(map(lambda x: x['qbas'], filtered))

print(spearmanr(zindices, qinv).statistic)
plt.scatter(zindices,  qinv, s=0.1)
plt.show()
print(spearmanr(zindices, qbas).statistic)
plt.scatter(zindices,  qbas, s=0.1)
plt.show()

In [None]:
plt.hist(zindices, bins=100)
plt.show()

In [None]:
NUM_ELEMS = 10
NUM_CURVES = 10

#indices = torch.tensor(qinv).topk(NUM_ELEMS, largest=False).indices
indices = torch.tensor(zindices).topk(NUM_ELEMS, largest=True).indices

data_std_tensor = torch.tensor(np.reshape(DATA_STD, (3,1,1))).to(device)
data_mean_tensor = torch.tensor(np.reshape(DATA_MEAN, (3,1,1))).to(device)

for i in indices:
    v = filtered[i]
    print(METHOD_NAMES[v['method']])
    print(v.keys())
    row  = v['row']
    ranking = v['ranking']
    inverse_ranking = v['inverse_ranking']
    label = v['label']
    masking_values = torch.from_numpy(np.zeros(row.shape)).float().to(device)
    measures = fl.get_measures_for_ranking(row, ranking, label, network, with_inverse=True, with_random=True, masking_values=masking_values)
    measures_inverse = fl.get_measures_for_ranking(row, inverse_ranking, label, network, with_inverse=False, with_random=False, masking_values=masking_values)
    
    masking_values_red = torch.clone(masking_values)
    masking_values_red[1,:,:] = 1
    printable_image_tensor = row * data_std_tensor + data_mean_tensor

    print('Mean', measures['mean'])
    print('qmean', v['qmean_mean'])
    print('qstd',v['qmean_std'])
    print(network(row.unsqueeze(0)).max())
    fig, axs = plt.subplots(1, 3, figsize=(10, 2))
    axs[0].plot(measures['output_curve'], color='green', label='r')
    axs[0].plot(measures['output_curve_inv'], color='orange', label='i')
    axs[0].plot(measures_inverse['output_curve'], color='red', label='i-ff')
    axs[0].plot(measures['output_curve_bas'], color='gray', linewidth=0.2)
    axs[0].set_title('Chunky randoms')
    axs[0].legend()
    for _ in range(NUM_CURVES):
        measures_random = fl.get_measures_for_ranking(row, fl._get_chunky_random_ranking_row(ranking.shape, CHUNKINESS, CHUNKINESS, True), label, network, with_inverse=False, with_random=False, masking_values=masking_values)
        axs[0].plot(measures_random['output_curve'], color='gray', linewidth=0.2)
    axs[1].imshow(np.moveaxis(printable_image_tensor.detach().cpu().numpy(), 0, -1))
    axs[2].imshow(v['ranking'].sum(axis=0).detach().cpu().numpy(), cmap='plasma')
    plt.show()

    fig, axs = plt.subplots(1, 3, figsize=(10, 2))
    axs[0].plot(measures['output_curve'], color='green')
    axs[0].plot(measures['output_curve_inv'], color='orange')
    axs[0].plot(measures_inverse['output_curve'], color='red')
    axs[0].plot(measures['output_curve_bas'], color='gray', linewidth=0.2)
    axs[0].set_title('Regular randoms')
    for _ in range(NUM_CURVES):
        measures_random = fl.get_measures_for_ranking(row, fl._get_random_ranking_row(ranking.shape), label, network, with_inverse=False, with_random=False, masking_values=masking_values)
        axs[0].plot(measures_random['output_curve'], color='gray', linewidth=0.2)
    axs[1].imshow(np.moveaxis(printable_image_tensor.detach().cpu().numpy(), 0, -1))
    axs[2].imshow(v['ranking'].sum(axis=0).detach().cpu().numpy(), cmap='plasma')
    plt.show()

    masked = fl._get_masked_inputs(row, masking_values, ranking, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, ranking, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m, r) in enumerate(zip(masked, masked_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('Ranking')
    plt.show()
    masked_inverse = fl._get_masked_inputs(row, masking_values, torch.tensor(fl._attributions_to_ranking_row(ranking.flatten().detach().cpu(), reverse=True).reshape(ranking.shape)).to(device), torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_inverse_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, torch.tensor(fl._attributions_to_ranking_row(ranking.flatten().detach().cpu(), reverse=True).reshape(ranking.shape)).to(device), torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m,r) in enumerate(zip(masked_inverse, masked_inverse_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('Inverse')
    plt.show()
    masked_inverse = fl._get_masked_inputs(row, masking_values, v['inverse_ranking'], torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_inverse_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, v['inverse_ranking'], torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m,r) in enumerate(zip(masked_inverse, masked_inverse_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('Inverse in file')
    plt.show()
    random_row = fl._get_random_ranking_row(ranking.shape)
    masked_random = fl._get_masked_inputs(row, masking_values, random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_random_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m,r) in enumerate(zip(masked_random,masked_random_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('Random')
    plt.show()
    random_row = fl._get_random_ranking_row((1,*ranking.shape[1:]))
    random_row = random_row.repeat((ranking.shape[0],1,1))
    masked_random = fl._get_masked_inputs(row, masking_values, random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_random_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m,r) in enumerate(zip(masked_random,masked_random_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('1c Random')
    plt.show()
    chunky_random_row = fl._get_chunky_random_ranking_row(ranking.shape, 32, 32, True)
    masked_random = fl._get_masked_inputs(row, masking_values, chunky_random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    masked_random_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, chunky_random_row, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
    fig, axs = plt.subplots(1, 4, figsize=(10, 2))
    for pos, (m,r) in enumerate(zip(masked_random,masked_random_printable)):
        axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
        axs[pos].set_yticks([])
        axs[pos].set_xticks([])
        axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
    plt.suptitle('Chunky random')
    plt.show()

In [None]:
from captum.attr import GuidedGradCam,LayerGradCam,LayerAttribution
explanation = GuidedGradCam(network, network.network.features[28]).attribute(inputs=row.unsqueeze(0), target=label.unsqueeze(0))
print(explanation.shape)
explanation=LayerAttribution.interpolate(explanation, row.shape[1:])
#explanation = torch.stack([explanation, explanation, explanation], dim=1).squeeze()
explanation = explanation.squeeze()
print(printable_image_tensor.shape)
print(explanation.shape)
plt.imshow(np.moveaxis(printable_image_tensor.detach().cpu().numpy(),0,-1))
plt.show()
printable_explanation = (explanation - explanation.min())/(explanation.max()-explanation.min())
plt.imshow(np.moveaxis(printable_explanation.detach().cpu().numpy(),0,-1))
plt.show()


ranking = torch.tensor(fl._attributions_to_ranking_row(explanation.flatten().detach().cpu(), reverse=False).reshape(ranking.shape)).to(device)
plt.imshow(np.moveaxis(ranking.detach().cpu().numpy(),0,-1))
plt.show()
masked = fl._get_masked_inputs(row, masking_values, ranking, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))
masked_printable = fl._get_masked_inputs(printable_image_tensor, masking_values_red, ranking, torch.tensor([0.2, 0.4, 0.6, 0.8]).to(device))

fig, axs = plt.subplots(1, 4, figsize=(10, 2))
for pos, (m, r) in enumerate(zip(masked, masked_printable)):
    axs[pos].imshow(np.moveaxis(r.detach().cpu().numpy(), 0, -1))
    axs[pos].set_yticks([])
    axs[pos].set_xticks([])
    axs[pos].set_title(f'{network(m.unsqueeze(0)).squeeze()[label].item():.4f}')
plt.suptitle('Ranking')
plt.show()

In [None]:
max([i for i in range(len(network.network.features)) if type(network.network.features[i]) == torch.nn.modules.conv.Conv2d])

In [None]:
for m in network.modules():
    print(type(m))