In [1]:
import os
import cv2
import tqdm
import time
import string
import pathlib
import numpy as np
import pandas as pd

from sklearn.metrics import jaccard_score

from typing import List
from matplotlib import pyplot as plt

from utils.rgb import mask2rgb
from utils.prediction.evaluations import visualize, preload_image_data
from utils.prediction.predict import Prediction

# Logging
from utils.logging import logging

log = logging.getLogger(__name__)
log.setLevel(logging.INFO)

  from .autonotebook import tqdm as notebook_tqdm


## Variables

In [2]:
metrics_models = [
    { 
        'model_name': 'U-Net 256x256',
        'model_path': r'checkpoints/avid-forest-323/best-checkpoint.pth.tar',
        'patch_size': 256,
    },
    { 
        'model_name': 'U-Net 512x512',
        'model_path': r'checkpoints/helpful-sky-334/best-checkpoint.pth.tar',
        'patch_size': 512,
    },
   { 
        'model_name': 'U-Net 640x640',
        'model_path': r'checkpoints/graceful-snowball-337/best-checkpoint.pth.tar',
        'patch_size': 640,
    }, 
    { 
        'model_name': 'U-Net 768x768',
        'model_path': r'checkpoints/silvery-serenity-371/best-checkpoint.pth.tar',
        'patch_size': 768,
    },
    { 
        'model_name': 'U-Net 800x800',
        'model_path': r'checkpoints/kind-totem-369/best-checkpoint.pth.tar',
        'patch_size': 800,
    },
    { 
        'model_name': 'U-Net 864x864',
        'model_path': r'checkpoints/swept-field-374/best-checkpoint.pth.tar',
        'patch_size': 864,
    },
    { 
        'model_name': 'U-Net 960x960',
        'model_path': r'checkpoints/giddy-leaf-375/best-checkpoint.pth.tar',
        'patch_size': 960,
    },
    { 
        'model_name': 'U-Net 1088x1088',
        'model_path': r'checkpoints/masked-orb-376/best-checkpoint.pth.tar',
        'patch_size': 1088,
    },
]
metrics_model_index = 7

metrics_output = pathlib.Path('metrics_output')
model_metrics_output = pathlib.Path(metrics_output, metrics_models[metrics_model_index]['model_name'])

# Create directory if it doesn't exists
if not os.path.isdir(model_metrics_output):
    os.makedirs(model_metrics_output)

## Util Functions

In [3]:
log.info('[DATA]: Started preloading test images and labels!')
test_imgs, _ = preload_image_data(r'data', r'imgs', False, metrics_models[metrics_model_index]['patch_size'])
test_labels, test_label_names = preload_image_data(r'data', r'imgs', True, metrics_models[metrics_model_index]['patch_size'])

[DATA]: Started preloading test images and labels!


## Model prediction

In [4]:
model_params = {
    'model_name': metrics_models[metrics_model_index]['model_path'],
    'patch_width': metrics_models[metrics_model_index]['patch_size'],
    'patch_height': metrics_models[metrics_model_index]['patch_size'],
    'n_channels': 3,
    'n_classes': 3
}
model = Prediction(model_params)
model.initialize()

log.info('[PREDICTION]: Model loaded!')
log.info(f'[PREDICTION]: Starting prediction on {len(test_imgs)} image(s).')

predicted_labels = []
img_process_time_list = []
m_ious = []

batch_start_time = time.time()
pbar = tqdm.tqdm(enumerate(test_imgs), total=len(test_imgs))
for i, img in pbar:
    img_start_time = time.time()
    mask_predict = model.predict_image(img)
    img_process_time = time.time() - img_start_time

    predicted_labels.append(mask_predict)
    img_process_time_list.append(img_process_time * 1000)

pbar.close()
batch_process_time = time.time() - batch_start_time


[PREDICTION]: Loading model checkpoints/masked-orb-376/best-checkpoint.pth.tar
[PREDICTION]: Model loaded!
[PREDICTION]: Starting prediction on 683 image(s).


100%|██████████| 683/683 [01:38<00:00,  6.96it/s]


### Getting Metrics

In [5]:
# Vars
dice_scores = []

log.info('[METRICS]: Started calculating Jacaard Index!')

pbar = tqdm.tqdm(enumerate(test_labels), total=len(test_labels))
for i, label in pbar:
    # # Fire
    # ground_truth_fire = cv2.inRange(label, 1, 1)
    # prediction_fire = cv2.inRange(predicted_labels[i], 1, 1)

    # gt_mapped_fire = ground_truth_fire.flatten().astype('float') / 255
    # pred_mapped_fire = prediction_fire.flatten().astype('float') / 255

    # dice_scores.append(jaccard_score(gt_mapped_fire, pred_mapped_fire))

    # Smoke
    ground_truth_smoke = cv2.inRange(label, 2, 2)
    prediction_smoke = cv2.inRange(predicted_labels[i], 2, 2)

    gt_mapped_smoke = ground_truth_smoke.flatten().astype('float') / 255
    pred_mapped_smoke = prediction_smoke.flatten().astype('float') / 255

    dice_scores.append(jaccard_score(gt_mapped_smoke, pred_mapped_smoke))

pbar.close()

data = {}
data['Image'] = test_label_names
data[metrics_models[metrics_model_index]['model_name']] = dice_scores

df = pd.DataFrame(data)
print(df)

# Global Vars
# gt_flatten_fire, pred_flatten_fire = np.asarray(gt_list_fire).flatten(), np.asarray(pred_list_fire).flatten()
# gt_flatten_smoke, pred_flatten_smoke = np.asarray(gt_list_smoke).flatten(), np.asarray(pred_list_smoke).flatten()

# Remove Unused Data From Memory
# log.info('[METRICS]: Started converting RGB masks to binary masks!')
# del test_imgs
# del test_labels

# Calculate Metrics
log.info('[METRICS]: Done calculating Jacaard Index!')

[METRICS]: Started calculating Jacaard Index!


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

                                             Image  U-Net 1088x1088
0           flame-11b63f62f7620ee5b62ccc02f0435a25         0.000000
1           flame-ac20ece1c160bf1b9e910fb516c99ddc         0.240769
2           flame-b83dca6a6d2b44c2bb4199e66fdf264d         0.221305
3           flame-18a5b7523c334b32822be34413e6080e         0.000000
4           flame-8ae2ca32ce1f24c6b279aecc3afb3971         0.000000
..                                             ...              ...
678  normal_state-8e485f6d814f2f054463045b4696d8b3         0.000000
679  normal_state-2bdd920269fe65951436573772398d32         0.000000
680  normal_state-b80759dc9dbac8a6d5bc54b8e597c026         0.000000
681  normal_state-c0886c95d0b091d9d8813610ebf8e948         0.000000
682  normal_state-1efab71ef6bc6c07d0f515e397068c32         0.000000

[683 rows x 2 columns]
[METRICS]: Done calculating Jacaard Index!





### Saving report data to Excel

In [6]:
print(df)

if not os.path.isfile(f'{str(model_metrics_output.resolve())}/dice_score_{metrics_models[metrics_model_index]["model_name"]}.xlsx'):
    with pd.ExcelWriter(f'{str(model_metrics_output.resolve())}/dice_score_{metrics_models[metrics_model_index]["model_name"]}.xlsx', engine='openpyxl') as writer:
        df.to_excel(writer, sheet_name="DiceScore", index = False)
else:
    with pd.ExcelWriter(f'{str(model_metrics_output.resolve())}/dice_score_{metrics_models[metrics_model_index]["model_name"]}.xlsx', engine='openpyxl', mode='a', if_sheet_exists='overlay') as writer:
        df.to_excel(writer, sheet_name="DiceScore", index = False)

                                             Image  U-Net 1088x1088
0           flame-11b63f62f7620ee5b62ccc02f0435a25         0.000000
1           flame-ac20ece1c160bf1b9e910fb516c99ddc         0.240769
2           flame-b83dca6a6d2b44c2bb4199e66fdf264d         0.221305
3           flame-18a5b7523c334b32822be34413e6080e         0.000000
4           flame-8ae2ca32ce1f24c6b279aecc3afb3971         0.000000
..                                             ...              ...
678  normal_state-8e485f6d814f2f054463045b4696d8b3         0.000000
679  normal_state-2bdd920269fe65951436573772398d32         0.000000
680  normal_state-b80759dc9dbac8a6d5bc54b8e597c026         0.000000
681  normal_state-c0886c95d0b091d9d8813610ebf8e948         0.000000
682  normal_state-1efab71ef6bc6c07d0f515e397068c32         0.000000

[683 rows x 2 columns]
