In [1]:
import os
import torch
import sys
sys.path.append('../../lib/exlib/src')

import sys
sys.path.append('../../src')
import sop

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sop.utils.seed_all(42)

# config
exp_config = sop.ImageNetConfig()
val_config = exp_config.get_config('val_sm')
val_config['evaluation']['batch_size'] = 16
val_config



{'dataset': {'name': 'imagenet_s',
  'root': '/shared_data0/weiqiuy/datasets/imagenet'},
 'training': {'batch_size': 16,
  'num_epochs': 20,
  'mask_batch_size': 64,
  'optimizer': {'name': 'adamw', 'lr': 5e-06, 'weight_decay': 0.01}},
 'evaluation': {'split': 'val', 'num_data': 1, 'batch_size': 16},
 'model': {'type': 'vit',
  'base': 'google/vit-base-patch16-224',
  'sop': '/shared_data0/weiqiuy/sop/exps/imagenet_lr5e-06_tgtnnz0.2_gg0.0600_gs0.0100_ft_identify_fixk_scratch_ks3/best',
  'num_classes': 1000}}

In [2]:
backbone_model, original_model, processor, backbone_config, model, config = sop.tasks.imagenet.get_model(val_config['model']['type'],
                                                                 backbone_model_name=val_config['model']['base'],
                                                                 backbone_processor_name=val_config['model']['base'],
                                                                 sop_model_name=val_config['model']['sop'], eval_mode=True
                                                                                                        )

backbone_model = backbone_model.to(device)
original_model = original_model.to(device)
model = model.to(device)

projection layer is not frozen
projection layer is not frozen
Loaded step 40100


In [4]:
methods = [
    'shap_20',
    'rise_20',
    'lime_20',
    'sop',
    'fullgrad',
    'gradcam',
    'intgrad',
    'attn',
    'archipelago',
    'mfaba',
    'agi',
    'ampe',
    'bcos',
    'xdnn',
    'bagnet',
]

In [6]:
from sop.metrics import get_acc

In [11]:
from sop.tasks.images.imagenet import get_explainer

debug = True
k = 0.2

# method = 'lime_20'
# explainer_name = method.split('_')[0]
method = 'shap_20'

if method == 'sop':
    explainer = model
else:
    explainer = get_explainer(original_model, backbone_model, method.split('_')[0], device)
    
method_list = method.split('_')
explainer_name = method_list[0]

if len(method_list) == 2:
    suffix = f'_{method_list[1]}'
else:
    suffix = ''

if method != 'sop':
    ATTR_VAL_DATA_DIR = f'/shared_data0/weiqiuy/sop/exps/imagenet_vit_1/attributions_seg/{explainer_name}_1_pred{suffix}/val'
else:
    ATTR_VAL_DATA_DIR = None
    
val_dataset, val_dataloader = sop.tasks.imagenet.get_dataset(val_config['dataset']['name'], 
                                          split=val_config['evaluation']['split'], 
                                          num_data=val_config['evaluation']['num_data'],
                                          batch_size=val_config['evaluation']['batch_size'],
                                                        attr_dir=ATTR_VAL_DATA_DIR,
                                          processor=processor, debug=debug)

  0%|          | 0/1001 [00:00<?, ?it/s]

Loaded 100 images and 100 classes


In [12]:
import numpy as np
from tqdm.auto import tqdm

results_all = {}
for k in tqdm(np.linspace(0.1, 1, 10)):
    results = get_acc(val_dataloader, explainer, method, device, k=k, eval_all=False)
    results_all[k] = results
    break

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

In [14]:
# save_dir = f'/shared_data0/weiqiuy/sop/results/sparsity/{val_config["dataset"]["name"]}/'
# os.makedirs(save_dir, exist_ok=True)

# results_path = f'{save_dir}/{method}.pt'

# torch.save(results, results_path)

In [16]:
save_dir, results_path

('/shared_data0/weiqiuy/sop/results/sparsity/imagenet_s/',
 '/shared_data0/weiqiuy/sop/results/sparsity/imagenet_s//shap_20.pt')

In [17]:
data = torch.load(results_path)

In [18]:
data

{'acc': 0.5625,
 'corrects': [False,
  False,
  False,
  True,
  False,
  True,
  False,
  True,
  True,
  False,
  True,
  True,
  True,
  True,
  True,
  False]}