In [1]:
import sys
import torch
sys.path.append('../')
from experiment import OCTClassification
from nntools.utils import load_yaml
sys.path.append('../pytorch-grad-cam')
from gradcam import GradCam, GuidedBackpropReLUModel
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2

In [2]:
trained_savedict = '/home/clement/Documents/Clement/runs/mlruns/1/e3c337f9fe6647cf9d61e81109fd6644/artifacts/iteration_380000_mIoU_0916.pth'
from opticNet import OpticNet
network1 = OpticNet(n_classes=4, n_channels=3)

network1.load(trained_savedict)
network1 = network1.cuda()
network1.eval()


Loading model from  /home/clement/Documents/Clement/runs/mlruns/1/e3c337f9fe6647cf9d61e81109fd6644/artifacts/iteration_380000_mIoU_0916.pth


OpticNet(
  (first_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ResConv(
      (branch_1): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU()
        (5): ReflectionPad2d((1, 0, 1, 0))
        (6): Conv2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
        (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): ReLU()
        (9): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      )
      (branch_2): Sequential(
        (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): ReLU()
        (2)

In [3]:
config_path = '../config.yaml'
config = load_yaml(config_path)
experiment = OCTClassification(config)
dataset = experiment.test_dataset

labels = {} 
for k, v in dataset.map_class.items():
    labels[v] = k
    
def show_cam_on_image(img, mask, alpha=0.5):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap*(1-alpha) + alpha*np.float32(img)
    cam = cam / np.max(cam)
    return cam



In [4]:
gradcam = GradCam(model=network1, feature_module=network1.mid4,
                       target_layer_names=None, use_cuda=True)

In [5]:
%matplotlib inline
preds = []
gts = []
index = np.arange(1000)
np.random.shuffle(index)    
output = 'OpticNet-Explained-GradCam_Kermany/'

if not os.path.exists(output):
    os.makedirs(output)
    
list_files = os.listdir(output)
for i in index:
    filename = dataset.filename(i)
    img, gt = dataset[i]
    img = img.unsqueeze(0).cuda()
    
    pred = network1(img)
    prob = torch.softmax(pred, dim=1)
    pred = torch.argmax(prob, dim=1)
    argsort = torch.argsort(prob, dim=1, descending=True)
    
    preds.append(pred)
    gts.append(gt)  
    grayscale_cam, grad_pred = gradcam(img, int(pred))
    
    img = img.squeeze(0).data.cpu().numpy().transpose((1,2,0))
    mix = show_cam_on_image(img, grayscale_cam, 0.75) 
    fig, ax = plt.subplots()
    fig.set_size_inches(9,9)
    title = 'Predicted: '
    for arg in argsort.squeeze(0)[:2]:
        title = title + ' ' + labels[int(arg)]+', Pr: {:.2%}'.format((prob[0][arg]))
    
    vis =  np.uint8(255 * mix)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)

    ax.imshow(vis)
    ax.set_title(title)
    
    # Hide grid lines
    ax.grid(False)
    # Hide axes ticks
    ax.set_xticks([])
    ax.set_yticks([])
    plt.tight_layout()
    filepath = os.path.join(output, filename)
    plt.savefig(filepath)
#     plt.show()
    plt.close(fig)

  "See the documentation of nn.Upsample for details.".format(mode))


In [6]:
preds = np.asarray([int(_) for _ in preds])
gts = np.asarray([int(_) for _ in gts])

In [7]:
np.sum(preds==gts)/1000

0.962

In [8]:
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(gts, preds)
cm

array([[249,   1,   0,   0],
       [  7, 243,   0,   0],
       [ 28,   0, 222,   0],
       [  0,   0,   2, 248]])