In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple

import torch
import torch.nn as nn

from networks.ResNet_3D_CPM import Resnet18, DetectionPostprocess
from utils.box_utils import nms_3D
from logic.utils import load_states

model_path = './save/pretrained/best_model.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Resnet18()
model = model.to(device)
load_states(model_path, device, model)

In [None]:
from dataload.utils import load_series_list, load_image, load_label, gen_dicom_path, gen_label_path, ALL_CLS, ALL_RAD, ALL_LOC

image_spacing = np.array([1.0, 0.8, 0.8]) # z, y, x
series_list = load_series_list('./data/all_client_test.txt')
dicom_paths = []
labels = []

for folder, series_name in series_list:
    dicom_path = gen_dicom_path(folder, series_name)
    dicom_paths.append(dicom_path)
    label_path = gen_label_path(folder, series_name)
    labels.append(load_label(label_path, np.array([1.0, 1.0, 1.0]))) # use pixel rather than physical spacing

In [None]:
import math
from dataload.split_combine import SplitComb
from dataload.utils import load_image
from evaluationScript.nodule_finding import NoduleFinding

CROP_SIZE = [64, 128, 128]
OVERLAP_SIZE = [int(s) * 0.25 for s in CROP_SIZE]
batch_size = 16
top_k = 40
model.eval()
split_comber = SplitComb(crop_size=CROP_SIZE, overlap_size=OVERLAP_SIZE, pad_value=-1)
detection_postprocess = DetectionPostprocess(topk=60, threshold=0.7, nms_threshold=0.05, nms_topk=20, crop_size=CROP_SIZE)

def output2nodulefinding(output: np.ndarray) -> List[NoduleFinding]:
    pred_nodules = []
    for n in output:
        prob, z, y, x, d, h, w = n
        nodule_finding = NoduleFinding(coordX=x, coordY=y, coordZ=z, w=w, h=h, d=d, CADprobability=prob)
        nodule_finding.auto_nodule_type()
        pred_nodules.append(nodule_finding)
    return pred_nodules

def label2nodulefinding(label: np.ndarray) -> List[NoduleFinding]:
    nodules = []
    loc = label[ALL_LOC]
    rad = label[ALL_RAD]
    for (z, y, x), (d, h, w), r in zip(loc, rad, rad):
        nodule_finding = NoduleFinding(coordX=x, coordY=y, coordZ=z, w=w, h=h, d=d, CADprobability=1.0)
        nodule_finding.auto_nodule_type()
        nodules.append(nodule_finding)
    return nodules

def nodule2cude(nodules: List[NoduleFinding], shape: Tuple[int, int, int], stride=4):
    bboxes = []
    for nodule in nodules:
        z, y, x, d, h, w = nodule.coordZ, nodule.coordY, nodule.coordX, nodule.d, nodule.h, nodule.w
        z1 = max(round(z - d/2), 0)
        y1 = max(round(y - h/2), 0)
        x1 = max(round(x - w/2), 0)

        z2 = min(round(z + d/2), shape[0])
        y2 = min(round(y + h/2), shape[1])
        x2 = min(round(x + w/2), shape[2])
        bboxes.append((z1, y1, x1, z2, y2, x2))
    bboxes = np.array(bboxes)
    return bboxes

def inference(img_path:int) -> List[NoduleFinding]:
    image = load_image(img_path)
    image = image * 2.0 - 1.0 # convert to -1 ~ 1  note ste pad_value to -1 for SplitComb
    # split_images [N, 1, crop_z, crop_y, crop_x]
    split_images, nzhw = split_comber.split(image)
    data = torch.from_numpy(split_images)

    outputlist = []
    for i in range(int(math.ceil(data.size(0) / batch_size))):
        end = (i+1) * batch_size
        if end > data.size(0):
            end = data.size(0)
        input = data[i*batch_size:end].to(device)
        with torch.no_grad():
            output = model(input)
            output = detection_postprocess(output, device=device) #1, prob, ctr_z, ctr_y, ctr_x, d, h, w
        outputlist.append(output.data.cpu().numpy())
    
    output = np.concatenate(outputlist, 0)
    output = split_comber.combine(output, nzhw=nzhw)
    output = torch.from_numpy(output).view(-1, 8)
    object_ids = output[:, 0] != -1.0
    output = output[object_ids]
    if len(output) > 0:
        keep = nms_3D(output[:, 1:], overlap=0.05, top_k=top_k)
        output = output[keep]
    output = output.numpy()[:, 1:]
    nodules = output2nodulefinding(output)
    return image, nodules

In [None]:
idx = 5
img_path = dicom_paths[idx]
label = labels[idx]

gt_nodules = label2nodulefinding(label)
image, pred_nodules = inference(img_path)

gt_bboxes = nodule2cude(gt_nodules, image.shape)
pred_bboxes = nodule2cude(pred_nodules, image.shape)

mapped_image = ((image + 1) * 127.5).astype(np.uint8)
# copy 3D image to 3 channels
mapped_image = np.stack([mapped_image, mapped_image, mapped_image], axis=-1) # [Z, Y, X, 3]

In [None]:

MAX_IMAGE = 8
def draw_bbox(image: np.ndarray, bboxes: np.ndarray, color = (255, 0, 0)) -> np.ndarray:
    """
    Args:
        image: a 3D image with shape [Z, Y, X, 3]
    """
    for z1, y1, x1, z2, y2, x2 in bboxes:
        for z in range(int(z1), int(z2)):
            image[z] = cv2.rectangle(image[z].copy(), (x1, y1), (x2, y2), color, 1)
    return image

pred_zs = []
for bbox in pred_bboxes:
    zs = list(range(int(bbox[0]), int(bbox[3])))
    xs = list(range(int(bbox[1]), int(bbox[4])))
    center_x = (bbox[2] + bbox[5]) / 2
    
    bboxed_image = draw_bbox(mapped_image.copy(), bbox[np.newaxis, ...])
    if center_x < image.shape[2] / 2: # left
        bboxed_image = bboxed_image[:, :, :bboxed_image.shape[2]//2]
    else: # right
        bboxed_image = bboxed_image[:, :, bboxed_image.shape[2]//2:]
    
    if len(zs) > MAX_IMAGE:
        zs = range(zs[0], zs[-1], len(zs) // MAX_IMAGE)
    
    plt.figure(figsize=(int(len(zs) * 3.5), 9))
    
    center_x, center_y, center_z = (bbox[1] + bbox[4]) / 2, (bbox[2] + bbox[5]) / 2, (bbox[0] + bbox[3]) / 2
    w, h, d = bbox[4] - bbox[1], bbox[5] - bbox[2], bbox[3] - bbox[0]
    for i, z in enumerate(zs):
        ax = plt.subplot(1, len(zs), i+1)
        ax.imshow(bboxed_image[z], cmap='gray')
        ax.set_title(f'z={z}')
        ax.axis('off')
    plt.tight_layout()

In [None]:
pred_bboxes[2]