In [1]:
import os
import torch
import sys
sys.path.append('../../lib/exlib/src')

import sys
sys.path.append('../../src')
import sop

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sop.utils.seed_all(42)

# config
exp_config = sop.ImageNetConfig()
val_config = exp_config.get_config('val_sm')
val_config['evaluation']['batch_size'] = 16
val_config

[nltk_data] Downloading package punkt to /home/runai-home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/runai-home/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


{'dataset': {'name': 'imagenet_s',
  'root': '/shared_data0/weiqiuy/datasets/imagenet'},
 'training': {'batch_size': 16,
  'num_epochs': 20,
  'mask_batch_size': 64,
  'optimizer': {'name': 'adamw', 'lr': 5e-06, 'weight_decay': 0.01}},
 'evaluation': {'split': 'val', 'num_data': 1, 'batch_size': 16},
 'model': {'type': 'vit',
  'base': 'google/vit-base-patch16-224',
  'sop': '/shared_data0/weiqiuy/sop/exps/imagenet_lr5e-06_tgtnnz0.2_gg0.0600_gs0.0100_ft_identify_fixk_scratch_ks3/best',
  'num_classes': 1000}}

In [2]:
backbone_model, original_model, processor, backbone_config, model, config = sop.tasks.imagenet.get_model(val_config['model']['type'],
                                                                 backbone_model_name=val_config['model']['base'],
                                                                 backbone_processor_name=val_config['model']['base'],
                                                                 sop_model_name=val_config['model']['sop'], 
                                                                 # sop_model_name='/shared_data0/weiqiuy/sop/exps/imagenet_lr5e-06_tgtnnz0.2_gg0.0600_gs0.0100_ft_identify_fixk_scratch_ks3/last',
                                                                 eval_mode=True, wrap_proj=True
                                                                                                        )

backbone_model = backbone_model.to(device)
original_model = original_model.to(device)
model = model.to(device)

projection layer is not frozen
Loaded step 40100


In [3]:
from sop.utils.metric_utils import get_entropy, get_prob_obj, get_prob_obj_coverage, get_iou

In [7]:
import numpy as np
from tqdm.auto import tqdm
from sop.metrics import get_acc
from sop.tasks.images.imagenet import get_explainer
from exlib.modules.sop import convert_idx_masks_to_bool

methods = [
    'bcos',
    'xdnn',
    'bagnet',
    'sop',
    'shap_20',
    'rise_20',
    'lime_20',
    'fullgrad',
    'gradcam',
    'intgrad',
    'attn',
    'archipelago',
    'mfaba',
    'agi',
    'ampe',
]


debug = False
k = 0.2

# method = 'lime_20'
# explainer_name = method.split('_')[0]
# method = 'shap_20'
results_all = {}

for method in methods:
    print('method', method)
    if method == 'sop':
        explainer = model
    else:
        explainer = get_explainer(original_model, backbone_model, method.split('_')[0], device)

    method_list = method.split('_')
    explainer_name = method_list[0]

    if len(method_list) == 2:
        suffix = f'_{method_list[1]}'
    else:
        suffix = ''

    if method not in ['sop', 'bagnet']: #, 'xdnn', 'bagnet']:
        ATTR_VAL_DATA_DIR = f'/shared_data0/weiqiuy/sop/exps/imagenet_vit_1/attributions_seg/{explainer_name}_1_pred{suffix}/val'
    else:
        ATTR_VAL_DATA_DIR = None
    
    val_dataset, val_dataloader = sop.tasks.imagenet.get_dataset(val_config['dataset']['name'], 
                                              split=val_config['evaluation']['split'], 
                                              num_data=val_config['evaluation']['num_data'],
                                              batch_size=val_config['evaluation']['batch_size'],
                                                            attr_dir=ATTR_VAL_DATA_DIR,
                                              processor=processor, debug=debug)

    # for k in tqdm(np.linspace(0.1, 1, 10)):
    k = 0.2
    # results = get_acc(val_dataloader, explainer, method, device, k=k, eval_all=True, built_in=True)
    
    entropies = []
    ratios_obj_coverage = []
    ratios_obj = []
    ious = []
    accs = []
    
    for bi, batch in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
        if bi > 10:
            break
        if len(batch) == 5:
            inputs, labels, segs, attrs, idxs = batch
        else:
            inputs, labels, segs, idxs = batch
        inputs, labels, segs = inputs.to(device), labels.to(device), segs.to(device)
        # with torch.no_grad():
        #     original_logits = original_model(inputs)
        # preds = torch.argmax(original_logits, dim=-1)
        # Get explanation
        # expln = explainer(inputs, preds)

        # if method != 'backbone':
        
        masks_all = []
        for idx in range(len(inputs)):
            if method in ['bagnet', 'sop']:
                # print('recompute')
                if method == 'bagnet':
                    expln = explainer(inputs[idx:idx+1], return_groups=True)
                    masks = expln.group_masks[0]
                    mask_weights = expln.group_attributions[0].flatten()
                else: # sop
                    expln = explainer(inputs[idx:idx+1], return_tuple=True)
                    pred = expln.logits.argmax(-1)
                    masks = expln.masks[0]
                    # print('expln.masks', expln.masks.shape)
                    # print('expln.mask_weights', expln.mask_weights.shape)
                    mask_weights = expln.mask_weights[0,:,pred].flatten()
                # print('masks', masks.shape)
                # print('mask_weights', mask_weights.shape)
                
                
                # Sort the masks based on mask_weights
                sort_idxs = torch.argsort(mask_weights)
                masks = masks[sort_idxs]  # Sort masks accordingly
                mask_weights = mask_weights[sort_idxs]

                # Cumulative sum of sorted masks
                topk = int(masks.shape[-1] * masks.shape[-2] * k)
                masks_cumsum = masks.cumsum(dim=0).bool().float()
                # print('masks_cumsum', masks_cumsum.shape)
                # Calculate the sum along the last two dimensions
                masks_cumsum_sum = masks_cumsum.sum((-1, -2))
                # print('masks_cumsum_sum', masks_cumsum_sum.shape)

                # Find the index where the sum first exceeds topk
                mask_index = torch.searchsorted(masks_cumsum_sum, topk)

                # Ensure we don't exceed the tensor's first dimension
                mask_index = torch.clamp(mask_index, max=masks_cumsum.shape[0] - 1)

                mask = masks_cumsum[mask_index]
            else:
                # Create a mask of size (28, 28) with values from 1 to 28*28
                cell_size = 14
                image_size = 224
                mask = torch.arange(1, cell_size*cell_size + 1, dtype=torch.int).reshape(cell_size, cell_size)

                # Resize the mask to (224, 224) without using intermediate floating point numbers
                # This can be achieved by repeating each value in both dimensions to scale up the mask
                scale_factor = image_size // cell_size  # Calculate scale factor
                resized_mask = mask.repeat_interleave(scale_factor, dim=0).repeat_interleave(scale_factor, dim=1)

                masks = convert_idx_masks_to_bool(resized_mask[None]).to(device)
                # print('inputs', inputs.shape)
                # print('masks', masks.shape)
                # print('attrs', attrs.shape)
                mask_weights = (masks.to(device) * attrs[idx][0:1].to(device)).sum(-1).sum(-1).to(device)
                sort_idxs = torch.argsort(mask_weights).flip(-1)
                masks = masks[sort_idxs]
                mask_weights = mask_weights[sort_idxs]

                topk = int(masks.shape[0] * k)
                masks_use = masks[:topk]
                mask = masks_use.sum(0)

            # metrics
            entropy = get_entropy(mask[None], segs[idx])
            entropies.append(entropy.item())
            ratios_obj_coverage.extend(get_prob_obj_coverage(mask[None], segs[idx]).view(-1).cpu().numpy().tolist())
            ratios_obj.extend(get_prob_obj(mask[None], segs[idx]).view(-1).cpu().numpy().tolist())
            ious.extend(get_iou(mask[None], segs[idx]).view(-1).cpu().numpy().tolist())

            masks_all.append(mask)

        masks_all = torch.stack(masks_all, dim=0)

    results_all[method] = {
        'entropies': entropies,
        'ratios_obj_coverage': ratios_obj_coverage,
        'ratios_obj': ratios_obj,
        'ious': ious
        # 'accs': accs
    }
    
    print(method, np.mean(results_all[method]['ious']))
    
    # results_all[method] = results
    # print(method, results_all[method]['acc'])
    # break

method bcos


Using cache found in /home/runai-home/.cache/torch/hub/B-cos_B-cos-v2_main


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

bcos 0.2280362893083789
method xdnn


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

xdnn 0.3363624789371071
method bagnet


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

bagnet 0.3129487893522971
method sop


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

sop 0.6320557505770755
method shap_20


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

shap_20 0.39217175372323254
method rise_20


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

rise_20 0.1287941942821569
method lime_20


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

lime_20 0.3596765278330581
method fullgrad
init fullgrad
init fullgrad done


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

fullgrad 0.5198513122079682
method gradcam


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

gradcam 0.39432331937066256
method intgrad


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

intgrad 0.3209664839880795
method attn


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

attn 0.46815120977010916
method archipelago


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

archipelago 0.2870745841673025
method mfaba


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

mfaba 0.39315565751696174
method agi


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

agi 0.4401441326534206
method ampe


  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 919 images and 1000 classes


  0%|          | 0/58 [00:00<?, ?it/s]

ampe 0.41836462559347803


In [8]:
# torch.save(results_all, 'bagnet_purity.pt')
torch.save(results_all, 'imagenets_purity_iou_sm.pt')

In [8]:
results_all[method] = {
        'entropies': entropies,
        'ratios_obj_coverage': ratios_obj_coverage,
        'ratios_obj': ratios_obj,
        # 'accs': accs
    }

In [14]:
for k in results_all[method]:
    print(k, np.mean(results_all[method][k]))

entropies 0.6845371479450262
ratios_obj_coverage 0.2725739830633513
ratios_obj 0.3585904936863047


In [4]:
# torch.save(results_all, 'bagnet_purity.pt')
results_all = torch.load('bagnet_purity.pt')