In [None]:
import cv2
import torch
import string
import pathlib
import numpy as np
import pandas as pd
import torchvision
import time
import tqdm

from sklearn.metrics import roc_auc_score, precision_score, recall_score, confusion_matrix

from typing import List
from predict2 import Prediction
from matplotlib import pyplot as plt

from unet.model import UNet
from utils.dataset import Dataset
from utils.rgb import rgb2mask, mask2rgb, LABEL_COLORS
from utils.plots import plot_img_and_mask

# Logging
from utils.logging import logging

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

## Variables

In [None]:
patch_size = 768

## Util Functions

In [None]:
def create_mask_image(bbox_list, width, height):
    image = np.zeros((width, height, 1), np.uint8)
    for box in bbox_list:
        x, y, w, h = box
        cv2.rectangle(image, (x, y), (x + w, y + h), (255), -1)
    return image

def retrieve_bounding_boxes(input_image):
    count, hierarchy = cv2.findContours(input_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    bounding_boxes: List = []
    for contours in count:
        x, y, w, h = cv2.boundingRect(contours)
        if (w * h > 40):
            bounding_boxes.append([x, y, w, h])
    return bounding_boxes, len(bounding_boxes)

def preload_image_data(data_dir: string, img_dir: string, is_mask: bool = False):
    dataset_files: List = []
    with open(pathlib.Path(data_dir, 'test_dataset.txt'), mode='r', encoding='utf-8') as file:
        for i, line in enumerate(file):
            path = pathlib.Path(data_dir, img_dir, line.strip(), f'Image/{line.strip()}.png' if is_mask == False else f'Mask/0.png')

            # Load image
            img = cv2.imread(str(path))
            img = Dataset._resize_and_pad(img, (patch_size, patch_size), (0, 0, 0))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            if is_mask:
                img = rgb2mask(img)
            dataset_files.append(img)
    return dataset_files

test_imgs = preload_image_data(r'data', r'imgs', False)
test_labels = preload_image_data(r'data', r'imgs', True)
model_name = r'checkpoints/silvery-serenity-371/checkpoint.pth.tar'

In [None]:
def compute_IoU(cm):
    '''
    Adapted from:
        https://github.com/davidtvs/PyTorch-ENet/blob/master/metric/iou.py
        https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/keras/metrics.py#L2716-L2844
    '''
    
    sum_over_row = cm.sum(axis=0)
    sum_over_col = cm.sum(axis=1)
    true_positives = np.diag(cm)

    # sum_over_row + sum_over_col = 2 * true_positives + false_positives + false_negatives.
    denominator = sum_over_row + sum_over_col - true_positives

    with np.errstate(divide='ignore', invalid='ignore'):
        iou = np.true_divide(true_positives, denominator)
        iou[iou == np.inf] = 0.0
        iou = np.nan_to_num(iou)
    return iou, np.nanmean(iou)

def IoU(mask_true, mask_pred, n_classes=2):
        labels = np.arange(n_classes)
        cm = confusion_matrix(mask_true.flatten(), mask_pred.flatten(), labels=labels)        
        return compute_IoU(cm)

## Model prediction

In [None]:
model_params = {
    'model_name': model_name,
    'patch_width': patch_size,
    'patch_height': patch_size,
    'n_channels': 3,
    'n_classes': 3
}
model = Prediction(model_params)
model.initialize()

log.info('[PREDICTION]: Model loaded!')
log.info(f'[PREDICTION]: Starting prediction on {len(test_imgs)} image(s).')

predicted_labels = []
img_process_time_list = []
m_ious = []

batch_start_time = time.time()
pbar = tqdm.tqdm(enumerate(test_imgs), total=len(test_imgs))
for i, img in pbar:
    img_start_time = time.time()
    mask_predict = model.predict_image(img)
    img_process_time = time.time() - img_start_time

    predicted_labels.append(mask_predict)
    img_process_time_list.append(img_process_time * 1000)

pbar.close()
batch_process_time = time.time() - batch_start_time

### Getting Metrics

In [None]:
def roc_auc_score_multiclass(actual_class, pred_class, average = "macro"):
  # Creating a set of all the unique classes using the actual class list
  unique_class = set(actual_class)
  roc_auc_list = [0.0, 0.0, 0.0]

  for per_class in unique_class:
    # Creating a list of all the classes except the current class 
    other_class = [x for x in unique_class if x != per_class]

    # Marking the current class as 1 and all other classes as 0
    new_actual_class = [0 if x in other_class else 1 for x in actual_class]
    new_pred_class = [0 if x in other_class else 1 for x in pred_class]

    # Using the sklearn metrics method to calculate the roc_auc_score
    try:
      roc_auc = roc_auc_score(new_actual_class, new_pred_class, average = average)
    except:
      roc_auc = 0.0    
    roc_auc_list[int(per_class)] = roc_auc

  return np.mean(roc_auc_list), roc_auc_list

In [49]:
def getting_confusion_matrix_data(mask_true, mask_pred, n_classes=2):
    labels = np.arange(n_classes)
    cm = confusion_matrix(mask_true.flatten(), mask_pred.flatten(), labels=labels)        
    
    FP = cm.sum(axis=0) - np.diag(cm)  
    FN = cm.sum(axis=1) - np.diag(cm)
    TP = np.diag(cm)
    TN = cm.sum() - (FP + FN + TP)
    return FP, FN, TP, TN

def perf_measure(y_actual, y_hat):
   TP = 0
   FP = 0
   TN = 0
   FN = 0

   y_actual = y_actual.flatten()
   y_hat = y_hat.flatten()

   for i in range(len(y_hat)): 
      if y_actual[i]==y_hat[i]==1:
         TP += 1
      if y_hat[i]==1 and y_actual[i]!=y_hat[i]:
         FP += 1
      if y_actual[i]==y_hat[i]==0:
         TN += 1
      if y_hat[i]==0 and y_actual[i]!=y_hat[i]:
         FN += 1

   return TP, FP, TN, FN

def get_data_list(y_actual, y_hat):
   data_list = [] 
   true_list = []

   y_actual = y_actual.flatten()
   y_hat = y_hat.flatten()

   for i in range(len(y_hat)): 
      if y_actual[i] == y_hat[i]==1:
         TP += 1
      if y_hat[i] == 1 and y_actual[i] != y_hat[i]:
         FP += 1
      if y_actual[i] == y_hat[i] == 0:
         TN += 1
      if y_hat[i] == 0 and y_actual[i] != y_hat[i]:
         FN += 1

   return data_list, true_list


In [51]:
precision_list = []
recall_list = []
auc_list = []
iou_list = []

pbar = tqdm.tqdm(enumerate(test_labels), total=len(test_labels))
for i, label in pbar:
    true_label = label.flatten()
    predict_label = predicted_labels[i].flatten()
    labels = np.arange(model_params['n_classes'])

    precision_t = precision_score(true_label, predict_label, average=None, labels=labels, zero_division=1)
    # recall = recall_score(true_label, predict_label, average=None, labels=labels, zero_division=1)
    # auc_s, auc_per_class = roc_auc_score_multiclass(true_label, predict_label)

    FP, FN, TP, TN = getting_confusion_matrix_data(label, predicted_labels[i], model_params['n_classes'])
    FP1, FN1, TP1, TN1 = perf_measure(label, predicted_labels[i])
    # class_iou, mean_iou = IoU(label, predicted_labels[i], model_params['n_classes'])

    precision = TP / (TP + FP)
    precision1 = TP1 / (TP1 + FP1)

    print('precision', precision)
    print('precision1', precision1)
    print('precision_t', precision_t)

    # pbar.desc = f'Precision: {precision[1]:.4f} | Recall: {recall[1]:.4f} | AUC: {auc_per_class[1]:.4f} | Mean IoU: {mean_iou:.4f} | Processing Time: {img_process_time_list[i]:.3f}ms'

    # precision_list.append(precision)
    # recall_list.append(recall)
    # auc_list.append(auc_per_class)
    # iou_list.append(class_iou)
    break

pbar.close()

print(precision_list)

# String log evaluation metrics
# log.info(
# f'Precision: {np.mean(precision_list):.5f} | \
# Recall: {np.mean(recall_list):.5f} | \
# AUC: {np.mean(auc_list):.5f} | \
# Processing Time: {batch_process_time * 1000}s'
# )

  0%|          | 0/683 [00:00<?, ?it/s]

precision [0.99906186 0.84162963 0.        ]
precision1 0.9903417861830249
precision_t [0.99906186 0.84162963 0.        ]
[]





### Visualization of prediction data

In [None]:
# Image Based eval Precision-Recall curve
recall_list.sort()
precision_list.sort()

plt.plot(recall_list, precision_list, marker='.', color='blue', label=f'U-NET: {np.mean(auc_list):.2f}')
plt.title('Image based evaluation Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.savefig('results/img_based_eval_precision_recall_curve.svg', format='svg', dpi=300)

### Saving data to Excel

In [None]:
import pandas as pd


metrics_frame = pd.DataFrame()
iou_frame = pd.DataFrame()

### Za Franka

In [None]:
import PIL
import matplotlib.image as mpimg

franko_dir = pathlib.Path(r'franko')
images_dir = pathlib.Path(franko_dir, r'images')
ground_truth_dir = pathlib.Path(franko_dir, r'ground_truths')
predictions_dir = pathlib.Path(franko_dir, r'predictions')

for i, img in tqdm.tqdm(enumerate(test_imgs), total=len(test_imgs)):
    save_img = PIL.Image.fromarray(img)
    save_img.save(str(pathlib.Path(images_dir, f'{i}.png')))

for i, ground_truth in tqdm.tqdm(enumerate(test_labels), total=len(test_labels)):
    mpimg.imsave(str(pathlib.Path(ground_truth_dir, f'{i}.png')), ground_truth, cmap='gray')

for i, predicted_img in tqdm.tqdm(enumerate(predicted_labels), total=len(predicted_labels)):
    mpimg.imsave(str(pathlib.Path(predictions_dir, f'{i}.png')), predicted_img, cmap='gray')