In [2]:
import os
import gc
import cv2
import torch
import numpy as np
import torchsummary

from sklearn.metrics import jaccard_score

from matplotlib import pyplot as plt
from segmentation_models_pytorch import DeepLabV3Plus
from torchvision import transforms, utils

from pytorch_grad_cam.utils.image import show_cam_on_image

%matplotlib inline

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device

'cpu'

## Image Pre-processing

In [4]:
def preprocess_image(img: np.ndarray, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], resized = (224, 224)) -> torch.Tensor:
  
  preprocessing = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Resize((224, 224)),
                                      transforms.Normalize(mean, std),
                                      ])
  return preprocessing(img.copy()).unsqueeze(0)

In [5]:
def mask_img_to_mask(mask_path, bg_path):
    
    gt_mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2GRAY)
    gt_bg   = cv2.cvtColor(cv2.imread(bg_path), cv2.COLOR_BGR2GRAY)
    
    tumor = (gt_mask == 52).astype(np.uint8).reshape(gt_mask.shape[0], gt_mask.shape[1], 1)
    stroma = (gt_mask == 94).astype(np.uint8).reshape(gt_mask.shape[0], gt_mask.shape[1], 1)
    normal = (gt_mask == 162).astype(np.uint8).reshape(gt_mask.shape[0], gt_mask.shape[1], 1)
    bg = (gt_bg/255).astype(np.uint8).reshape(gt_mask.shape[0], gt_mask.shape[1], 1)
    
    mask = np.concatenate((tumor, stroma, normal, bg), axis=2)
    
    return mask

In [6]:
def calculate_slice_bboxes(
    image_height: int,
    image_width: int,
    slice_height: int = 512,
    slice_width: int = 512,
    overlap_height_ratio: float = 0.2,
    overlap_width_ratio: float = 0.2,
) -> list[list[int]]:
    """
    Given the height and width of an image, calculates how to divide the image into
    overlapping slices according to the height and width provided. These slices are returned
    as bounding boxes in xyxy format.

    :param image_height: Height of the original image.
    :param image_width: Width of the original image.
    :param slice_height: Height of each slice
    :param slice_width: Width of each slice
    :param overlap_height_ratio: Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels)
    :param overlap_width_ratio: Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels)
    :return: a list of bounding boxes in xyxy format
    """

    slice_bboxes = []
    y_max = y_min = 0
    y_overlap = int(overlap_height_ratio * slice_height)
    x_overlap = int(overlap_width_ratio * slice_width)
    while y_max < image_height:
        x_min = x_max = 0
        y_max = y_min + slice_height
        while x_max < image_width:
            x_max = x_min + slice_width
            if y_max > image_height or x_max > image_width:
                xmax = min(image_width, x_max)
                ymax = min(image_height, y_max)
                xmin = max(0, xmax - slice_width)
                ymin = max(0, ymax - slice_height)
                slice_bboxes.append([xmin, ymin, xmax, ymax])
            else:
                slice_bboxes.append([x_min, y_min, x_max, y_max])
            x_min = x_max - x_overlap
        y_min = y_max - y_overlap
    return slice_bboxes

In [7]:
def img_resize(img, mask, factor):
    
    img_height, img_width, channels = img.shape
    
    #print(img.shape)
    
    new_height = int(factor * np.round(img_height/factor))
    new_width = int(factor * np.round(img_width/factor))
    
    img = cv2.resize(img, (max(new_width, 1), max(new_height, 1)))
    mask = cv2.resize(mask, (max(new_width, 1), max(new_height, 1)))
    
    #print(img.shape)
    
    slice_boxes = calculate_slice_bboxes(new_height, new_width, factor, factor, 0, 0)
    
    return img, mask, slice_boxes

In [8]:
def predict_mask(img):
    
    image_tensor = preprocess_image(img,)
    
    pred_mask = model(image_tensor.to(device))
    pred_mask = torch.nn.functional.softmax(pred_mask, dim=1)
    pred_mask = np.transpose(pred_mask.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
    
    return pred_mask

# DeeplabV3+

In [9]:
model = DeepLabV3Plus(encoder_name='resnet50', classes=4, activation=None, encoder_weights=None, ).to(device)
model.load_state_dict(torch.load(f='models/deeplabv3plus_dJ_par_resnet50_01.pth', map_location=device))

<All keys matched successfully>

In [10]:
%%capture
model.eval()

In [11]:
t_path = 'dataset/3.testing/3.testing/'
test_images = os.listdir('dataset/3.testing/3.testing/img')

In [14]:
for i in test_images:
    img_path  = t_path + 'img/' + str(i)
    mask_path = t_path + 'mask/' + str(i)
    bg_path   = t_path + 'background-mask/' + str(i)
    
    img     = cv2.imread(img_path)
    gt_mask = mask_img_to_mask(mask_path, bg_path)
    
    img, gt_mask, slice_boxes = img_resize(img, gt_mask, 224)
    pred = np.zeros(gt_mask.shape)
    
    for j in slice_boxes:
        pred[j[1]:j[3], j[0]:j[2]] = predict_mask(img[j[1]:j[3], j[0]:j[2]])
    
    pred = pred.round().astype('uint8')
    
    gt_t = (gt_mask[:, :, 0] + gt_mask[:, :, 1])
    gt_t[gt_t>1] = 1
    pred_t = (pred[:, :, 0] + pred[:, :, 1])
    pred_t[pred_t>1] = 1

    fig = plt.figure(figsize=(20, 20))
    
    ax = fig.add_subplot(2, 2, 1)
    plt.imshow(show_cam_on_image(img/255, gt_t, use_rgb=True))
    plt.title("Actual Tumor Tissue Area")
    
    ax = fig.add_subplot(2, 2, 2)
    plt.imshow(show_cam_on_image(img/255, gt_mask[:, :, 2], use_rgb=True))
    plt.title("Actual Normal Tissue Area")
    
    ax = fig.add_subplot(2, 2, 3)
    plt.imshow(show_cam_on_image(img/255, pred_t, use_rgb=True))
    plt.title("Predicted Tumor Tissue Area")
    
    ax = fig.add_subplot(2, 2, 4)
    plt.imshow(show_cam_on_image(img/255, pred[:, :, 2], use_rgb=True))
    plt.title("Predicted Normal Tissue Area")

    dem_img = 'heatmaps/' + str(i)[0:-3] + 'jpg'
    plt.savefig(dem_img)

    plt.close(fig)

    #break
    
    #tumor_v_1.append(jaccard_score(gt_t, pred_t, average='macro', zero_division=1))
    #normal_v_1.append(jaccard_score(gt_mask[:, :, 2], pred[:, :, 2], average='macro', zero_division=1))
    
    gc.collect()