In [1]:
import collections
import torch.nn as nn
import torchvision
from torchvision.models import resnet_cbam

class MyResNet(nn.Module):
    def __init__(self, in_channels=1):
        super(MyResNet, self).__init__()
        # bring resnet
        self.model = torchvision.models.resnet34(pretrained=True)
        for params in self.model.parameters():
            params.requires_grad = False

        layers_resnet = nn.Sequential(collections.OrderedDict([
            ('fc1', nn.Linear(512, 256)),
            ('activation1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('activation2', nn.ReLU()),
            ('fc3', nn.Linear(128, 2)),
            ('out', nn.Sigmoid())
        ]))

        self.model.fc = layers_resnet
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        return self.model(x)

class MyResNet_CBAM(nn.Module):
    def __init__(self, in_channels=1):
        super(MyResNet_CBAM, self).__init__()
        # bring resnet
        self.model = resnet_cbam.resnet34_cbam(pretrained=True)
        for params in self.model.parameters():
            params.requires_grad = False

        layers_resnet = nn.Sequential(collections.OrderedDict([
            ('fc1', nn.Linear(512, 256)),
            ('activation1', nn.ReLU()),
            ('fc2', nn.Linear(256, 128)),
            ('activation2', nn.ReLU()),
            ('fc3', nn.Linear(128, 2)),
            ('out', nn.Sigmoid())
        ]))

        self.model.fc = layers_resnet
        self.model.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        return self.model(x)

    print("model ready.")

model ready.


In [2]:
import PIL
import torch
from torchvision import transforms
from torchvision.utils import make_grid

from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
diagnose_path = [
    'Data/test/NORMAL/IM-0007-0001.jpeg',
    'Data/test/NORMAL/IM-0011-0001.jpeg',
    'Data/test/NORMAL/IM-0013-0001.jpeg',
    'Data/test/NORMAL/IM-0015-0001.jpeg',
    'Data/test/NORMAL/IM-0023-0001.jpeg',
    'Data/test/NORMAL/IM-0029-0001.jpeg',
    'Data/test/NORMAL/IM-0030-0001.jpeg',
    'Data/test/NORMAL/IM-0033-0001.jpeg',
    'Data/test/NORMAL/IM-0041-0001.jpeg',
    'Data/test/NORMAL/IM-0043-0001.jpeg',
    'Data/test/NORMAL/IM-0046-0001.jpeg',
    'Data/test/NORMAL/IM-0049-0001.jpeg',
    'Data/test/NORMAL/IM-0073-0001.jpeg',
    'Data/test/NORMAL/IM-0084-0001.jpeg',
    'Data/test/NORMAL/IM-0087-0001.jpeg',
    'Data/test/NORMAL/IM-0110-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0028-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0041-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0059-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0060-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0102-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0107-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0117-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0120-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0135-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0275-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0290-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0292-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0294-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0319-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0331-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0339-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0369-0001.jpeg',
    'Data/test/NORMAL/NORMAL2-IM-0378-0001.jpeg',
    'Data/test/PNEUMONIA/person112_bacteria_539.jpeg',
    'Data/test/PNEUMONIA/person15_virus_46.jpeg',
]

triage_path = [
    'Data/test/PNEUMONIA/person113_bacteria_543.jpeg',
    'Data/test/PNEUMONIA/person119_bacteria_568.jpeg',
    'Data/test/PNEUMONIA/person123_bacteria_587.jpeg',
    'Data/test/PNEUMONIA/person151_bacteria_718.jpeg',
    'Data/test/PNEUMONIA/person153_bacteria_726.jpeg',
    'Data/test/PNEUMONIA/person1612_virus_2797.jpeg',
    'Data/test/PNEUMONIA/person1612_virus_2798.jpeg',
    'Data/test/PNEUMONIA/person161_bacteria_762.jpeg',
    'Data/test/PNEUMONIA/person1628_virus_2822.jpeg',
    'Data/test/PNEUMONIA/person1629_virus_2823.jpeg',
    'Data/test/PNEUMONIA/person1633_virus_2829.jpeg',
    'Data/test/PNEUMONIA/person1635_virus_2831.jpeg',
    'Data/test/PNEUMONIA/person1637_virus_2834.jpeg',
    'Data/test/PNEUMONIA/person1643_virus_2843.jpeg',
    'Data/test/PNEUMONIA/person1645_virus_2845.jpeg',
    'Data/test/PNEUMONIA/person1649_virus_2850.jpeg',
    'Data/test/PNEUMONIA/person1651_virus_2855.jpeg',
    'Data/test/PNEUMONIA/person1653_virus_2858.jpeg',
    'Data/test/PNEUMONIA/person1663_virus_2876.jpeg',
    'Data/test/PNEUMONIA/person1670_virus_2886.jpeg',
    'Data/test/PNEUMONIA/person1675_virus_2891.jpeg',
    'Data/test/PNEUMONIA/person175_bacteria_833.jpeg',
    'Data/test/PNEUMONIA/person38_virus_83.jpeg',
    'Data/test/PNEUMONIA/person64_virus_122.jpeg',
    'Data/test/PNEUMONIA/person71_virus_131.jpeg',
    'Data/test/PNEUMONIA/person78_bacteria_381.jpeg',
    'Data/test/PNEUMONIA/person78_virus_140.jpeg',
]

In [4]:
torch.cuda.empty_cache()

resnet_1 = MyResNet()
resnet_1.load_state_dict(torch.load("2x_diagnose/v1_resnet34_224px/resnet34_epochs40_triage.pth"))
# resnet_1.load_state_dict(torch.load("2x_triage/v1_resnet34_224px/resnet34_epochs40_triage.pth"))

resnet_2 = MyResNet_CBAM()
resnet_2.load_state_dict(torch.load("2x_diagnose/v3_resnet34_CBAM/resnet34_epochs40_diagnose_cbam.pth"))
# resnet_2.load_state_dict(torch.load("2x_triage/v3_resnet34_CBAM/resnet34_epochs40_triage_cbam.pth"))

<All keys matched successfully>

In [8]:
configs = [
    dict(model_type='resnet', arch=resnet_1.model, layer_name='layer4'),
    dict(model_type='resnet', arch=resnet_2.model, layer_name='layer4'),
]
cams = [
    [cls.from_config(**config) for cls in (GradCAM, GradCAMpp)]
    for config in configs
]

for _path in diagnose_path:
    pil_img = PIL.Image.open(_path)
    torch_img = transforms.Compose([
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor()
                                   ])(pil_img).to(device)
    normed_torch_img = transforms.Normalize([0.485], [0.229])(torch_img)[None]
    
    for config in configs:
        config['arch'].to(device).eval()

    images = []

    idx = 0
    if 'NORMAL' in _path:
        idx = 1
    elif 'PNEUMONIA' in _path:
        idx = 0

    """
    if 'bacteria' in _path:
        idx = 1
    elif 'virus' in _path:
        idx = 0
    """

    for gradcam, gradcam_pp in cams:
        mask, _ = gradcam(normed_torch_img, class_idx=idx)
        heatmap, result = visualize_cam(mask, torch_img)

        mask_pp, _ = gradcam_pp(normed_torch_img, class_idx=idx)
        heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)

        # images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
        images.extend([heatmap, result, heatmap_pp, result_pp])
    grid_image = make_grid(images, nrow=4)
    transforms.ToPILImage()(grid_image).save("./grad_cam/diagnose_{}.png".format(_path[10:-5].replace('/','_')))

In [6]:
print(resnet_1)
print(resnet_2)

MyResNet(
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn