### Visualizing gradient changes from fire class

### Using GradCAM

- Let's flatten the nested model layer structure for easy access

In [None]:
def tracesubmodules(module, inputs):
    handles, modules = [], []

    def trace(module, inputs, outputs):
        modules.append(module)

    def traverse(module):
        for m in module.children():
            traverse(m)  # recursion is love
        is_leaf = len(list(module.children())) == 0
        if is_leaf: handles.append(module.register_forward_hook(trace))

    traverse(module)

    _ = module(inputs)

    [h.remove() for h in handles]

    return modules

In [None]:
def draw_single_mask(batch_img, n_channels, batch_size, caption:str):
    x, y = batch_img
    fig, axes = plt.subplots(batch_size // 2, 2, figsize=(15,8))
    fig.suptitle(f"{caption}", fontsize=20)
    im_pos = 0
    for i in range(batch_size):
        axes[im_pos % 2].imshow(x[i].squeeze().view(_INPUT_SIZE,_INPUT_SIZE, n_channels).cpu().numpy())
        im_pos += 1
        axes[im_pos % 2].imshow(x[i].squeeze().view(_INPUT_SIZE,_INPUT_SIZE,n_channels).cpu().numpy())
        axes[im_pos % 2].imshow(y[i].view(_INPUT_SIZE,_INPUT_SIZE,1).cpu().numpy(), cmap='jet', alpha=0.5)
        im_pos += 1
    plt.show()

In [None]:
from torch.autograd import Variable

class GradCam():
    def __init__(self, module, device, *args, **kwargs):
        self.module, self.device = module, device
        self.handles = []
        self.gradients = None
        self.conv_outputs = None

    def store_outputs_and_grad(self, layer):
        def store_grads(module, grad_in, grad_out):
            self.gradients = grad_out[0]

        def store_outputs(module, input, outputs):
            if module == layer:
                self.conv_outputs = outputs

        self.handles.append(layer.register_forward_hook(store_outputs))
        self.handles.append(layer.register_full_backward_hook(store_grads))
        # self.handles.append(layer.register_full_backward_hook(store_grads))

    def guide(self, module):
        def guide_relu(module, grad_in, grad_out):
            return (torch.clamp(grad_out[0], min=0.0),)

        for module in module.modules():
            if isinstance(module, nn.ReLU):
                #self.handles.append(module.register_full_backward_hook(guide_relu))
                self.handles.append(module.register_backward_hook(guide_relu))


    def clean(self):
        [h.remove() for h in self.handles]

    def __call__(self, input_image, target_image, layer, guide=False, target_class=None, postprocessing=lambda x: x, regression=False):
        self.clean()
        self.module.zero_grad()

        if layer is None:
            modules = tracesubmodules(self.module, input_image)
            for i, module in enumerate(modules):
                if isinstance(module, nn.Conv2d):
                    layer = module

        self.store_outputs_and_grad(layer)

        if guide: self.guide(self.module)

        input_var = Variable(input_image, requires_grad=True).to(self.device)
        predictions = self.module(input_var)

        if target_class is None: values, target_class = torch.max(predictions, dim=1)
        if regression: predictions.backward(gradient=target_class, retain_graph=True)
        else:
            target = target_image.to(self.device)
            # target[0][target_class] = 1
            predictions.backward(gradient=target, retain_graph=True)

        with torch.no_grad():
            avg_channel_grad = F.adaptive_avg_pool2d(self.gradients.data, 1)
            self.cam = F.relu(torch.sum(self.conv_outputs[0] * avg_channel_grad[0], dim=0))
            print(self.cam.shape)
            # Extract heatmap from the torch compute graph, disable gradient computation

            # normalize


            # image_with_heatmap = tensor2cam(postprocessing(input_image.squeeze().cpu()), self.cam)
            if postprocessing is not None:
                input_image = postprocessing(input_image.squeeze().cpu())
            self.cam = self.cam.detach().cpu()
            # Normalize heatmap transforms
            self.cam = torch.sigmoid(self.cam)
            self.cam = (self.cam - self.cam.min())/ (self.cam.max() - self.cam)
            # self.cam = F.interpolate(self.cam.unsqueeze(0),
            #                           size=(1,_INPUT_SIZE, _INPUT_SIZE),
            #                           mode='linear',
            #                           align_corners=False)
            image_with_heatmap = input_image.squeeze().cpu(), self.cam.detach().cpu()
            draw_single_mask(image_with_heatmap, 3, _BATCH_SIZE, "Gradient CAM")


        self.clean()

        return image_with_heatmap, { 'prediction': target_class}

In [None]:
input_t, y = next(iter(train_loader))
check_point = torch.load("fcn_model.pt", map_location=torch.device('cpu'))
fcn_model.load_state_dict(check_point)

In [None]:
print(input_t.shape)

torch.Size([2, 3, 512, 512])


In [None]:
traced_fcn = tracesubmodules(fcn_model, input_t)

In [None]:
print(traced_fcn)
print(len(traced_fcn))

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False), Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False), BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU(inplace=True), Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False), BatchNorm2d(64, eps=1e-0

In [None]:
output_t = fcn_model(input_t)

In [None]:
fcn_model.eval()
layer = traced_fcn[-4] # Last ReLU of the conv layers


In [None]:
fcn_cam = GradCam(fcn_model, device)

In [None]:
layer_viz = fcn_cam(input_t, y, layer, postprocessing=None)