In [1]:
import os
import numpy as np
import cv2
import random
import glob
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms



class_names=['control', 'diabetic']


In [2]:
class HeatmapGenerator ():
    
    #---- Initialize heatmap generator
    #---- pathModel - path to the trained densenet model
    #---- nnArchitecture - architecture name DENSE-NET121, DENSE-NET169, DENSE-NET201
    #---- nnClassCount - class count, 14 for chxray-14

 
    def __init__ (self, pathModel, img_size, mean=0, std=0, nnClassCount=2):
       
        #---- Initialize the network
        model = torchvision.models.densenet121(pretrained=False)
        num_ftrs = model.classifier.in_features
        model.classifier  = nn.Sequential(nn.Linear(num_ftrs, 500),nn.Linear(500,  nnClassCount))
        model = model.cuda()
        model.load_state_dict(torch.load(pathModel))

        self.model = model
        self.model.eval()
        
        #---- Initialize the weights
        self.weights = list(self.model.features.parameters())[-2]

        #---- Initialize the image transform
        normalize = transforms.Normalize(mean, std)
        transformList = []
        transformList.append(transforms.Resize((img_size, img_size)))
        transformList.append(transforms.ToTensor())
        if not (mean==0 and std==0): transformList.append(normalize)  
        self.transformSequence = transforms.Compose(transformList)
    
    #--------------------------------------------------------------------------------
     
    def generate (self, pathImageFile, save_img, transCrop):
        
        #---- Load image, transform, convert 
        with torch.no_grad():
 
            imageData = Image.open(pathImageFile).convert('RGB')
            imageData = self.transformSequence(imageData)
            imageData = imageData.unsqueeze_(0)
            if torch.cuda.is_available():
                imageData = imageData.cuda()
            l = self.model(imageData)
            output = self.model.features(imageData)
            label = class_names[torch.max(l,1)[1]]
            #---- Generate heatmap
            heatmap = None
            for i in range (0, len(self.weights)):
                map = output[0,i,:,:]
                if i == 0: heatmap = self.weights[i] * map
                else: heatmap += self.weights[i] * map
                npHeatmap = heatmap.cpu().data.numpy()

        #---- Blend original and heatmap 
                
        imgOriginal = cv2.imread(pathImageFile)
        # imgOriginal = cv2.resize(imgOriginal, (transCrop, transCrop))
        imgOriginal = cv2.cvtColor(imgOriginal, cv2.COLOR_BGR2RGB)
        cam = npHeatmap / np.max(npHeatmap)
        cam = cv2.resize(cam, (imgOriginal.shape[1],imgOriginal.shape[0]))
        heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
        img = cv2.addWeighted(imgOriginal, .5, heatmap, .5, 0)            
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLOR_BGR2RGB)
        

        h,w = img.shape[0:2] #np.moveaxis(imgOriginal, 0, 2)
        imgOriginal = Image.fromarray(imgOriginal)
        heatmap = Image.fromarray(heatmap)
        blend = Image.fromarray(img)

        target = Image.new('RGB', (3 * w, h))
        target.paste(imgOriginal, box = (0, 0))
        target.paste(heatmap, box = (w, 0))
        target.paste(blend, box = (2 * w, 0))
        
        if pathImageFile.split('\\')[-1].split('.')[0].split('_')[1]=='c' or pathImageFile.split('\\')[-1].split('.')[0].split('_')[-1] in ['h', 'N']:
            true_leb = 'control'
        else:
            true_leb = 'diabetic'

        I1 = ImageDraw.Draw(target)
        font = ImageFont.truetype('heatmap/FreeMono.ttf', int(w/15))
        I1.text((w, 36), 'pred label: '+label+('\nreal label: '+true_leb ), font=font, fill='black')
        target.save(os.path.join(save_img, pathImageFile.split('\\')[-1].split('.')[0])+'.jpg')

        # target.save(os.path.join(save_img, "result_{}.png".format(res_id)))
        # ff = plt.figure()
        # plt.title('predicted label: '+label+('\nreal label: control' if pathImageFile.split('\\')[-1].split('.')[0].split('_')[1]=='c' else '\nreal label: diabetic'))
        # plt.imshow(target)
        # plt.axis('off')
        # os.makedirs(save_img, exist_ok=True)
        # plt.savefig(os.path.join(save_img, pathImageFile.split('\\')[-1].split('.')[0]))
        # plt.close()
        # return ff
        
        

In [3]:
listpathInputImage = glob.glob('data/data_m_vessel/test/control/*')+glob.glob('data/data_m_vessel/test/retina/*')
# img_path = random.choice(listpathInputImage)
pathModel = "models/densenet121_vessel_e50_s300_b14.pt"

h = HeatmapGenerator(pathModel, 300)
for img_path in listpathInputImage:
    f = h.generate(img_path, 'heatmap/vessel', 300)
f



In [4]:
listpathInputImage = glob.glob('data/data_m/test/control/*')+glob.glob('data/data_m/test/retina/*')
# img_path = random.choice(listpathInputImage)
pathModel = "models/densenet121_e50_s300_b14.pt"

h = HeatmapGenerator(pathModel, 300, [0.3998, 0.1676, 0.0636], [0.2762, 0.1356, 0.0666])
for img_path in listpathInputImage:
    f = h.generate(img_path, 'heatmap/raw', 300)
f



In [5]:
listpathInputImage = glob.glob('data/data_m_prep/test/control/*')+glob.glob('data_prep/test/retina/*')
img_path = random.choice(listpathInputImage)
pathModel = "models/densenet121_prep_e50_s300_b14.pt"


h = HeatmapGenerator(pathModel, 300, [0.5023, 0.5017, 0.5019], [0.1245, 0.0934, 0.0581])
for img_path in listpathInputImage:
    f = h.generate(img_path, 'heatmap/prep', 300)
f

