In [1]:
import matplotlib.pyplot as plt
import yaml
import pandas as pd
import cv2
import os
import torch
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T

from mafat_radar_challenge.main import get_instance

In [2]:
import mafat_radar_challenge.data_loader.augmentation as module_aug
import mafat_radar_challenge.data_loader.data_loaders as module_data
import mafat_radar_challenge.model.loss as module_loss
import mafat_radar_challenge.model.metric as module_metric
import mafat_radar_challenge.model.model as module_arch
import mafat_radar_challenge.data_loader.data_splitter as module_splitter
from mafat_radar_challenge.cli import load_config
from mafat_radar_challenge.main import get_instance

In [26]:
MODEL="/mnt/agarcia_HDD/mafat-radar-challenge-experiment/MAFAT_replica_aug_eff_b2_more_aux_more_synth_specaug_simple_aug_v5_adam/0808-140236/checkpoints/model_best.pth"

# Pytorch-visualizations packages

In [27]:
from pytorch_visualizations_master.src.guided_backprop import GuidedBackprop
from pytorch_visualizations_master.src.vanilla_backprop import VanillaBackprop
from pytorch_visualizations_master.src.misc_functions import save_gradient_images, convert_to_grayscale, get_positive_negative_saliency

In [28]:
cfg = load_config(os.path.join(os.path.dirname(MODEL), "config.yml"))
transforms = get_instance(module_aug, "augmentation", cfg)
if "sampler" in cfg:
    sampler = getattr(module_sampler, cfg["sampler"]["type"])
    sampler = partial(sampler, **cfg["sampler"]["args"])
else:
    sampler = None
# cfg["data_loader"]["args"]["sampler"] = sampler
data_loader = get_instance(module_data, "data_loader", cfg, transforms)
valid_data_loader = get_instance(module_data, "val_data_loader", cfg, transforms)
model = get_instance(module_arch, "arch", cfg)
checkpoint = torch.load(MODEL)
model.load_state_dict(checkpoint["state_dict"])

Loaded pretrained weights for efficientnet-b2


<All keys matched successfully>

In [29]:
input_model = model.model
features = nn.ModuleList([])
features.append(input_model._conv_stem)
features.append(input_model._bn0)
features.append(input_model._swish)
features.append(input_model._blocks)
features.append(input_model._conv_head)
features.append(input_model._bn1)
features.append(input_model._swish)

input_model.features = features

In [30]:
def format_for_plotting(tensor):
    """Formats the shape of tensor for plotting.
    Tensors typically have a shape of :math:`(N, C, H, W)` or :math:`(C, H, W)`
    which is not suitable for plotting as images. This function formats an
    input tensor :math:`(H, W, C)` for RGB and :math:`(H, W)` for mono-channel
    data.
    Args:
        tensor (torch.Tensor, torch.float32): Image tensor
    Shape:
        Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
        Output: :math:`(H, W, C)` or :math:`(H, W)`, respectively
    Return:
        torch.Tensor (torch.float32): Formatted image tensor (detached)
    Note:
        Symbols used to describe dimensions:
            - N: number of images in a batch
            - C: number of channels
            - H: height of the image
            - W: width of the image
    """

    has_batch_dimension = len(tensor.shape) == 4
    formatted = tensor.clone()

    if has_batch_dimension:
        formatted = tensor.squeeze(0)

    if formatted.shape[0] == 1:
        return formatted.squeeze(0).detach()
    else:
        return formatted.permute(1, 2, 0).detach()
    
def denormalize(tensor):
    """Reverses the normalisation on a tensor.
    Performs a reverse operation on a tensor, so the pixel value range is
    between 0 and 1. Useful for when plotting a tensor into an image.
    Normalisation: (image - mean) / std
    Denormalisation: image * std + mean
    Args:
        tensor (torch.Tensor, dtype=torch.float32): Normalized image tensor
    Shape:
        Input: :math:`(N, C, H, W)`
        Output: :math:`(N, C, H, W)` (same shape as input)
    Return:
        torch.Tensor (torch.float32): Demornalised image tensor with pixel
            values between [0, 1]
    Note:
        Symbols used to describe dimensions:
            - N: number of images in a batch
            - C: number of channels
            - H: height of the image
            - W: width of the image
    """

    means = [0.485, 0.456, 0.406]
    stds = [0.229, 0.224, 0.225]

    denormalized = tensor.clone()

    for channel, mean, std in zip(denormalized[0], means, stds):
        channel.mul_(std).add_(mean)

    return denormalized

def gradient_to_image(gradient):
    gradient = gradient - gradient.min()
    gradient /= gradient.max()
    gradient = np.uint8(gradient * 255).transpose(1, 2, 0)
    # Convert RBG to GBR
    gradient = gradient[..., ::-1]
    return gradient

In [None]:
for idx, (image_batch, label_batch) in enumerate(valid_data_loader):
    print("Batch {}".format(idx))
    for idx_2, (image, label) in enumerate(zip(image_batch, label_batch)):
        print("Image {}".format(idx_2))
        print("Label {}".format(label))
        image = image.unsqueeze(0)
        image.requires_grad = True
        label = label[None]
        label = int(label[0][0].cpu().numpy())
        # Guided backprop
        GBP = GuidedBackprop(input_model, image, label)
        # Get gradients
        guided_grads = GBP.generate_gradients()
        # Save colored gradients
        gradient_image = gradient_to_image(guided_grads)
        # image = denormalize(image)
        image = image.detach().cpu().numpy()
        image = image[0].transpose(1,2,0)
        plt.figure(figsize=(8,8))
        plt.subplot(121)    
        plt.imshow(gradient_image)
        plt.subplot(122)   
        plt.imshow(image)
        plt.show()