In [None]:
from pathlib import Path

import sys
sys.path.append(str(Path('../WACV2022/PAN-PyTorch').resolve()))

import torch.nn.functional as F

from ops.dataset import PANDataSet
from ops.models import PAN
from ops.transforms import *
from opts import parser
from ops import dataset_config
from ops.utils import AverageMeter, accuracy
from ops.temporal_shift import make_temporal_pool


from methods import gradcamPAN as gradcam
from methods import risePAN as rise
from methods import siduPAN as sidu

from metrics import insertionPAN as insertion
from metrics import deletionPAN as deletion
from util import normalize
from util import groupNormalize as norm

import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
spectrum = 'TV'
model_spectrum = 'IRTV'

num_classes = 8
cfg_LITE = {'dataset': f'{spectrum.lower()}', 'modality': 'Lite', 'train_list': f'/data/SOIB/file_lists/{spectrum}_train.txt', 'val_list': f'/data/SOIB/file_lists/{spectrum}_test.txt', 'root_path': '/data/', 'store_name': 'PAN_Lite_irtv_resnet50_shift8_blockres_avg_segment8_e50', 'lmdb': False, 'arch': 'resnet50', 'num_segments': 8, 'consensus_type': 'avg', 'k': 3, 'dropout': 0.5, 'loss_type': 'nll', 'img_feature_dim': 256, 'suffix': None, 'pretrain': 'imagenet', 'tune_from': None, 'base': 'TSM', 'epochs': 50, 'batch_size': 22, 'lr': 0.001, 'lr_type': 'step', 'lr_steps': [30.0, 40.0], 'momentum': 0.9, 'weight_decay': 0.0001, 'clip_gradient': 20, 'no_partialbn': True, 'iter_size': 1, 'print_freq': 20, 'eval_freq': 1, 'workers': 18, 'resume': '', 'evaluate': False, 'snapshot_pref': '', 'start_epoch': 0, 'gpus': None, 'flow_prefix': '', 'root_log': 'log', 'root_model': 'checkpoint', 'shift': True, 'shift_div': 8, 'shift_place': 'blockres', 'temporal_pool': False, 'non_local': False, 'dense_sample': False, 'VAP': True}

cfg = cfg_LITE

In [None]:
data_length = 1 if cfg['modality'] == 'RGB' else 4

In [None]:
model = PAN(num_classes, cfg['num_segments'], cfg['modality'],
                base_model=cfg['arch'],
                consensus_type=cfg['consensus_type'],
                dropout=cfg['dropout'],
                img_feature_dim=cfg['img_feature_dim'],
                partial_bn=not cfg['no_partialbn'],
                pretrain=cfg['pretrain'],
                is_shift=cfg['shift'], shift_div=cfg['shift_div'], shift_place=cfg['shift_place'],
                fc_lr5=not (cfg['tune_from'] and cfg['dataset'] in cfg['tune_from']),
                temporal_pool=cfg['temporal_pool'],
                non_local=cfg['non_local'], data_length=data_length, has_VAP=cfg['VAP']).cuda()

In [None]:
weights = torch.load(f'../WACV2022/PAN-PyTorch/checkpoint/PAN_Lite_{model_spectrum.lower()}_resnet50_shift8_blockres_avg_segment8_e50/ckpt.best.pth.tar')
model.load_state_dict({k.split('.',1)[1]: v for k, v in weights['state_dict'].items()})
model.eval()

In [None]:
val_loader = torch.utils.data.DataLoader(
        PANDataSet('/data/SOIB', f'/data/SOIB/file_lists/{spectrum}_test.txt', num_segments=cfg['num_segments'],
                   new_length=data_length,
                   modality=cfg['modality'],
                   image_tmpl='{:04d}.png',
                   random_shift=False,
                   transform=torchvision.transforms.Compose([
                       GroupScale(int(model.scale_size)),
                       GroupCenterCrop(model.crop_size),
                       Stack(roll=(cfg['arch'] in ['BNInception', 'InceptionV3'])),
                       ToTorchFormatTensor(div=(cfg['arch'] not in ['BNInception', 'InceptionV3'])),
                   ]), dense_sample=False, is_lmdb=False),
                    batch_size=1, shuffle=False,
                    num_workers=4, pin_memory=True)


In [None]:
import time
def evaluate(input, target=None):
    if target is None:
        output = model(norm(input).cuda())
        target = output.argmax().item()
    gcam = gradcam(model, model.base_model.avgpool, vid, target)
    rcams = rise(model, vid)
    scams = sidu(model, model.base_model.avgpool, vid)
    scores = {}
    values = {}
    
    with torch.cuda.amp.autocast():
        for desc, cam in [('gcam', gcam), ('rcam', rcams[target]), ('scam', scams[target])]:
            dels, dscore = deletion(model, vid, cam, target)
            inss, iscore = insertion(model, vid, cam, target, factor=4)

            scores[desc] = {'insertion': iscore, 'deletion': dscore}
            values[desc] = {'insertion': inss, 'deletion': dels}
    
    return scores, values

In [None]:
scores = []
values = []
for i, (vid, target) in enumerate(tqdm(val_loader)):
    s, v = evaluate(vid)
    scores.append(s)
    values.append(v)

In [None]:
content = [(
{
    'experiment': i,
    'cam_type': 'gcam',
    'step': s,
    'insertion': v['gcam']['insertion'][s].item(),
    'deletion': v['gcam']['deletion'][s].item(),
},
{
    'experiment': i,
    'cam_type': 'scam',
    'step': s,
    'insertion': v['scam']['insertion'][s].item(),
    'deletion': v['scam']['deletion'][s].item(),
},
{
    'experiment': i,
    'cam_type': 'rcam',
    'step': s,
    'insertion': v['rcam']['insertion'][s].item(),
    'deletion': v['rcam']['deletion'][s].item(),
},
) for i, v in enumerate(values) for s in range(len(values[0]['gcam']['insertion']))]

df = []
for c in content:
    df += c

In [None]:
import pandas
import seaborn as sns

df = pandas.DataFrame(df)

In [None]:
from sklearn.metrics import auc
aucs = {}
for j, cam in df.groupby('cam_type'):
    aucs[j] = {'deletion': [], 'insertion': []}
    for i, g in cam.groupby('experiment'):
        aucs[j]['deletion'].append(auc(g['step'], g['deletion']))
        aucs[j]['insertion'].append(auc(g['step'], g['insertion']))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,8))
ax1.set_title('Insertion')
ax2.set_title('Deletion')

sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='gcam'], ax=ax1, label=f'Grad-CAM AUC=${np.mean(aucs["gcam"]["insertion"]):0.2f}\pm {np.std(aucs["gcam"]["insertion"]):0.2f}$')
sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='scam'], ax=ax1, label=f'SIDU AUC=${np.mean(aucs["scam"]["insertion"]):0.2f}\pm {np.std(aucs["scam"]["insertion"]):0.2f}$')
sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='rcam'], ax=ax1, label=f'RISE AUC=${np.mean(aucs["rcam"]["insertion"]):0.2f}\pm {np.std(aucs["rcam"]["insertion"]):0.2f}$')

sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='gcam'], ax=ax2, label=f'Grad-CAM AUC=${np.mean(aucs["gcam"]["deletion"]):0.2f}\pm {np.std(aucs["gcam"]["deletion"]):0.2f}$')
sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='scam'], ax=ax2, label=f'SIDU AUC=${np.mean(aucs["scam"]["deletion"]):0.2f}\pm {np.std(aucs["scam"]["deletion"]):0.2f}$')
sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='rcam'], ax=ax2, label=f'RISE AUC=${np.mean(aucs["rcam"]["deletion"]):0.2f}\pm {np.std(aucs["rcam"]["deletion"]):0.2f}$')