In [9]:
import albumentations as A
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import torch
from albumentations.pytorch import ToTensorV2
from collections import defaultdict
from tqdm import tqdm

from modules import *
from networks import *
from training import *

SIZE = 224

transform = A.Compose([
    A.Resize(height=SIZE, width=SIZE, interpolation=cv.INTER_AREA),
    A.Lambda(image=sharpen, p=1.0),
    A.Lambda(image=polar_transform, mask=polar_transform),
    A.Normalize(),
    ToTensorV2(),
])
loader = load_dataset('../data/DRISHTI/ROI/TestImages', '../data/DRISHTI/ROI/TestMasks', transform, 1, shuffle=False)

Loaded dataset with 51 samples in 51 batches.


In [10]:
arch = 'swin'

path = rf"..\models\polar\{arch}\binary.pth"
checkpoint = load_checkpoint(path, map_location='cpu')
base_model = checkpoint['model']

path = rf"..\models\polar\{arch}\cascade.pth"
checkpoint = load_checkpoint(path, map_location='cpu')
model1 = checkpoint['model']
model1 = model1.eval()

path = rf"..\models\polar\{arch}\dual.pth"
checkpoint = load_checkpoint(path, map_location='cpu')
model2 = checkpoint['model']
model2 = model2.eval()

=> Loading checkpoint: ..\models\polar\swin\binary.pth
=> Loading checkpoint: ..\models\polar\swin\cascade.pth
=> Loading checkpoint: ..\models\polar\swin\dual.pth


In [11]:
print('Default')
res = evaluate(
    'dual',
    model2,
    loader,
    inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    # tta=True,
)
print(f'Dual {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

res = evaluate(
    'binary',
    base_model,
    loader,
    inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    # tta=True,
)
print(f'Binary {arch} - Dice OD: {res["dice_OD"]:.4f}')

res = evaluate(
    'cascade',
    model1,
    loader,
    base_model=base_model,
    inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    # tta=True,
)
print(f'Cascade {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

Default


Evaluating dual segmentation: 100%|██████████| 51/51 [00:58<00:00,  1.15s/it, accuracy_OC=0.949, accuracy_OD=0.96, balance_accuracy_OC=0.889, balance_accuracy_OD=0.962, dice_OC=0.802, dice_OD=0.926, fn_OC=1.86e+3, fn_OD=781, fnr_OC=0.206, fnr_OD=0.0447, fp_OC=721, fp_OD=1.23e+3, fpr_OC=0.0158, fpr_OD=0.0319, iou_OC=0.685, iou_OD=0.866, npv_OC=0.958, npv_OD=0.978, precision_OC=0.881, precision_OD=0.908, sensitivity_OC=0.794, sensitivity_OD=0.955, specificity_OC=0.984, specificity_OD=0.968, tn_OC=4.24e+4, tn_OD=3.6e+4, tp_OC=5.24e+3, tp_OD=1.21e+4]       


Dual swin - Dice OD: 0.9258, Dice OC: 0.8022


Evaluating binary segmentation: 100%|██████████| 51/51 [00:44<00:00,  1.15it/s, accuracy_OD=0.97, balance_accuracy_OD=0.958, dice_OD=0.942, fn_OD=1e+3, fnr_OD=0.07, fp_OD=509, fpr_OD=0.0136, iou_OD=0.894, npv_OD=0.973, precision_OD=0.962, sensitivity_OD=0.93, specificity_OD=0.986, tn_OD=3.68e+4, tp_OD=1.19e+4]       


Binary swin - Dice OD: 0.9421


Evaluating cascade segmentation: 100%|██████████| 51/51 [01:17<00:00,  1.52s/it, accuracy_OC=0.93, accuracy_OD=0.97, balance_accuracy_OC=0.801, balance_accuracy_OD=0.958, dice_OC=0.698, dice_OD=0.942, fn_OC=3.18e+3, fn_OD=1.01e+3, fnr_OC=0.391, fnr_OD=0.0699, fp_OC=336, fp_OD=508, fpr_OC=0.00727, fpr_OD=0.0136, iou_OC=0.55, iou_OD=0.894, npv_OC=0.931, npv_OD=0.973, precision_OC=0.933, precision_OD=0.962, sensitivity_OC=0.609, sensitivity_OD=0.93, specificity_OC=0.993, specificity_OD=0.986, tn_OC=4.27e+4, tn_OD=3.68e+4, tp_OC=3.92e+3, tp_OD=1.19e+4]    

Cascade swin - Dice OD: 0.9421, Dice OC: 0.6975





In [12]:
print('Postprocess')
res = evaluate(
    'dual',
    model2,
    loader,
    inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    # tta=True,
)
print(f'Dual {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

res = evaluate(
    'binary',
    base_model,
    loader,
    inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    # tta=True,
)
print(f'Binary {arch} - Dice OD: {res["dice_OD"]:.4f}')

res = evaluate(
    'cascade',
    model1,
    loader,
    base_model=base_model,
    inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    # tta=True,
)
print(f'Cascade {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

Postprocess


Evaluating dual segmentation: 100%|██████████| 51/51 [01:05<00:00,  1.28s/it, accuracy_OC=0.948, accuracy_OD=0.905, balance_accuracy_OC=0.894, balance_accuracy_OD=0.927, dice_OC=0.805, dice_OD=0.876, fn_OC=1.79e+3, fn_OD=722, fnr_OC=0.195, fnr_OD=0.04, fp_OC=796, fp_OD=4.04e+3, fpr_OC=0.0175, fpr_OD=0.106, iou_OC=0.689, iou_OD=0.801, npv_OC=0.96, npv_OD=0.979, precision_OC=0.872, precision_OD=0.838, sensitivity_OC=0.805, sensitivity_OD=0.96, specificity_OC=0.982, specificity_OD=0.894, tn_OC=4.23e+4, tn_OD=3.32e+4, tp_OC=5.31e+3, tp_OD=1.22e+4]        


Dual swin - Dice OD: 0.8757, Dice OC: 0.8049


Evaluating binary segmentation: 100%|██████████| 51/51 [00:46<00:00,  1.09it/s, accuracy_OD=0.96, balance_accuracy_OD=0.955, dice_OD=0.931, fn_OD=877, fnr_OD=0.0606, fp_OD=1.11e+3, fpr_OD=0.03, iou_OD=0.877, npv_OD=0.976, precision_OD=0.934, sensitivity_OD=0.939, specificity_OD=0.97, tn_OD=3.62e+4, tp_OD=1.2e+4]        


Binary swin - Dice OD: 0.9309


Evaluating cascade segmentation: 100%|██████████| 51/51 [01:09<00:00,  1.35s/it, accuracy_OC=0.928, accuracy_OD=0.955, balance_accuracy_OC=0.808, balance_accuracy_OD=0.952, dice_OC=0.699, dice_OD=0.923, fn_OC=3.07e+3, fn_OD=853, fnr_OC=0.373, fnr_OD=0.0587, fp_OC=515, fp_OD=1.4e+3, fpr_OC=0.0113, fpr_OD=0.0375, iou_OC=0.552, iou_OD=0.864, npv_OC=0.933, npv_OD=0.976, precision_OC=0.911, precision_OD=0.919, sensitivity_OC=0.627, sensitivity_OD=0.941, specificity_OC=0.989, specificity_OD=0.962, tn_OC=4.26e+4, tn_OD=3.59e+4, tp_OC=4.02e+3, tp_OD=1.2e+4]       

Cascade swin - Dice OD: 0.9225, Dice OC: 0.6992





In [5]:
print('TTA')
res = evaluate(
    'dual',
    model2,
    loader,
    # inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    tta=True,
)
print(f'Dual {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

res = evaluate(
    'binary',
    base_model,
    loader,
    # inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    tta=True,
)
print(f'Binary {arch} - Dice OD: {res["dice_OD"]:.4f}')

res = evaluate(
    'cascade',
    model1,
    loader,
    base_model=base_model,
    # inverse_transform=undo_polar_transform,
    # post_process_fn=postprocess,
    tta=True,
)
print(f'Cascade {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

TTA


Evaluating dual segmentation: 100%|██████████| 51/51 [07:17<00:00,  8.59s/it, accuracy_OC=0.968, accuracy_OD=0.976, balance_accuracy_OC=0.934, balance_accuracy_OD=0.965, dice_OC=0.874, dice_OD=0.954, fn_OC=1.26e+3, fn_OD=1.2e+3, fnr_OC=0.118, fnr_OD=0.0618, fp_OC=820, fp_OD=367, fpr_OC=0.014, fpr_OD=0.00725, iou_OC=0.784, iou_OD=0.915, npv_OC=0.978, npv_OD=0.975, precision_OC=0.896, precision_OD=0.976, sensitivity_OC=0.882, sensitivity_OD=0.938, specificity_OC=0.986, specificity_OD=0.993, tn_OC=5.55e+4, tn_OD=4.83e+4, tp_OC=8e+3, tp_OD=1.56e+4]     


Dual ref - Dice OD: 0.9540, Dice OC: 0.8741


Evaluating binary segmentation: 100%|██████████| 51/51 [04:01<00:00,  4.73s/it, accuracy_OD=0.966, balance_accuracy_OD=0.958, dice_OD=0.935, fn_OD=1.38e+3, fnr_OD=0.0686, fp_OD=823, fpr_OD=0.0164, iou_OD=0.883, npv_OD=0.972, precision_OD=0.95, sensitivity_OD=0.931, specificity_OD=0.984, tn_OD=4.79e+4, tp_OD=1.55e+4] 


Binary ref - Dice OD: 0.9353


Evaluating cascade segmentation: 100%|██████████| 51/51 [07:20<00:00,  8.64s/it, accuracy_OC=0.946, accuracy_OD=0.967, balance_accuracy_OC=0.914, balance_accuracy_OD=0.958, dice_OC=0.802, dice_OD=0.936, fn_OC=1.58e+3, fn_OD=1.35e+3, fnr_OC=0.139, fnr_OD=0.0671, fp_OC=1.93e+3, fp_OD=845, fpr_OC=0.0333, fpr_OD=0.0169, iou_OC=0.686, iou_OD=0.884, npv_OC=0.972, npv_OD=0.972, precision_OC=0.807, precision_OD=0.949, sensitivity_OC=0.861, sensitivity_OD=0.933, specificity_OC=0.967, specificity_OD=0.983, tn_OC=5.43e+4, tn_OD=4.79e+4, tp_OC=7.68e+3, tp_OD=1.55e+4]

Cascade ref - Dice OD: 0.9357, Dice OC: 0.8025





In [6]:
print('TTA + Postprocess')
res = evaluate(
    'dual',
    model2,
    loader,
    # inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    tta=True,
)
print(f'Dual {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

res = evaluate(
    'binary',
    base_model,
    loader,
    # inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    tta=True,
)
print(f'Binary {arch} - Dice OD: {res["dice_OD"]:.4f}')

res = evaluate(
    'cascade',
    model1,
    loader,
    base_model=base_model,
    # inverse_transform=undo_polar_transform,
    post_process_fn=postprocess,
    tta=True,
)
print(f'Cascade {arch} - Dice OD: {res["dice_OD"]:.4f}, Dice OC: {res["dice_OC"]:.4f}')

TTA + Postprocess


Evaluating dual segmentation: 100%|██████████| 51/51 [06:30<00:00,  7.66s/it, accuracy_OC=0.969, accuracy_OD=0.976, balance_accuracy_OC=0.935, balance_accuracy_OD=0.966, dice_OC=0.876, dice_OD=0.954, fn_OC=1.25e+3, fn_OD=1.18e+3, fnr_OC=0.117, fnr_OD=0.0612, fp_OC=813, fp_OD=386, fpr_OC=0.0139, fpr_OD=0.00759, iou_OC=0.786, iou_OD=0.915, npv_OC=0.978, npv_OD=0.975, precision_OC=0.897, precision_OD=0.975, sensitivity_OC=0.883, sensitivity_OD=0.939, specificity_OC=0.986, specificity_OD=0.992, tn_OC=5.55e+4, tn_OD=4.83e+4, tp_OC=8.02e+3, tp_OD=1.57e+4]


Dual ref - Dice OD: 0.9538, Dice OC: 0.8755


Evaluating binary segmentation: 100%|██████████| 51/51 [03:43<00:00,  4.38s/it, accuracy_OD=0.966, balance_accuracy_OD=0.957, dice_OD=0.934, fn_OD=1.38e+3, fnr_OD=0.0683, fp_OD=848, fpr_OD=0.0169, iou_OD=0.882, npv_OD=0.972, precision_OD=0.949, sensitivity_OD=0.932, specificity_OD=0.983, tn_OD=4.78e+4, tp_OD=1.55e+4]


Binary ref - Dice OD: 0.9344


Evaluating cascade segmentation: 100%|██████████| 51/51 [07:31<00:00,  8.84s/it, accuracy_OC=0.946, accuracy_OD=0.965, balance_accuracy_OC=0.913, balance_accuracy_OD=0.957, dice_OC=0.8, dice_OD=0.933, fn_OC=1.59e+3, fn_OD=1.37e+3, fnr_OC=0.139, fnr_OD=0.0675, fp_OC=1.97e+3, fp_OD=895, fpr_OC=0.0339, fpr_OD=0.0178, iou_OC=0.683, iou_OD=0.881, npv_OC=0.972, npv_OD=0.972, precision_OC=0.805, precision_OD=0.947, sensitivity_OC=0.861, sensitivity_OD=0.932, specificity_OC=0.966, specificity_OD=0.982, tn_OC=5.43e+4, tn_OD=4.78e+4, tp_OC=7.67e+3, tp_OD=1.55e+4]  

Cascade ref - Dice OD: 0.9333, Dice OC: 0.7998



