# HPO experiment: Grid Search on LRP-epsilon

In [1]:
import os
from collections import defaultdict

from tqdm import tqdm
import torch
import torchvision
import pandas as pd
import matplotlib.pyplot as plt

from open_xai import Project
from open_xai.explainers import LRPEpsilon


# models
def get_model(model_name):
    weights = torchvision.models.get_model_weights(model_name).DEFAULT
    model = torchvision.models.get_model(model_name, weights=weights).eval()
    transform = weights.transforms()
    return model, transform

# input images
def get_images(num_images):
    IMG_DIR = "../data/imagenet/images/"
    imgs = torch.stack([
        torchvision.io.read_image(os.path.join(IMG_DIR, fnm))
        for fnm in os.listdir(IMG_DIR)[:num_images]
    ])
    return imgs

img_to_np = lambda img: img.permute(1,2,0).detach().numpy()

# get infd from run
def get_infd_from_run(model, run, n_perturb=200, noise_scale=.2, batch_size=32):
    t = run.explainer_config.kwargs["target"]
    x = run.inputs.detach()
    y = model(run.inputs)[:, t].detach()
    a = run.outputs[0].detach()
    std, mean = torch.std_mean(x)

    # def _perturb(x, m, s, ns):
    #     x_p = x + ns * torch.normal(mean=m, std=s, size=x.shape)
    #     x_p = torch.minimum(x, x_p)
    #     x_p = torch.maximum(x-1, x_p)
    #     return x_p
    # x_p = torch.cat([
    #     _perturb(x, mean, std, noise_scale)
    #     for _ in range(n_perturb)
    # ])

    x_p = x.repeat(n_perturb, 1, 1, 1)
    noise = torch.normal(mean=mean, std=std, size=x_p.shape)
    x_p = x + noise_scale * noise
    x_p = torch.minimum(x, x_p)
    x_p = torch.maximum(x-1, x_p)

    y_p = torch.cat([
        model(x_p[i*batch_size:(i+1)*batch_size])[:, t].detach()
        for i, _ in enumerate(range(n_perturb)[::batch_size])
    ])
    y_d = y - y_p

    dot_prod = torch.mul(x_p, a).sum(dim=(1,2,3))
    mu = torch.ones(n_perturb)
    scaling_factor = torch.mean(mu*y_d*dot_prod)/torch.mean(mu*dot_prod*dot_prod)
    dot_prod *= scaling_factor
    return torch.mean(mu*torch.square(y_d-dot_prod)) / torch.mean(mu)

# get sens from run
def get_sens_from_run(model, run, epsilon=.2, n_perturb=10):
    x = run.inputs.detach()
    a = run.outputs[0].detach()
    def _perturb(x):
        noise = torch.rand(size=x.shape)
        noise = noise * epsilon * 2 - epsilon
        return x + noise
    a_p = torch.cat([
        run.explainer.run(
            data = _perturb(x),
            **run.explainer_config.kwargs
        )[0] for _ in range(n_perturb)
    ])
    lb = torch.tensor(-torch.inf)
    sens = torch.linalg.norm(a-a_p) / torch.linalg.norm(a)
    return torch.max(lb, sens)

# run
def implement_by_project_of_model(project_name, model, inputs, grid):
    proj = Project(project_name)
    for eps in grid:
        experiment = proj.explain(LRPEpsilon(model, epsilon=eps))
        for input in tqdm(inputs, total=len(inputs)):
            x = input.unsqueeze(0)#.detach()
            target = model(x).argmax(1).item()
            run = experiment.run(x, target=target)
            run.explainer_config.epsilon = eps
            run.infd = get_infd_from_run(model, run)
            run.sens = get_sens_from_run(model, run)
    return proj

# rearrange result to visualize and tabulate
def rearrange_by_input(proj):
    rearranged = defaultdict(list)
    for experiment in proj.experiments:
        for input_idx, run in enumerate(experiment.runs):
            rearranged[input_idx].append(dict(
                epsilon = run.explainer_config.epsilon,
                target = run.explainer_config.kwargs["target"],
                outputs = run.outputs[0],
                infd = run.infd,
                sens = run.sens,
            ))
    return rearranged

# post process for viz
def post_process(attr):
    attr = torch.nn.functional.relu(attr)
    postprocessed = attr.permute((1, 2, 0)).sum(dim=-1)
    attr_max = torch.max(postprocessed)
    attr_min = torch.min(postprocessed)
    postprocessed = (postprocessed - attr_min) / (attr_max - attr_min)
    return postprocessed.cpu().detach().numpy()

# viz
def visualize_proj(project, imgs):
    rearranged = rearrange_by_input(project)
    ncols = 1 + len(list(rearranged.values())[0])
    nrows = len(rearranged.keys())
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4*ncols, 4*nrows), gridspec_kw={"width_ratios": [1]*ncols})

    for input_idx, runinfos in rearranged.items():
        axes[input_idx, 0].imshow(img_to_np(imgs[input_idx]))
        for c_1, runinfo in enumerate(runinfos):
            if input_idx == 0:
                axes[input_idx, c_1+1].set_title(f"lrp-e={runinfo['epsilon']}")
            axes[input_idx, c_1+1].set_xlabel(
                f"infd: {runinfo['infd'].item():.4f} / sens: {runinfo['sens'].item():.4f}",
                fontsize=15,
            )
            axes[input_idx, c_1+1].imshow(post_process(runinfo["outputs"].squeeze()), cmap="gray")

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

# summary table
def tabulate_proj(project):
    rearranged = rearrange_by_input(project)
    summary = []
    for input_idx, runinfos in rearranged.items():
        for runinfo in runinfos:
            summary.append({
                "input_idx": input_idx,
                "epsilon": runinfo["epsilon"],
                "infd": runinfo["infd"].item(),
                "sens": runinfo["sens"].item(),
            })
    return pd.DataFrame.from_records(summary).set_index(["input_idx", "epsilon"]).groupby("epsilon").mean()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_result(model_name, num_images, grid=[1e-2, .25, .5, 1.]):
    model, transform = get_model(model_name)
    imgs = get_images(num_images)
    inputs = transform(imgs)
    proj_name = f"hpo_lrpe_{model_name}"
    proj = implement_by_project_of_model(proj_name, model, inputs, grid)
    visualize_proj(proj, imgs)
    return tabulate_proj(proj)

In [3]:
import warnings

warnings.filterwarnings(action="ignore")

In [4]:
get_result("vgg16", 4)

100%|██████████| 4/4 [01:56<00:00, 29.22s/it]
 25%|██▌       | 1/4 [00:31<01:35, 31.79s/it]

: 

In [None]:
get_result("resnet18", 4)

In [None]:
get_result("vit_b_16", 4)