In [96]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [102]:
import torch.nn as nn

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    """Self-distillation based on resnet"""

    def __init__(self, block, layers, branch_layers, num_classes=2):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        inplanes_head1 = self.inplanes
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        inplanes_head2 = self.inplanes
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        inplanes_head3 = self.inplanes
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.fc_main = nn.Linear(512 * block.expansion, num_classes)

        # side branch 1
        self.inplanes = inplanes_head1
        self.sb11 = self._make_layer(block, 128, branch_layers[0][0], stride=2)
        self.sb12 = self._make_layer(block, 256, branch_layers[0][1], stride=2)
        self.sb13 = self._make_layer(block, 512, branch_layers[0][2], stride=2)
        self.fc_head1 = nn.Linear(512 * block.expansion, num_classes)

        # side branch 2
        self.inplanes = inplanes_head2
        self.sb21 = self._make_layer(block, 256, branch_layers[1][0], stride=2)
        self.sb22 = self._make_layer(block, 512, branch_layers[1][1], stride=2)
        self.fc_head2 = nn.Linear(512 * block.expansion, num_classes)

        # side branch 3
        self.inplanes = inplanes_head3
        self.sb31 = self._make_layer(block, 512, branch_layers[2][0], stride=2)
        self.fc_head3 = nn.Linear(512 * block.expansion, num_classes)

        # CAM-attention
        self.cam_conv1 = nn.Conv2d(512 * block.expansion, num_classes, kernel_size=1, padding=0,
                                   bias=False)
        self.cam_bn1 = nn.BatchNorm2d(num_classes)
        self.cam_out = nn.Conv2d(num_classes, num_classes, kernel_size=1, padding=0,
                                 bias=False)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.cam_conv2 = nn.Conv2d(num_classes, 1, kernel_size=3, padding=1,
                                   bias=False)
        self.cam_bn2 = nn.BatchNorm2d(1)
        self.sigmoid = nn.Sigmoid()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = node1 = self.layer1(x)
        x = node2 = self.layer2(x)
        x = node3 = self.layer3(x)
        x = self.layer4(x)

        # CAM branch
        cam1 = self.relu(self.cam_bn1(self.cam_conv1(x)))
        out_cam = self.gap(self.cam_out(cam1))
        out_cam = out_cam.view(out_cam.size(0), -1)

        # main branch
        cam2 = self.sigmoid(self.cam_bn2(self.cam_conv2(cam1)))
        main_feature = x * cam2 + x
        m = self.gap(main_feature)
        m = m.view(m.size(0), -1)
        out_main = self.fc_main(m)

        # side branch 1
        hide_feature1 = self.sb13(self.sb12(self.sb11(node1)))
        h1 = self.gap(hide_feature1)
        h1 = h1.view(h1.size(0), -1)
        side_out1 = self.fc_head1(h1)

        # side branch 2
        hide_feature2 = self.sb22(self.sb21(node2))
        h2 = self.gap(hide_feature2)
        h2 = h2.view(h2.size(0), -1)
        side_out2 = self.fc_head2(h2)

        # side branch 3
        hide_feature3 = self.sb31(node3)
        h3 = self.gap(hide_feature3)
        h3 = h3.view(h3.size(0), -1)
        side_out3 = self.fc_head3(h3)

        return [out_cam, main_feature, out_main], [hide_feature1, side_out1], [hide_feature2, side_out2], [
            hide_feature3, side_out3]


        # Main  Branch1   Branch2  Branch3
        # 1.84GFlops  1.36GFlops  1.59GFlops 1.82GFlops   


def resnet18(num_classes, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], [[1, 1, 2], [1, 2], [2]], num_classes, **kwargs)
    return model


def resnet34(num_classes, **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], [[2, 3, 3], [3, 3], [3]], num_classes, **kwargs)
    return model


def resnet50(num_classes, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], [[2, 3, 3], [3, 3], [3]], num_classes, **kwargs)
    return model


def resnet101(num_classes, **kwargs):
    """Constructs a ResNet-101 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], [[2, 6, 3], [6, 3], [3]], num_classes, **kwargs)
    return model


def resnet152(num_classes, **kwargs):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], [[4, 12, 3], [12, 3], [3]], num_classes, **kwargs)
    return model


# if __name__ == '__main__':
#     import torch
#     import time
#     # from torchstat import stat

#     # net = resnet50(num_classes=4)
#     # # stat(net, (3, 224, 224))

#     # # Freeze some layers
#     # ct = 0
#     # for name, child in net.named_children():
#     #     ct += 1
#     #     if ct < 6:
#     #         for names, params in child.named_children():
#     #             params.requires_grad = False
#     N = 500
#     input = torch.randn(N, 3, 224 ,224).cuda()
#     net = resnet18(num_classes=2).cuda()
#     # stat(net, (3, 224, 224))

#     torch.cuda.synchronize()
#     start = time.time()

#     with torch.no_grad():
#         net(input)
    
#     torch.cuda.synchronize()
#     dur = time.time() - start

#     print(dur / N)

In [104]:
model = ResNet(BasicBlock, [2, 2, 2, 2], [[1, 1, 2], [1, 2], [2]], num_classes=2)

In [152]:


# import torch

# import numpy as np
# import cv2


# class GradCAM(object):

#     def __init__(self, net, layer_name):
#         self.net = net
#         self.layer_name = layer_name
#         self.feature = None
#         self.gradient = None
#         self.net.eval()
#         self.handlers = []
#         self._register_hook()

#     def _get_features_hook(self, module, input, output):
#         self.feature = output
#         print("feature shape:{}".format(output.size()))

#     def _get_grads_hook(self, module, input_grad, output_grad):
#         """
#         :param input_grad: tuple, input_grad[0]: None
#                                    input_grad[1]: weight
#                                    input_grad[2]: bias
#         :param output_grad:tuple
#         :return:
#         """
#         self.gradient = output_grad[0]

#     def _register_hook(self):
#         for (name, module) in self.net.named_modules():
#             if name == self.layer_name:
#                 self.handlers.append(module.register_forward_hook(self._get_features_hook))
#                 self.handlers.append(module.register_backward_hook(self._get_grads_hook))

#     def remove_handlers(self):
#         for handle in self.handlers:
#             handle.remove()

#     def __call__(self, inputs, index):
#         """
#         :param inputs: [1,3,H,W]
#         :param index: class id
#         :return:
#         """
#         self.net.zero_grad()
#         # inputs = torch.tensor(inputs, dtype=torch.float32)
#         output = self.net(inputs)[0][0]  # [1,num_classes]
#         if index is None:
#             index = np.argmax(output.cpu().data.numpy())
#         target = output[0][index]
#         target.backward()

#         gradient = self.gradient[0].cpu().data.numpy()  # [C,H,W]
#         weight = np.mean(gradient, axis=(1, 2))  # [C]

#         feature = self.feature[0].cpu().data.numpy()  # [C,H,W]

#         cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]
#         cam = np.sum(cam, axis=0)  # [H,W]
#         cam = np.maximum(cam, 0)  # ReLU

#         cam -= np.min(cam)
#         cam /= np.max(cam)
#         # resize to 224*224
#         cam = cv2.resize(cam, (224, 224))
#         return cam

class GradCAM(object):

    def __init__(self, net, layer_name):
        self.net = net
        self.layer_name = layer_name
        self.feature = None
        self.gradient = None
        self.net.eval()
        self.handlers = []
        self._register_hook()

    def _get_features_hook(self, module, input, output):
        self.feature = output
        print("feature shape:{}".format(output.size()))

    def _get_grads_hook(self, module, grad_input, grad_output):
        self.gradient = grad_output[0]
        print("gradient shape:{}".format(grad_output[0].size()))

    def _register_hook(self):
        for (name, module) in self.net.named_modules():
            if name == self.layer_name:
                self.handlers.append(module.register_forward_hook(self._get_features_hook))
                self.handlers.append(module.register_full_backward_hook(self._get_grads_hook))

    def remove_handlers(self):
        for handle in self.handlers:
            handle.remove()

    def __call__(self, inputs):

        self.net.zero_grad()
        output = self.net(inputs)
        
        output = output[0] # select first tensor from the tuple
        output = torch.Tensor(output[0].unsqueeze(0))
        index = torch.argmax(output, dim=1)
        target = output[0, index]
        target.backward()

        gradient = self.gradient.cpu().data.numpy()  # [C,H,W]
        weight = np.mean(gradient, axis=(1, 2))  # [C]

        feature = self.feature.cpu().data.numpy()  # [C,H,W]

        cam = feature * weight[:, np.newaxis, np.newaxis]  # [C,H,W]
        cam = np.sum(cam, axis=0)  # [H,W]
        cam = np.maximum(cam, 0)  # ReLU

        cam -= np.min(cam)
        cam /= np.max(cam)
        # resize to 224*224
        cam = cv2.resize(cam, (224, 224))
        return cam




In [148]:

import torch
from torch import nn
import numpy as np


class GuidedBackPropagation(object):

    def __init__(self, net):
        self.net = net
        for (name, module) in self.net.named_modules():
            if isinstance(module, nn.ReLU):
                module.register_backward_hook(self.backward_hook)
        self.net.eval()

    @classmethod
    def backward_hook(cls, module, grad_in, grad_out):
        """
        :param module:
        :param grad_in: tuple
        :param grad_out: tuple
        :return: tuple(new_grad_in,)
        """
        return torch.clamp(grad_in[0], min=0.0),

    def __call__(self, inputs, index=None):
        """
        :param inputs: [1,3,H,W]
        :param index: class_id
        :return:
        """
        self.net.zero_grad()
        output = self.net(inputs)[0][0]  # [1,num_classes]
        if index is None:
            index = np.argmax(output.cpu().data.numpy())
        target = output[0][index]

        target.backward()

        return inputs.grad[0]  # [3,H,W]

In [149]:


import argparse
import os
import re

import cv2
import numpy as np
import torch
from skimage import io
from torch import nn
from torchvision import models
import sys
sys.path.append("..")


from PIL import Image
import torchvision.transforms as transforms


def get_net():

    net = ResNet(BasicBlock, [2, 2, 2, 2], [[1, 1, 2], [1, 2], [2]], num_classes=2)

    return net


def get_last_conv_name(net):
    """
    :param net:
    :return:
    """
    layer_name = None
    for name, m in net.named_modules():
        if isinstance(m, nn.Conv2d):
            layer_name = name
    return layer_name


def prepare_input(image):
    image = image.copy()
    means = np.array([0.485, 0.456, 0.406])
    stds = np.array([0.229, 0.224, 0.225])
    image -= means
    image /= stds

    image = np.ascontiguousarray(np.transpose(image, (2, 0, 1)))
    image = image[np.newaxis, ...]

    return torch.tensor(image, requires_grad=True)


def gen_cam(image, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    heatmap = heatmap[..., ::-1]  # gbr to rgb

    cam = heatmap + np.float32(image)
    return norm_image(cam), (heatmap * 255).astype(np.uint8)


def norm_image(image):
    image = image.copy()
    image -= np.max(np.min(image), 0)
    image /= np.max(image)
    image *= 255.
    return np.uint8(image)


def gen_gb(grad):
    grad = grad.data.numpy()
    gb = np.transpose(grad, (1, 2, 0))
    return gb


def save_image(image_dicts, input_image_name, network, output_dir):
    prefix = os.path.splitext(input_image_name)[0]
    for key, image in image_dicts.items():
        io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)






In [154]:
# img = io.imread(args.image_path)
img = Image.open('/content/drive/MyDrive/Data/0/20051019_38557_0100_PP.tif').convert("RGB")
trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224)]
)
img = trans(img)
img.save("/content/drive/MyDrive/Data/rgb.jpg")
img = np.float32(cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)) / 255
# img = np.float32(cv2.resize(img, (224, 224))) / 255
inputs = prepare_input(img)
print(inputs.type())
image_dict = {}
net = get_net()
# Grad-CAM
layer_name = get_last_conv_name(net) #if args.layer_name is None else args.layer_name
grad_cam = GradCAM(net, layer_name)

mask = grad_cam(inputs)  # cam mask


torch.FloatTensor


RuntimeError: ignored

torch.FloatTensor


In [96]:
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
grad_cam.remove_handlers()
# # Grad-CAM++
# grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
# mask_plus_plus = grad_cam_plus_plus(inputs, args.class_id)  # cam mask
# image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
# grad_cam_plus_plus.remove_handlers()

# GuidedBackPropagation
gbp = GuidedBackPropagation(net)
inputs.grad.zero_()
grad = gbp(inputs)

gb = gen_gb(grad)
image_dict['gb'] = norm_image(gb)
# cam_gb = gb * mask[..., np.newaxis]
# image_dict['cam_gb'] = norm_image(cam_gb)

save_image(image_dict, os.path.basename('/content/drive/MyDrive/Data/0/20051019_38557_0100_PP.tif'), 'Res', '/content/drive/MyDrive/Data/images')

In [151]:
inputs

tensor([[[[-2.1008, -2.1008, -2.0837,  ..., -2.1008, -2.1008, -2.1008],
          [-2.0837, -2.1008, -2.0837,  ..., -2.1008, -2.0837, -2.0837],
          [-2.0837, -2.0837, -2.0837,  ..., -2.1008, -2.0837, -2.1008],
          ...,
          [-2.0837, -2.0837, -2.0665,  ..., -2.0837, -2.1008, -2.0837],
          [-2.0837, -2.0837, -2.0665,  ..., -2.0837, -2.0837, -2.0837],
          [-2.0837, -2.0837, -2.0837,  ..., -2.0837, -2.0837, -2.0837]],

         [[-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
          ...,
          [-2.0182, -2.0357, -2.0182,  ..., -2.0182, -2.0357, -2.0182],
          [-2.0182, -2.0357, -2.0182,  ..., -2.0182, -2.0182, -2.0357],
          [-2.0182, -2.0182, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],

         [[-1.7870, -1.7870, -1.7696,  ..., -1.7870, -1.7870, -1.7870],
          [-1.7696, -1.7870, -