In [2]:
import pickle
import os.path as path
import os
import torch
import numpy as np

In [3]:
def compute_rank_correlation(map1, map2, flags=1, map_size=None):
    """
    Function that measures Spearman’s correlation coefficient between target and output logits:
    """

    def _rank_correlation_(map1_rank, map2_rank, map_size):
        n = torch.tensor(map_size)
        upper = 6 * torch.sum((map2_rank - map1_rank).pow(2), dim=1)  # [batch] or [batch, g]
        down = n * (n.pow(2) - 1.0)
        return (1.0 - (upper / down))

    if map_size is None:
        map_size = map1.shape[1]

    # get the rank for each element, we use sort function two times
    map1 = map1.argsort(dim=1)  # [batch , num_objs] or [batch, num_objs, g]
    map1_rank = map1.argsort(dim=1) # [batch , num_objs] or [batch, num_objs, g]

    map2 = map2.argsort(dim=1)  # [batch , num_objs] or [batch, num_objs, 1]
    map2_rank = map2.argsort(dim=1)  # [batch , num_objs] or [batch, num_objs, 1]

    correlation = _rank_correlation_(map1_rank.float(), map2_rank.float(), map_size)
    return correlation * flags


def calculate_violators(top_ans_deltas, weights, verbose=False):
    max_ind = weights.abs().topk(k=1, dim=1)[1] # first or last
    max_weights = weights.gather(dim=1, index=max_ind).squeeze()
    max_weight_deltas = top_ans_deltas.gather(dim=1, index=max_ind).squeeze()
    ind = (max_weights * max_weight_deltas)<0
    if verbose:
        print("num. of violators: {}, ratio:{:.2%}".format(ind.sum(), ind.sum()/top_ans_deltas.shape[0]))
    return ind

In [4]:
METHOD_NAME = {'attn_norm': 'AttIN', 'ig':'IG', 'raw_attn':'RawAtt', \
    'attn_grad':'AttGrad', 'inputGrad':'InputGrad', 'rollout':'Rollout', \
    'ours_no_lrp':'GenAtt', 'transformer_attribution':'TransAtt', 'partial_lrp':'PLRP', \
    'transformer_att':'TransAtt', 'rand0': 'Random', 'rand': 'Random'}

ordered_names = ['Random', 'AttIN','IG','RawAtt','AttGrad','InputGrad','Rollout','PLRP','GenAtt','TransAtt']

In [5]:
from sklearn.metrics import auc

def run_exp(output, exp_metric="AUCTP", plot=False, verbose=True, only_vio=False):
    assert exp_metric in ["AUCTP", "Sufficiency", "Comprehensiveness", "Correlation", "Violation", "RC"]

    preds, weights, top_ans_deltas, all_pred_deltas, newpred_accs, accs, mask_ratios = output
    exp_values = {}
    if exp_metric == "AUCTP":
        # Calculate Comprehensiveness Here
        comprehensiveness = top_ans_deltas.squeeze().mean(dim=0) # [N, K] -> [K]
        ratio_index = [0,1,2,5] # [5%, 10%, 20%, 50%]
        # print(mask_ratios[ratio_index])
        comprehensiveness = comprehensiveness[ratio_index].mean()
        exp_values["Comprehensiveness"] = comprehensiveness.item()

        # Calculate AUCTP Here (https://github.com/copenlu/xai-benchmark)
        scores = [accs.mean()]
        for i in range(len(mask_ratios)):
            scores.append(newpred_accs[:, i].mean()) # [0%, 5%, 10%, 20%, 30%, ..., 90%]

        mask_ratios = np.insert(mask_ratios, 0, 0)
        exp_values["AUCTP"] = auc(mask_ratios, scores)
        if verbose:
            print(exp_values)
            print(mask_ratios, scores)

        if plot:
            auctp = []
            all_accs = torch.cat([accs.unsqueeze(1), newpred_accs], dim=1).numpy()
            for i in range(accs.shape[0]):
                auctp.append(auc(mask_ratios, all_accs[i]))
            return {"Comprehensiveness": top_ans_deltas[:, ratio_index].mean(1).numpy(),\
                 "AUCTP": np.array(auctp)}, exp_values


    elif exp_metric in ["Sufficiency", "Comprehensiveness"]:
        if verbose:
            print("Calculate {} for {}".format(exp_metric, mask_ratios))
        scores = top_ans_deltas.squeeze().mean(dim=0) # eg, [5%, 10%, 20%, 50%]
        exp_values[exp_metric] = scores.mean().item()
        if plot:
            return {exp_metric: top_ans_deltas.mean(1).numpy()}, exp_values

    elif exp_metric == "Violation":
        vio_ind = calculate_violators(top_ans_deltas, weights)
        VioRatio = vio_ind.sum() / top_ans_deltas.shape[0]
        exp_values["Violation"] = VioRatio.item()
        if only_vio:
            return {"Violation": vio_ind, "weights": weights,"top_ans_deltas": top_ans_deltas, "preds":preds}, exp_values
        if verbose:
            print("num. of violators: {}, ratio:{:.2%}".format(vio_ind.sum(), VioRatio))
        if plot:
            return {"Violation": vio_ind}, exp_values

    elif exp_metric == "RC":
        RCC = compute_rank_correlation(weights.abs(), all_pred_deltas.abs())
        if verbose:
            print(" rcc:{}".format(RCC.mean()))
        exp_values["RC"] = RCC.mean().item()
        if plot:
            return {"RC": RCC}, exp_values

    return preds, weights, top_ans_deltas, all_pred_deltas, mask_ratios, exp_values

In [7]:
all_exp_metric_names = ["AUCTP", "Violation", "Sufficiency", "Comprehensiveness", "RC"]
all_replace_names = ["slice_out", "zeros_mask", "att_mask"]
all_exp_names = ["rand", "ig", "att", "att*grad", "att*grad_only_grad", "att*grad_sign", "grad_cam", "input*grad"]

exp_names = ["att_norm", "att", "att*grad", "input*grad", "ig"]
replace_names = ["slice_out", "zeros_mask", "att_mask"]
exp_metric_names = ["AUCTP", "Violation", "Sufficiency", "Comprehensiveness", "RC"]




def read_results(model_path, dir_path='./data/', return_list=False):
    cached_output_dir = path.join(dir_path, model_path)
    exp_to_directories = {}
    for d in os.listdir(cached_output_dir):
        dir_path = path.join(cached_output_dir, d)
        if os.path.isdir(dir_path):
            exp_to_directories[d] = dir_path

    all_output = {}
    exp_scores, exp_all_dict = {}, {}
    for exp_name in exp_to_directories:
        print("######################", exp_name, "######################")
        if exp_name not in exp_scores:
            exp_scores[exp_name] = {}
            exp_all_dict[exp_name] = {}
            all_output[exp_name] = {}
        
        dict_output = all_output[exp_name]
        scores = exp_scores[exp_name]
        dict_scores =  exp_all_dict[exp_name]

        for exp_metric in exp_metric_names:
            dict_scores[exp_metric] = {}
            dict_output[exp_metric] = {}
            scores[exp_metric] = 0
        
        dir_name = exp_to_directories[exp_name]
        for file_name in os.listdir(dir_name):
            mask_type, metric_name = '', ''
            for m in replace_names:
                list_names = file_name.split(m)
                if list_names[-1] in ['', '.pkl']:
                    metric_name = list_names[0].split('_')[0]
                    mask_type = m
            assert metric_name != ''

            with open(path.join(dir_name, file_name), 'rb') as f:
                outputs = pickle.load(f)
                top_ans_deltas = torch.from_numpy(outputs['top_ans_deltas'])
                newpred_accs = torch.from_numpy(outputs['newpred_accs'])
                all_pred_deltas = torch.from_numpy(outputs['all_pred_deltas'])
                preds = torch.from_numpy(outputs['preds'])
                weights = torch.from_numpy(outputs['weights'])
                accs = torch.from_numpy(outputs['accs'])
                mask_ratios = outputs['mask_ratios']
            
            output_list = [preds, weights, top_ans_deltas, all_pred_deltas, newpred_accs, accs, mask_ratios]
            results = run_exp(output_list, metric_name)
            if return_list:
                print('return', exp_name, metric_name, mask_type)
                dict_output[metric_name][mask_type] = output_list

            for calculated_metric in results[-1]:
                dict_scores[calculated_metric][mask_type] = results[-1][calculated_metric]
                scores[calculated_metric] += results[-1][calculated_metric]
            print(results[-1].keys(),)

        exp_scores[exp_name] = {n: (scores[n]/3) for n in scores}

    if return_list:
        return exp_scores, exp_all_dict, all_output
    else:
        return exp_scores, exp_all_dict, output_list

In [8]:
model_path = "visual_bert_gqa2" # lxmert_vqa2
exp_scores, exp_all_dict, output_list = read_results(model_path)

###################### inputGrad ######################
num. of violators: 4062, ratio:40.62%
dict_keys(['Violation'])
num. of violators: 4143, ratio:41.43%
dict_keys(['Violation'])
num. of violators: 4066, ratio:40.66%
dict_keys(['Violation'])
 rcc:0.12318931519985199
dict_keys(['RC'])
 rcc:0.5555236339569092
dict_keys(['RC'])
 rcc:0.5550675988197327
dict_keys(['RC'])
{'Comprehensiveness': 0.048775769770145416, 'AUCTP': 0.5115900069475173}
[0.   0.05 0.1  0.2  0.3  0.4  0.5  0.6  0.7  0.8  0.9 ] [tensor(0.6075), tensor(0.5966), tensor(0.5871), tensor(0.5777), tensor(0.5739), tensor(0.5682), tensor(0.5646), tensor(0.5599), tensor(0.5558), tensor(0.5524), tensor(0.5458)]
dict_keys(['Comprehensiveness', 'AUCTP'])
{'Comprehensiveness': 0.03237999975681305, 'AUCTP': 0.5202600002288819}
[0.   0.05 0.1  0.2  0.3  0.4  0.5  0.6  0.7  0.8  0.9 ] [tensor(0.6075), tensor(0.5961), tensor(0.5873), tensor(0.5824), tensor(0.5792), tensor(0.5771), tensor(0.5764), tensor(0.5754), tensor(0.5724), tenso