In [None]:
import sys
from pathlib import Path

import numpy as np

import torch
import torchvision
import torchvision.transforms as T

import matplotlib.pyplot as plt
from tqdm import tqdm, trange

from methods import gradcam3d as gradcam
from methods import rise3d as rise
from methods import sidu3d as sidu
from metrics import insertion3d as insertion
from metrics import deletion3d as deletion
from util import normalize

dataset_path = Path('/data/SOIB')
sys.path.append(str(dataset_path))

from SOIB import SOIB_Dataset

norm = lambda x: (x-x.min()) / (x.max() - x.min()+1e-13)

modality = 'IR'
net_modality = 'IR'

assert modality in ['IR', 'TV', 'IRTV']
assert net_modality in ['IR', 'TV', 'IRTV']

In [None]:
from transformations import RandomSequenceCrop

        
transform = T.Compose([
    RandomSequenceCrop(length=16, padding='clamp'),
    T.Lambda(lambda x: torch.stack([torchvision.io.read_image(str(path)) for path in x])/255),
    T.Resize(256),
    T.CenterCrop((224,224)),
])


ds = SOIB_Dataset(str(dataset_path), modality=modality, train=False, transforms=transform)

In [None]:
model = torchvision.models.video.r3d_18()
model.fc = torch.nn.Linear(512, 8, bias=True)
model = model.cuda().eval()
model.load_state_dict(torch.load(f'../WACV2022/weights/C3D_{net_modality}_15.pth'))

In [None]:
def evaluate(input, target=None):
    output = model(normalize(input).permute(1,0,2,3).cuda().unsqueeze(0))
    if target is None:
        target = output.argmax().item()
    gcam = gradcam(model, model.avgpool, vid, target)
    rcams = rise(model, vid.permute(1,0,2,3).unsqueeze(0), N=100)
    scams = sidu(model, model.avgpool, vid)

    scores = {}
    values = {}

    with torch.cuda.amp.autocast():
        for desc, cam in [('gcam', gcam), ('rcam', rcams[target]), ('scam', scams[target])]:
            dels, dscore = deletion(model, vid, cam, target)
            inss, iscore = insertion(model, vid, cam, target, factor=4)

            scores[desc] = {'insertion': iscore, 'deletion': dscore}
            values[desc] = {'insertion': inss, 'deletion': dels}
    
    return scores, values

In [None]:
scores = []
values = []
for i, (vid, target) in enumerate(tqdm(ds)):
    s, v = evaluate(vid)
    scores.append(s)
    values.append(v)

In [None]:
content = [(
{
    'experiment': i,
    'cam_type': 'gcam',
    'step': s,
    'insertion': v['gcam']['insertion'][s].item(),
    'deletion': v['gcam']['deletion'][s].item(),
},
{
    'experiment': i,
    'cam_type': 'scam',
    'step': s,
    'insertion': v['scam']['insertion'][s].item(),
    'deletion': v['scam']['deletion'][s].item(),
},
{
    'experiment': i,
    'cam_type': 'rcam',
    'step': s,
    'insertion': v['rcam']['insertion'][s].item(),
    'deletion': v['rcam']['deletion'][s].item(),
},
) for i, v in enumerate(values) for s in range(len(values[0]['gcam']['insertion']))]

df = []
for c in content:
    df += c

In [None]:
import pandas
import seaborn as sns

df = pandas.DataFrame(df)

In [None]:
from sklearn.metrics import auc
aucs = {}
for j, cam in df.groupby('cam_type'):
    aucs[j] = {'deletion': [], 'insertion': []}
    for i, g in cam.groupby('experiment'):
        aucs[j]['deletion'].append(auc(g['step'], g['deletion']))
        aucs[j]['insertion'].append(auc(g['step'], g['insertion']))

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16,8))
ax1.set_title('Insertion')
ax2.set_title('Deletion')

sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='gcam'], ax=ax1, label=f'Grad-CAM AUC=${np.mean(aucs["gcam"]["insertion"]):0.2f}\pm {np.std(aucs["gcam"]["insertion"]):0.2f}$')
sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='scam'], ax=ax1, label=f'SIDU AUC=${np.mean(aucs["scam"]["insertion"]):0.2f}\pm {np.std(aucs["scam"]["insertion"]):0.2f}$')
sns.lineplot(y='insertion', x='step', data=df[df.cam_type=='rcam'], ax=ax1, label=f'RISE AUC=${np.mean(aucs["rcam"]["insertion"]):0.2f}\pm {np.std(aucs["rcam"]["insertion"]):0.2f}$')

sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='gcam'], ax=ax2, label=f'Grad-CAM AUC=${np.mean(aucs["gcam"]["deletion"]):0.2f}\pm {np.std(aucs["gcam"]["deletion"]):0.2f}$')
sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='scam'], ax=ax2, label=f'SIDU AUC=${np.mean(aucs["scam"]["deletion"]):0.2f}\pm {np.std(aucs["scam"]["deletion"]):0.2f}$')
sns.lineplot(y='deletion', x='step', data=df[df.cam_type=='rcam'], ax=ax2, label=f'RISE AUC=${np.mean(aucs["rcam"]["deletion"]):0.2f}\pm {np.std(aucs["rcam"]["deletion"]):0.2f}$')