In this first experiment, we compare RangeGrad to a number of existing methods. Here, we use the VGG19 network trained on the ImageNet dataset. We compare our results to the simple gradient (baseline), SmoothGrad and GradCam.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from random import choices
from minmax.models import densenet, resnet, vgg
from torchvision.models import vgg as vggn
from time import time

transform_test = transforms.Compose([
    #transforms.Resize(256),
    #transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_reverse = transforms.Compose([
    transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
    transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
])

testset = torchvision.datasets.ImageNet(
    root='./files/ImageNet',
    split="val",
    transform=transform_test
)

samples = [8922,13899,16096,4753,9301,37782,35062,25899,23031,44597] + choices(range(len(testset)), k=10) 
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=1,
    num_workers=2,
    sampler=samples
)

net, netn, r = vgg.vgg19(pretrained=True), vggn.vgg19(pretrained=True), float(0.000001)
net.eval()
net.zero_grad()
net.requires_grad_(False)
print("Loading Complete")

times = []
def stopwatch(title=""):
    now = time()
    if len(times) > 0:
        print(title, " - time: ", now - times[-1])
    times.append(now)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
from minmax.mm import mmTensor
from minmax.ImageNetClasses import classes
import scipy.ndimage as ndimage
from pytorch_grad_cam import GradCAM
import os
from datetime import datetime

OUTPUT_FOLDER = "Experiment_1"
try:
    os.mkdir(OUTPUT_FOLDER)
    print("Output Folder Created")
except FileExistsError:
    print("Output Folder Exists")

def set_scale(m):
    #m.scale_factor = s
    #m.debug_range = True
    #m.debug_relu = True
    m.autoscale_relu = 1/1000

def remove_scale(m):
    if hasattr(m, 'scale_factor'):
        del m.scale_factor
    if hasattr(m, 'debug_range'):
        del m.debug_range
    if hasattr(m, 'debug_relu'):
        del m.debug_relu
    if hasattr(m, 'autoscale_relu'):
        del m.autoscale_relu

for i, (data, target) in enumerate(testloader):
    # Get Standard Gradient
    stopwatch()
    net.apply(remove_scale)
    inputs = torch.clone(data)
    inputs.requires_grad = True
    outputs = netn(inputs)

    prediction = outputs.argmax().item()
    one_hot = torch.zeros((1,1000))
    one_hot[0,prediction] = 1

    if prediction != target:
        print("Wrong classification, skipping")
        continue
    print("Sample ID: ", samples[i])
    print("True Class: ", classes[target.item()])
    print("Prediction: ", classes[prediction])

    outputs.backward(gradient=one_hot)
    
    expl_gradient = inputs.grad.data.detach()[0]
    expl_gradient = np.linalg.norm(expl_gradient, axis=0)
    expl_gradient -= expl_gradient.min()
    expl_gradient /= expl_gradient.max()
    stopwatch("Standard Gradient")

    # Get GradCam Gradient
    inputs = torch.clone(data)
    target_layers = [netn.features[-1]] #VGG
    cam = GradCAM(model=netn, target_layers=target_layers)
    expl_gradcam = cam(input_tensor=inputs, targets=None)[0]
    stopwatch("GradCam")
    
    # Get RangeGrad Gradient
    net.apply(set_scale)
    inputs = mmTensor(data)
    mask = torch.full(data[0].shape, float(r), requires_grad=True)
    inputs.add_mask(mask, mask)
    
    outputs = net(inputs)
    
    l = outputs.upper() - outputs.lower()
    l.backward(gradient=one_hot)
    
    expl_rangegrad = mask.grad.data.detach().numpy()
    expl_rangegrad = np.linalg.norm(expl_rangegrad, axis=0)
    expl_rangegrad -= expl_rangegrad.min()
    expl_rangegrad /= expl_rangegrad.max()
    stopwatch("RangeGrad")
    
    # Get SmoothGrad Gradient
    net.apply(remove_scale)
    
    SAMPLES = 35
    shape = list(data.shape)
    shape[0] = SAMPLES # Nr of samples
    inputs = torch.Tensor(np.random.normal(0, 0.15, shape).astype(np.float32)) 
    inputs += data
    inputs.requires_grad = True
    outputs = netn(inputs)
    outputs.backward(gradient=one_hot.expand(SAMPLES, -1))

    expl_smoothgrad = inputs.grad.data.detach().numpy()
    expl_smoothgrad = expl_smoothgrad.sum(axis=0)
    expl_smoothgrad = np.linalg.norm(expl_smoothgrad, axis=0)
    expl_smoothgrad -= expl_smoothgrad.min()
    expl_smoothgrad /= expl_smoothgrad.max()
    stopwatch("SmoothGrad")

    # Draw All
    image = transform_reverse(data)[0].movedim(0,2).numpy()
    plt.clf()
    plt.axis("off")
    plt.imshow(image)
    plt.savefig("{}/{}_Original.png".format(OUTPUT_FOLDER, samples[i]), bbox_inches='tight',pad_inches=0)
    plt.clf()
    
    for n, a in [("Gradient", expl_gradient), ("SmoothGrad", expl_smoothgrad), ("GradCam", expl_gradcam), ("RangeGrad", expl_rangegrad)]:
        plt.clf()
        plt.axis("off")
        plt.imshow(a, cmap=plt.cm.Reds)
        plt.savefig("{}/{}_{}.png".format(OUTPUT_FOLDER, samples[i], n), bbox_inches='tight',pad_inches=0)
        np.save("{}/{}_{}.npy".format(OUTPUT_FOLDER, samples[i], n), a)

    plt.clf()
    
    fig, axs = plt.subplots(1,5)
    fig.set_figwidth(20)
    
    axs[0].set_title("Original")
    axs[0].axis("off")
    axs[0].imshow(image)
    
    axs[1].set_title("Gradient")
    axs[1].axis("off")
    axs[1].imshow(expl_gradient, cmap=plt.cm.Reds)
    
    axs[2].set_title("SmoothGrad")
    axs[2].axis("off")
    axs[2].imshow(expl_smoothgrad, cmap=plt.cm.Reds)
    
    axs[3].set_title("GradCam")
    axs[3].axis("off")
    axs[3].imshow(expl_gradcam, cmap=plt.cm.Reds)
    
    axs[4].set_title("RangeGrad")
    axs[4].axis("off")
    axs[4].imshow(expl_rangegrad, cmap=plt.cm.Reds)
    plt.show()