In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import resnet101
from torchvision.models.densenet import densenet201
from torchvision.models.vgg import vgg19
from data import ct_mean, ct_std
from argparse import Namespace
import numpy as np
from skimage.io import imsave, imread
import os
import random
from data import get_dataset
from time import time

from models import GradCAM 
from models import MainModel

In [29]:
def jet(image):
    n = 4 * image[:, :1]
    r = torch.clamp(torch.min(n-1.5,-n+4.5), 0, 1)
    g = torch.clamp(torch.min(n-0.5,-n+3.5), 0, 1)
    b = torch.clamp(torch.min(n+0.5,-n+2.5), 0, 1)
    return torch.cat((r,g,b), 1)


def main(opts):
    # load image and convert to tensor
    image = imread(os.path.join(opts.data_dir, 'images', '%s.png' % opts.img_id))
    image = torch.tensor(image[None, None, ...], dtype=torch.float32) / 255
    image = (image - ct_mean) / ct_std
    image = image.expand(-1, 3, -1, -1)
    ind = torch.tensor([[opts.ind]])
    # construct CNN model
    model = MainModel(opts.arch, 6)

    # load weights
    model.load_state_dict(torch.load(opts.dict_file))
    # Grad CAM
    print(image.shape)
    print(np.unique(image[0][1] == image[0][2]))
    print(image[0][2])
#     print(image[0][3])
    
    grad_cam = GradCAM(model)
    cam = grad_cam(image, ind)
    # output image with cam
    cam = jet(cam)
    image = torch.clamp(image * ct_std + ct_mean, 0, 1)
    imsave('%s.png' % opts.img_id, np.around(image[0,0].cpu().numpy()*255).astype(np.uint8))
    image = image + cam
    
    image = np.moveaxis(image[0].cpu().numpy(), 0, 2)
    image = image / image.max()
    image = np.around(image*255).astype(np.uint8)
    print(image.shape)
    
    imsave('%s-cam.png' % opts.img_id, image)

In [30]:
main(
    Namespace(
    data_dir = '../',
    use_gpu = False,
    arch = 'densenet201',
    img_id = '793d85d63',
    ind = 5,
    dict_file = 'densenet201_pretrained/model_densenet201.pt'
))

torch.Size([1, 3, 256, 256])
[ True]
tensor([[-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968],
        [-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968],
        [-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968],
        ...,
        [-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968],
        [-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968],
        [-0.5968, -0.5968, -0.5968,  ..., -0.5968, -0.5968, -0.5968]])
(256, 256, 3)
