# XAI notebook
Notebook defines way to test different CNN explainability techniques. This metric used masking and GAN to change the background of the given object in (classification) task.

In [None]:
import torch
import torchvision
import numpy as np
import cv2
import random
from PIL import Image
from utils import map_from_list
from torchvision.datasets import VOCDetection, Caltech101, VOCSegmentation, ImageFolder
import torchvision.transforms as transforms
%matplotlib inline 
import matplotlib.pyplot as plt

In [None]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

## Load datasets and pretrained networks

In [None]:
root = 'C:/Users/pette/Documents/jupterNotebooks/machinelearning/datasets' # Own data root directory here
choose_dataset = 'Caltech' # VOC, Caltech, ImageNet

In [None]:
# Transforms for resnet and densenet. Output transform only needed for segmantation:
def grey_scale_to_rgb(x):
    if x.size(dim=0) == 3:
        return x
    else:
        return x.repeat(3, 1, 1)

def image_grey_scale_to_rgb(im):
    new = im.convert(mode='RGB')
    return new
    
def VOC_to_label(x):
    all_classes = ['horse', 'person', 'bottle', 'dog', 'tvmonitor', 'car', 'aeroplane', 'bicycle',
                   'boat', 'chair', 'diningtable', 'pottedplant', 'train', 'cat', 'sofa', 'bird',
                   'sheep', 'motorbike', 'bus', 'cow']
    final_labels = torch.zeros(len(all_classes))
    for one_object in x['annotation']['object']:
        final_labels[all_classes.index(one_object['name'])] = 1
    return final_labels

transform_input = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.ToTensor(),
    transforms.Lambda(grey_scale_to_rgb),
    transforms.Normalize(
         mean=[0.485, 0.456, 0.406],
         std=[0.229, 0.224, 0.225]
 )
])
transform_output = transforms.Compose([
    transforms.Lambda(VOC_to_label)
])

caltech_image_transform = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.Lambda(image_grey_scale_to_rgb)
    ])

In [None]:
if choose_dataset == 'VOC':
    num_of_classes = 20
    dataset = VOCDetection(root, year='2012', image_set='train', download=True, transform=transform_input, target_transform=transform_output)
    dataset_original = VOCDetection(root, year='2012', image_set='train', download=True, transform=None, target_transform=None)
    #dataset_segmentation = VOCSegmentation(root, year='2012', image_set='train', download=True, transform=None, target_transform=transform_output)
elif choose_dataset == 'Caltech':
    num_of_classes = 101
    dataset = Caltech101(root, download=True, transform=transform_input, target_transform=None)
    dataset_original = Caltech101(root, download=True, transform=caltech_image_transform, target_transform=None)
elif choose_dataset == 'ImageNet':
    num_of_classes = 1000
    root += '/imagenet_images'
    dataset = ImageFolder(root, transform=transform_input, target_transform=None)
    dataset_original = ImageFolder(root, transform=None, target_transform=None)

In [None]:
print(dataset[10])
print(dataset_original[10])

In [None]:
# Choose model
model_name = 'ResNet'

In [None]:
if model_name == 'ResNet':
    # Import only if model used
    from torchvision.models import resnet34, ResNet34_Weights
    
    # Loads best possible pre-trained weights for ImageNet dataset (further traning needed for other datasets)
    weights = ResNet34_Weights.DEFAULT
    # Init model with weights
    model = resnet34(weights=weights)
    
    # Remove global average pooling as it reduces HiResCAM to GradCAM:
    model = list(model.children())[:-2]
    model = torch.nn.Sequential(*model)
    model.add_module("flatten", torch.nn.Flatten())
    model.add_module("linear_end", torch.nn.Linear(51200, num_of_classes))

In [None]:
if model_name == 'DenseNet':
    # Import only if model used
    from torchvision.models import densenet121, DenseNet121_Weights

    # Loads best possible pre-trained weights for ImageNet dataset (further traning needed for other datasets)
    weights = DenseNet121_Weights.DEFAULT
    # Init model with weights
    model = densenet121(weights=weights)

In [None]:
# Create dataloader
batch_size = 4

data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

## Train the model
Model can be further trained with loaded datasets. All classification models pretrained weights are trained on Imagenet dataset.

In [None]:
train = False
load_weights = True
# Create optimizer and loss function
crit = torch.nn.CrossEntropyLoss()

momentum = 0
lr = 0.001

if choose_dataset == 'VOC':
    crit = torch.nn.BCEWithLogitsLoss()
else:
    crit = torch.nn.CrossEntropyLoss()
    
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum = momentum)

# Define number of epochs used for further training
epochs = 2

In [None]:
if train:
    for epoch in range(epochs):
        running_loss = 0
        for i, data in enumerate(data_loader, 0):
                inputs, labels = data[0], data[1]

                optimizer.zero_grad()
                outputs = model(inputs)

                loss = crit(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

                print(f"{i}/{len(data_loader)}")
        torch.save(model.state_dict(), f'weights/latest{choose_dataset}.pth')
        print(f"Loss in epoch {epoch}: {running_loss/(len(data_loader)*batch_size)}")
else:
    if load_weights:
        model.load_state_dict(torch.load(f'weights/latest{choose_dataset}.pth'))
    else:
        pass # Use pretrained weights for XAI method evaluation

# GradCAM example

In [None]:
# First define couple of fuctions to make code more readable
def get_difference_mask(cam_1, cam_2):
    # Threshold both:
    threshold = 0.3
    cam_1_mask = (cam_1 > threshold).astype(int)
    cam_2_mask = (cam_2 > threshold).astype(int)
    
    # Calculate places where different
    difference_mask = ((cam_1_mask != cam_2_mask) & (cam_2_mask == 1)).astype(int)
    
    return difference_mask

def mask_input_tensor(input_tensor, difference_mask, random_mask = True):
    # Create random mask if defined, otherwise invert the mask
    if random_mask:
        random_array = np.random.uniform(low = torch.min(input_tensor), high = torch.max(input_tensor),
                                         size = (input_tensor.size()))
        mask = difference_mask*random_array
        mask[mask==0] = 1
    else:
        mask = 1-difference_mask

    output_tensor = torch.mul(input_tensor, torch.tensor(mask))
    return output_tensor, mask
        

In [None]:
#print(dict(model.named_modules()))

num_of_images = 10
all_images = random.sample(range(0, len(dataset)), num_of_images)

plotting = False
new_tensors = {}

for number in all_images:
    # Get transformed tensor with index
    (input_tensor, labels) = dataset[number]


    # Get original image with index and reshape(for plotting)
    (img, label) = dataset_original[number]
    img = cv2.resize(np.array(img), (300, 300))
    img = np.float32(img) / 255 # Assume 8 bit pixels

    input_tensor = input_tensor.unsqueeze(0)

    if choose_dataset == 'VOC':
        labels = (labels==1).nonzero().squeeze().tolist()

    if isinstance(labels, int):
        labels = [labels]

    for label in labels:

        output = model(input_tensor)
        output = torch.argmax(output)
        
        # Set target as our predicted label
        targets = [ClassifierOutputTarget(output)]
        # Define target layer
        target_layers = [model[7][2]]

        # Run model cams
        with GradCAM(model=model, target_layers=target_layers) as cam:
            grayscale_grad_cams = cam(input_tensor=input_tensor, targets=targets)
            grad_cam_image = show_cam_on_image(img, grayscale_grad_cams[0, :], use_rgb=True)

        with HiResCAM(model=model, target_layers=target_layers) as cam:
            grayscale_hires_cams = cam(input_tensor=input_tensor, targets=targets)
            hires_cam_image = show_cam_on_image(img, grayscale_hires_cams[0, :], use_rgb=True)

        if plotting:
            # Make images the same format and plot original, greyscale and heatmap:
            grad_cam = np.uint8(255*grayscale_grad_cams[0, :])
            grad_cam = cv2.merge([grad_cam, grad_cam, grad_cam])
            
            hires_cam = np.uint8(255*grayscale_hires_cams[0, :])
            hires_cam = cv2.merge([hires_cam, hires_cam, hires_cam])
            
            images = np.hstack((np.uint8(255*img), grad_cam, grad_cam_image, hires_cam, hires_cam_image))
            Image.fromarray(images).show()

        difference_mask = get_difference_mask(grayscale_grad_cams, grayscale_hires_cams)

        masked_tensor, mask = mask_input_tensor(input_tensor, difference_mask, random_mask = False)
        # Store prediction and new tensor:
        new_tensors[number] = (masked_tensor, output, label, mask)

        # Break out of the loop as only one prediction needed
        break

## Store masked input image (and load the GAN results)

In [None]:
def mask_input_image(input_image, mask):
    transform_image = transforms.Compose([
        transforms.Resize((300,300)),
        transforms.ToTensor()
    ])
    mask = mask.squeeze(0)
    output_image = torch.mul(transform_image(input_image), torch.tensor(mask))
    output_image
    transform = transforms.ToPILImage()
    return transform(output_image)
    

In [None]:
store_pic_and_mask = True
if store_pic_and_mask:
    for key, value in new_tensors.items():
        # Get stored values from first run:
        input_tensor, original_prediction, label, mask = value
        number = key
        input_tensor = input_tensor.type(torch.float)

        # Get original image with index and reshape(for plotting)
        (img, label) = dataset_original[number]

        mask_img = mask_input_image(img, mask)

        dataset_original[key][0].save(f'image{key}.png')
        im = Image.fromarray((np.reshape(mask, (300,300)) * 255).astype(np.uint8))
        im.save(f'mask{key}.png')

## Rerun and calculate metrics

Run model again with the masked image and calculate metrics based on the predicted labels

In [None]:
for key, value in new_tensors.items():
    # Get stored values from first run:
    input_tensor, original_prediction, label, mask = value
    number = key
    input_tensor = input_tensor.type(torch.float)
    
    # Run the model
    output = model(input_tensor)
    output = torch.argmax(output)

    # Get original image with index and reshape(for plotting)
    (img, label) = dataset_original[number]
    numpy_img = cv2.resize(np.array(img), (300, 300))
    numpy_img = np.float32(img) / 255 # Assume 8 bit pixels

    # Set target as our predicted label. TODO: Maybe target as the same label as the previous prediction?
    targets = [ClassifierOutputTarget(original_prediction)]
    # Define target layer
    target_layers = [model[7][2]]

    # Run model cams
    with GradCAM(model=model, target_layers=target_layers) as cam:
        grayscale_grad_cams = cam(input_tensor=input_tensor, targets=targets)
        grad_cam_image = show_cam_on_image(numpy_img, grayscale_grad_cams[0, :], use_rgb=True)

    with HiResCAM(model=model, target_layers=target_layers) as cam:
        grayscale_hires_cams = cam(input_tensor=input_tensor, targets=targets)
        hires_cam_image = show_cam_on_image(numpy_img, grayscale_hires_cams[0, :], use_rgb=True)

    img = mask_input_image(img, mask)
    if plotting:
        # Make images the same format and plot original, greyscale and heatmap:
        grad_cam = np.uint8(255*grayscale_grad_cams[0, :])
        grad_cam = cv2.merge([grad_cam, grad_cam, grad_cam])

        hires_cam = np.uint8(255*grayscale_hires_cams[0, :])
        hires_cam = cv2.merge([hires_cam, hires_cam, hires_cam])

        images = np.hstack((img, grad_cam, grad_cam_image, hires_cam, hires_cam_image))
        Image.fromarray(images).show()
    print(output, "  ", original_prediction, "  ", label)