# XAI notebook
Notebook defines way to test different CNN explainability techniques. This metric used masking and GAN to change the background of the given object in (classification) task.

In [None]:
import torch
import torchvision
import numpy as np
import cv2
from PIL import Image
from utils import map_from_list
from torchvision.datasets import VOCDetection, Caltech101, VOCSegmentation, ImageFolder
import torchvision.transforms as transforms
%matplotlib inline 
import matplotlib.pyplot as plt

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

## Load datasets and pretrained networks

In [None]:

# Choose model
model_name = 'ResNet'

In [None]:
if model_name == 'ResNet':
    # Import only if model used
    from torchvision.models import resnet34, ResNet34_Weights
    
    # Loads best possible pre-trained weights for ImageNet dataset (further traning needed for other datasets)
    weights = ResNet34_Weights.DEFAULT
    # Init model with weights
    model = resnet34(weights=weights)

In [None]:
if model_name == 'DenseNet':
    # Import only if model used
    from torchvision.models import densenet121, DenseNet121_Weights

    # Loads best possible pre-trained weights for ImageNet dataset (further traning needed for other datasets)
    weights = DenseNet121_Weights.DEFAULT
    # Init model with weights
    model = densenet121(weights=weights)

In [None]:
root = 'C:/Users/pette/Documents/jupterNotebooks/machinelearning/datasets' # Own data root directory here
choose_dataset = 'VOC' # VOC, Caltech, ImageNet

In [None]:
# Transforms for resnet and densenet. Output transform only needed for segmantation:
def grey_scale_to_rgb(x):
    if x.size(dim=0) == 3:
        return x
    else:
        return x.repeat(3, 1, 1)
    
def VOC_to_label(x):
    all_classes = ['horse', 'person', 'bottle', 'dog', 'tvmonitor', 'car', 'aeroplane', 'bicycle',
                   'boat', 'chair', 'diningtable', 'pottedplant', 'train', 'cat', 'sofa', 'bird',
                   'sheep', 'motorbike', 'bus', 'cow']
    final_labels = torch.zeros(len(all_classes))
    for one_object in x['annotation']['object']:
        final_labels[all_classes.index(one_object['name'])] = 1
    return final_labels

transform_input = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.ToTensor(),
    transforms.Lambda(grey_scale_to_rgb),
    transforms.Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225]
 )
])
transform_output = transforms.Compose([
    transforms.Lambda(VOC_to_label)
])

In [None]:
if choose_dataset == 'VOC':
    num_of_classes = 20
    dataset = VOCDetection(root, year='2012', image_set='train', download=True, transform=transform_input, target_transform=transform_output)
    dataset_original = VOCDetection(root, year='2012', image_set='train', download=True, transform=None, target_transform=None)
    #dataset_segmentation = VOCSegmentation(root, year='2012', image_set='train', download=True, transform=None, target_transform=transform_output)
elif choose_dataset == 'Caltech':
    num_of_classes = 101
    dataset = Caltech101(root, download=True, transform=transform_input, target_transform=None)
    dataset_original = Caltech101(root, download=True, transform=None, target_transform=None)
elif choose_dataset == 'ImageNet':
    num_of_classes = 1000
    root += '/imagenet_images'
    dataset = ImageFolder(root, transform=transform_input, target_transform=None)
    dataset_original = ImageFolder(root, transform=None, target_transform=None)

In [None]:
print(dataset[10])

In [None]:
# Create dataloader
batch_size = 4

data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

## Train the model
Model can be further trained with loaded datasets. All classification models pretrained weights are trained on Imagenet dataset.

In [None]:
train = True
load_weights = False
# Create optimizer and loss function
crit = torch.nn.CrossEntropyLoss()

momentum = 0
lr = 0.001

if choose_dataset == 'VOC':
    crit = torch.nn.BCEWithLogitsLoss()
else:
    crit = torch.nn.CrossEntropyLoss()
    
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)

# Define number of epochs used for further training
epochs = 3

In [None]:
if train:
    model.fc = torch.nn.Linear(512, num_of_classes)
    for epoch in range(epochs):
        running_loss = 0
        for i, data in enumerate(data_loader, 0):
                inputs, labels = data[0], data[1]

                optimizer.zero_grad()
                outputs = model(inputs)

                loss = crit(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                print(f"{i}/{len(data_loader)}")
        torch.save(model.state_dict(), f'weights/latest{choose_dataset}.pth')
        print(f"Loss in epoch {epoch}: {running_loss/(len(data_loader)*batch_size)}")
else:
    if load_weights:
        model.load_state_dict(torch.load(f'weights/latest{choose_dataset}.pth'))
    else:
        pass # Use pretrained weights for XAI method evaluation

# GradCAM example

In [None]:
i = 1000
# Get transformed tensor with index
(input_tensor, labels) = dataset[i]


# Get original image with index and reshape(for plotting)
(img, label) = dataset_original[i]
img = cv2.resize(np.array(img), (300, 300))
img = np.float32(img) / 255 # Assume 8 bit pixels

input_tensor = input_tensor.unsqueeze(0)

if choose_dataset == 'VOC':
    labels = (labels==1).nonzero().squeeze().tolist()

if isinstance(labels, int):
    labels = [labels]

for label in labels:
    print(labels)
    if choose_dataset == 'ImageNet':
        label, name = map_from_list(label, dataset.find_classes(root)[0])
    # Set target as our ground truth label
    targets = [ClassifierOutputTarget(label)]
    # Define target layer
    target_layers = [model.layer4]

    # Run model with given cam
    with GradCAM(model=model, target_layers=target_layers) as cam:
        grayscale_cams = cam(input_tensor=input_tensor, targets=targets)
        cam_image = show_cam_on_image(img, grayscale_cams[0, :], use_rgb=True)

    # Make images the same format and plot original, greyscale and heatmap:
    cam = np.uint8(255*grayscale_cams[0, :])
    cam = cv2.merge([cam, cam, cam])
    images = np.hstack((np.uint8(255*img), cam , cam_image))
Image.fromarray(images)

## Threshold and create a new image

Threshold gradcam probabilities and multiply with the input tensor to get a new image to feed to the network

In [None]:
print(grayscale_cams)
threshold = 0.3
mask = (grayscale_cams > threshold).astype(int)

masked_tensor = torch.mul(input_tensor, torch.tensor(mask))
masked_image = np.uint8(255*img)*np.reshape(mask, (300,300,1))

In [None]:
plt.imshow(masked_image, interpolation='nearest')
plt.show()