In [1]:
import sys, json, os, argparse
import warnings
from skimage.io import imsave, imread
from skimage import img_as_ubyte
import os.path as osp
import pandas as pd
from tqdm import tqdm
from tqdm import trange
import numpy as np
import torch
from models.get_model import get_arch
from utils.get_loaders import get_inference_seg_loader

from utils.model_saving_loading import str2bool, load_model
from utils.reproducibility import set_seeds
from scipy.ndimage import binary_fill_holes as bfh
from skimage.transform import resize
from skimage import img_as_float
from skimage.segmentation import find_boundaries
from skimage.color import gray2rgb, label2rgb
from skimage.morphology import dilation, square
from scipy.ndimage import zoom
from skimage.measure import label

In [2]:
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu as threshold

In [3]:
def dice_score(actual, predicted):
# If the ground truth has a lesion and the prediction does not, the score is 0. False negative case.
# If the ground truth has no lesion and the prediction has a lesion, the score 0. False positive case. 
# If the ground truth has no lesion and the prediction has no lesion, the score is 1. True negative case. 
    actual = np.asarray(actual).astype(bool)
    predicted = np.asarray(predicted).astype(bool)
    im_sum = actual.sum() + predicted.sum()
    
    if actual.sum() != 0 and predicted.sum() == 0: return 0 
    if actual.sum() == 0 and predicted.sum() != 0: return 0
    if im_sum == 0: return 1
    
    intersection = np.logical_and(actual, predicted)
    return 2. * intersection.sum() / im_sum

In [4]:
def mark_boundaries_ad(image, label_img, color=(1, 1, 0), outline_color=None, rad=8, mode='outer', background_label=0):
    marked = img_as_float(image, force_copy=True)
    if marked.ndim == 2:
        marked = gray2rgb(marked)
    if mode == 'subpixel':
        # Here, we want to interpose an extra line of pixels between
        # each original line - except for the last axis which holds
        # the RGB information. ``ndi.zoom`` then performs the (cubic)
        # interpolation, filling in the values of the interposed pixels
        marked = zoom(marked, [2 - 1/s for s in marked.shape[:-1]] + [1], mode='reflect')
    boundaries = find_boundaries(label_img, mode=mode,
                                 background=background_label)
    if outline_color is not None:
        outlines = dilation(boundaries, square(rad))
        marked[outlines] = outline_color
    marked[boundaries] = color
    return marked

In [5]:
from skimage.io import imsave
from skimage import img_as_ubyte
from skimage.transform import resize
from skimage import io

In [6]:
def print_cv_results(base_path):
    avg_dice, avg_std = np.zeros((5,)), np.zeros((5,))
    for fold in range(5):
        file_name = osp.join(base_path + str(fold+1), 'val_metrics.txt')
        f = open(file_name, "r").read()
        dsc = f.split('|')[1].split(',')[0]
        dsc_avg = float(dsc[:5])
        dsc_std = float(dsc[-5:])
        
        avg_dice[fold] = dsc_avg
        avg_std[fold] = dsc_std
    print('Average DSC = {:.2f}+/-{:.2f}'.format(avg_dice.mean(), avg_dice.std()))

In [7]:
def evaluate_test_preds(model, list_folders, test_loader):
    print('* Evaluating test predictions')
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    checkpoint_list = [osp.join(l, 'model_checkpoint.pth') for l in list_folders]
    states = [torch.load(c, map_location=device) for c in checkpoint_list]
    print('* States loaded')
    
    with torch.no_grad():
        model.to(device)
        
        dice_scores = []
        tq_loader = tqdm(enumerate(test_loader), total=len(test_loader))
          
        rejected = 0
        for i_batch, (inputs, targets, _, _) in tq_loader:
            ensembled_probs = []
            inputs = inputs.to(device)
            for state in states:
                model.load_state_dict(state['model_state_dict'])
                model.eval()

                logits = model(inputs)
                probs = logits.sigmoid().detach().cpu()        
                ensembled_probs.append(probs)

            probs = torch.mean(torch.stack(ensembled_probs, dim=0), dim=0)   

            for j in range(len(probs)):
                segmentation = probs[j].numpy()[0]
                target = targets[j].numpy()     
                segmentation_bin = segmentation > 0.5
            
                d = dice_score(target>0.5, segmentation_bin)
                dice_scores.append(d)

                if np.count_nonzero(segmentation_bin) == 0:
                    rejected += 1
        return np.mean(dice_scores), rejected/len(test_loader.dataset)

In [9]:
def print_all(NAME, model_name):
    list_folders = os.listdir('experiments/endotect')
    list_folders = sorted([osp.join('experiments/endotect', n) for n in list_folders if NAME in n])

    ll_only_ce = [n for n in list_folders if 'only_ce' in n]
    ll_ce_combo_dice = [n for n in list_folders if 'ce_combo_dice' in n]
    ll_ce_linear_dice = [n for n in list_folders if 'ce_linear_dice' in n]
    ll_ce_finetune_dice = [n for n in list_folders if 'ce_finetune_dice' in n]
    ll_only_dice = [n for n in list_folders if 'only_dice' in n]
    
    print('--- 5-FOLD CROSS-VALIDATION AVERAGE RESULTS ----')
    print('only_ce:')
    print_cv_results(ll_only_ce[0][:-1])
    
    print('only_dice:')
    print_cv_results(ll_only_dice[0][:-1])  

    print('ce - ft - dice:')
    print_cv_results(ll_ce_finetune_dice[0][:-1])
    
    print('ce - lin - dice:')
    print_cv_results(ll_ce_linear_dice[0][:-1])
    
    print('ce+dice:')
    print_cv_results(ll_ce_combo_dice[0][:-1])

   
    bs = 4
    im_size = 480,640
    tg_size = (im_size[0], im_size[1])
    n_classes = 1
    print('* Instantiating a {} model'.format(model_name))
    model, mean, std = get_arch(model_name, n_classes=n_classes, pretrained=False)
    model.mode = 'eval'

    data_source = 'data_endotect/test.csv'
    test_loader = get_inference_seg_loader(data_source, batch_size=bs, mean=mean, std=std, tg_size=tg_size)
    
    print(50*'*')
    avg_dice, rejected_pctg = evaluate_test_preds(model, ll_only_ce, test_loader)
    print('ONLY CE: Average dice with OoD data = {:.3f} - \
    Percentage of rejected images = {:.3f}'.format(100*avg_dice, 100*rejected_pctg))
    print(50*'*')
    
    avg_dice, rejected_pctg = evaluate_test_preds(model, ll_only_dice, test_loader)
    print('ONLY DICE: Average dice with OoD data = {:.3f} - \
    Percentage of rejected images = {:.3f}'.format(100*avg_dice, 100*rejected_pctg))
    print(50*'*')

    avg_dice, rejected_pctg = evaluate_test_preds(model, ll_ce_finetune_dice, test_loader)
    print('CE ft DICE: Average dice with OoD data = {:.3f} - \
    Percentage of rejected images = {:.3f}'.format(100*avg_dice, 100*rejected_pctg))
    print(50*'*')

    avg_dice, rejected_pctg = evaluate_test_preds(model, ll_ce_linear_dice, test_loader)
    print('CE lin DICE: Average dice with OoD data = {:.3f} - \
    Percentage of rejected images = {:.3f}'.format(100*avg_dice, 100*rejected_pctg))
    print(50*'*')
    
    avg_dice, rejected_pctg = evaluate_test_preds(model, ll_ce_combo_dice, test_loader)
    print('CE+DICE: Average dice with OoD data = {:.3f} - \
    Percentage of rejected images = {:.3f}'.format(100*avg_dice, 100*rejected_pctg))
    print(50*'*')

## Resnet18

In [10]:
NAME = 'resnet18'
model_name = 'fpnet_resnet18_W'
print_all(NAME, model_name)

--- 5-FOLD CROSS-VALIDATION AVERAGE RESULTS ----
only_ce:
Average DSC = 89.19+/-1.18
only_dice:
Average DSC = 88.83+/-1.04
ce - ft - dice:
Average DSC = 89.06+/-1.30
ce - lin - dice:
Average DSC = 89.76+/-1.12
ce+dice:
Average DSC = 89.81+/-1.16
* Instantiating a fpnet_resnet18_W model
**************************************************
* Evaluating test predictions
* States loaded


100%|████████████████████████████████████████████████████████████████| 55/55 [00:12<00:00,  4.26it/s]


ONLY CE: Average dice with OoD data = 87.333 -     Percentage of rejected images = 6.364
**************************************************
* Evaluating test predictions
* States loaded


100%|████████████████████████████████████████████████████████████████| 55/55 [00:12<00:00,  4.38it/s]


ONLY DICE: Average dice with OoD data = 81.914 -     Percentage of rejected images = 1.818
**************************************************
* Evaluating test predictions
* States loaded


100%|████████████████████████████████████████████████████████████████| 55/55 [00:12<00:00,  4.38it/s]


CE ft DICE: Average dice with OoD data = 86.616 -     Percentage of rejected images = 5.909
**************************************************
* Evaluating test predictions
* States loaded


100%|████████████████████████████████████████████████████████████████| 55/55 [00:12<00:00,  4.36it/s]


CE lin DICE: Average dice with OoD data = 86.939 -     Percentage of rejected images = 5.455
**************************************************
* Evaluating test predictions
* States loaded


100%|████████████████████████████████████████████████████████████████| 55/55 [00:12<00:00,  4.36it/s]

CE+DICE: Average dice with OoD data = 84.889 -     Percentage of rejected images = 4.091
**************************************************





## MobileNet

In [None]:
NAME = 'mobilenet'
model_name = 'fpnet_mobilenet_W'
print_all(NAME, model_name)

## Resnet34

In [None]:
NAME = 'resnet34'
model_name = 'fpnet_resnet34_W'
print_all(NAME, model_name)

## Resnet50

In [None]:
NAME = 'resnet50'
model_name = 'fpnet_resnet50_W'
print_all(NAME, model_name)

## Resnext50

In [None]:
NAME = 'resnext50'
model_name = 'fpnet_resnext50_W_imagenet'
print_all(NAME, model_name)

## Resnet101

In [None]:
NAME = 'resnet101'
model_name = 'fpnet_resnet101_W'
print_all(NAME, model_name)

## Resnext101

In [None]:
NAME = 'resnext101'
model_name = 'fpnet_resnext101_32x4d_W_ssl'
print_all(NAME, model_name)

## Resnet152

In [None]:
NAME = 'resnet152'
model_name = 'fpnet_resnet152_W'
print_all(NAME, model_name)