In [1]:
%load_ext autoreload
%autoreload 2
from Config import *
from Core.dataset import DatasetFactory
from Core.tasks import Segmentation
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from Core import metrics
from skimage.measure import label
import gc

In [2]:
holdout = '/media/hdd/Kaggle/Pneumothorax/Data/Folds/holdout.csv'
configs = [
    # all images
    PANetDilatedResNet34_768_Fold0(),
    PANetDilatedResNet34_768_Fold1(),
    PANetDilatedResNet34_768_Fold3(),
    PANetDilatedResNet34_768_Fold4(),
    PANetResNet50_768_Fold0(),
    PANetResNet50_768_Fold1(),
    PANetResNet50_768_Fold2(),
    PANetResNet50_768_Fold3(),
    EMANetResNet101_768_Fold7(),
    EMANetResNet101_768_Fold8(),
]

for j, cfg in enumerate(configs):
    dataset = DatasetFactory(
        holdout,
        # cfg.train.csv_path,
        cfg)
    val_loader = dataset.yield_loader(is_test=True)
    net = cfg.model.architecture(pretrained=cfg.model.pretrained)

    trainer = Segmentation(net,
                           mode='test',
                           criterion=cfg.loss,
                           debug=False,
                           fold=cfg.fold)

    assert cfg.model.weights is not None, 'Weights is None!!'
    trainer.load_model(cfg.model.weights)
    _, _, pred, mask = trainer.predict(val_loader, cfg.test.TTA, raw=True, tgt_size=768, pbar=True)
    if not j:
        mean_pred = pred
    else:
        mean_pred = (j * mean_pred + pred) / (j + 1)
    gc.collect()

Fold 0 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold0.csv
Fold 1 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold1.csv
Fold 3 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold3.csv
Fold 4 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold4.csv
Fold 0 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold0.csv
Fold 1 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold1.csv
Fold 2 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold2.csv
Fold 3 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold3.csv
Fold 7 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold7.csv
Fold 8 selected
/media/hdd/Kaggle/Pneumothorax/Data/Folds/fold8.csv
Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetDilatedResNet34/Fold0/18Jul_14:21_v2_768/2019-07-18 17:21_Fold0_Epoch16_reset0_val0.833


36it [01:54,  2.62s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetDilatedResNet34/Fold1/21Jul_22:18/2019-07-22 07:32_Fold1_Epoch49_reset0_val0.847


36it [01:59,  2.80s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetDilatedResNet34/Fold3/31Jul_00:02/2019-07-31 03:11_Fold3_Epoch16_reset0_val0.851


36it [01:45,  2.48s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetDilatedResNet34/Fold4/01Aug_08:49/2019-08-01 11:08_Fold4_Epoch11_reset0_val0.855


36it [01:48,  2.56s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetResNet50/Fold0/29Jul_08:45/2019-07-29 19:12_Fold0_Epoch39_reset0_val0.840


89it [02:01,  1.26s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetResNet50/Fold1/28Jul_23:16/2019-07-29 08:21_Fold1_Epoch34_reset0_val0.842


89it [02:01,  1.28s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetResNet50/Fold2/28Jul_21:21/2019-07-28 22:58_Fold2_Epoch6_reset0_val0.842


89it [01:58,  1.28s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/PANetResNet50/Fold3/28Jul_01:57/2019-07-28 09:59_Fold3_Epoch30_reset0_val0.854


89it [02:02,  1.31s/it]                                       


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/EMANetResNet101_v2/Fold7/14Aug_23:07/2019-08-15 13:15_Fold7_Epoch39_reset0_dice0.831_ptx_dice0.402


134it [02:40,  1.04it/s]                             


Dataset comprised of 1067 images
Model checkpoint loaded from: ./Saves/EMANetResNet101_v2/Fold8/16Aug_00:00/2019-08-16 10:36_Fold8_Epoch29_reset0_dice0.836_ptx_dice0.479


134it [02:55,  1.06s/it]                             


## Get masks

In [None]:
# ths = 0.2
# ptx_ix = np.where(mask.reshape(mask.shape[0], -1).sum(1) > 0)[0]
# mean_pred = mean_pred[ptx_ix]
# mask = mask[ptx_ix]
# print('{} ptx images'.format(len(ptx_ix)))

In [3]:
def do_validation(prediction, ground_truth, ths=0.2, cls_ths=0.6, noise_th=75.0 * (384 / 128.0) ** 2, mix_noise=None):
    pred = prediction.copy()
    gt = ground_truth.copy()

    if noise_th is not None:
        pred[pred.reshape(pred.shape[0], -1).sum(-1) < noise_th, ...] = 0.0
    pred_seg = (pred > ths)
    if cls_ths is not None:
        pred_cls = (pred > cls_ths)
        pred_seg[pred_cls.reshape(pred.shape[0], -1).sum(1) == 0] = 0.0
    if mix_noise is not None:    
        filter_pred = pred * pred_seg
        pred_seg[(filter_pred.reshape(filter_pred.shape[0], -1).sum(-1) / (pred_seg.reshape(pred_seg.shape[0], -1).sum(-1) + 1)) < mix_noise, ...] = 0.0
        
    print((pred_seg.sum(1).sum(1) > 0).mean())
    dice, ptx_dice = metrics.cmp_dice(pred_seg, gt)
    acc = metrics.cmp_cls_acc(pred_seg, gt).mean()
    print('Dice: {:4f}\tPtx_Dice: {:4f}\tAcc: {:4f}'.format(dice.mean(), ptx_dice.mean(), acc))
    return dice

In [11]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=500)

0.23617619493908154
Dice: 0.846011	Ptx_Dice: 0.527889	Acc: 0.915651


array([1., 1., 1., ..., 1., 0., 1.])

In [12]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=750)

0.22024367385192128
Dice: 0.851952	Ptx_Dice: 0.516801	Acc: 0.916589


array([1., 1., 1., ..., 1., 0., 1.])

In [13]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=900)

0.2099343955014058
Dice: 0.855791	Ptx_Dice: 0.508872	Acc: 0.917526


array([1., 1., 1., ..., 1., 0., 1.])

In [9]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=1000)

0.20337394564198688
Dice: 0.857484	Ptx_Dice: 0.503896	Acc: 0.916589


array([1., 1., 1., ..., 1., 0., 1.])

In [14]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=1250)

0.18837863167760074
Dice: 0.857908	Ptx_Dice: 0.484948	Acc: 0.910965


array([1., 1., 1., ..., 1., 0., 1.])

In [8]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=1500)

0.1780693533270853
Dice: 0.856838	Ptx_Dice: 0.467693	Acc: 0.906279


array([1., 1., 1., ..., 1., 0., 1.])

In [6]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=2000)

0.15932521087160262
Dice: 0.855373	Ptx_Dice: 0.436177	Acc: 0.898782


array([1., 1., 1., ..., 1., 0., 1.])

In [7]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=2500)

0.13964386129334583
Dice: 0.854372	Ptx_Dice: 0.398396	Acc: 0.894096


array([1., 1., 1., ..., 1., 0., 1.])

In [22]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.3, noise_th=None)

0.21087160262417995
Dice: 0.854859	Ptx_Dice: 0.513059	Acc: 0.914714


array([1., 1., 1., ..., 1., 0., 1.])

In [23]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.35, noise_th=None)

0.19681349578256796
Dice: 0.859397	Ptx_Dice: 0.499901	Acc: 0.915651


array([1., 1., 1., ..., 1., 0., 1.])

In [16]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.4, noise_th=None)

0.1893158388003749
Dice: 0.860697	Ptx_Dice: 0.489018	Acc: 0.915651


array([1., 1., 1., ..., 1., 0., 1.])

In [17]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.45, noise_th=None)

0.17994376757263356
Dice: 0.859728	Ptx_Dice: 0.476373	Acc: 0.910028


array([1., 1., 1., ..., 1., 0., 1.])

In [18]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.5, noise_th=None)

0.1668228678537957
Dice: 0.858604	Ptx_Dice: 0.454708	Acc: 0.904405


array([1., 1., 1., ..., 1., 0., 1.])

In [20]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.55, noise_th=None)

0.15463917525773196
Dice: 0.856911	Ptx_Dice: 0.430517	Acc: 0.899719


array([1., 1., 1., ..., 1., 0., 1.])

In [21]:
do_validation(mean_pred, mask, 0.2, cls_ths=0.6, noise_th=None)

0.14901593252108716
Dice: 0.854874	Ptx_Dice: 0.417295	Acc: 0.895970


array([1., 1., 1., ..., 1., 0., 1.])

In [24]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.1)

0.2436738519212746
Dice: 0.842145	Ptx_Dice: 0.531536	Acc: 0.913777


array([1., 1., 1., ..., 1., 0., 1.])

In [25]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.25)

0.20524835988753515
Dice: 0.855272	Ptx_Dice: 0.506565	Acc: 0.912840


array([1., 1., 1., ..., 1., 0., 1.])

In [26]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.30)

0.18088097469540768
Dice: 0.859170	Ptx_Dice: 0.478060	Acc: 0.909091


array([1., 1., 1., ..., 1., 0., 1.])

In [34]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.325)

0.17057169634489222
Dice: 0.860882	Ptx_Dice: 0.464839	Acc: 0.908154


array([1., 1., 1., ..., 1., 0., 1.])

In [27]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.35)

0.15651358950328023
Dice: 0.858576	Ptx_Dice: 0.437921	Acc: 0.901593


array([1., 1., 1., ..., 1., 0., 1.])

In [28]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.40)

0.13495782567947517
Dice: 0.854274	Ptx_Dice: 0.393791	Acc: 0.891284


array([1., 1., 1., ..., 1., 0., 1.])

In [29]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.45)

0.11715089034676664
Dice: 0.845534	Ptx_Dice: 0.346603	Acc: 0.877226


array([1., 1., 1., ..., 1., 0., 1.])

In [30]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.475)

0.10871602624179943
Dice: 0.842031	Ptx_Dice: 0.326862	Acc: 0.870665


array([1., 1., 1., ..., 1., 0., 1.])

In [31]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.5)

0.09559512652296158
Dice: 0.835267	Ptx_Dice: 0.292625	Acc: 0.859419


array([1., 1., 1., ..., 1., 0., 1.])

In [32]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.55)

0.07403936269915651
Dice: 0.826670	Ptx_Dice: 0.241903	Acc: 0.843486


array([1., 1., 1., ..., 1., 0., 1.])

In [33]:
do_validation(mean_pred, mask, 0.2, cls_ths=None, noise_th=None, mix_noise=0.6)

0.053420805998125584
Dice: 0.812740	Ptx_Dice: 0.179974	Acc: 0.822868


array([1., 1., 1., ..., 1., 0., 1.])

In [None]:
do_validation(mean_pred, mask_vec, 0.25, cls_ths=None, noise_th=None, mix_noise=0.75)

In [None]:
do_validation(mean_pred, mask_vec, 0.2, cls_ths=None, noise_th=None, mix_noise=0.55)

In [None]:
def split_instances(pred, bin_ths=0.25, connectivity=2, min_size=10, out_value=1):
    mask = pred.copy()
    B, H, W = mask.shape
    bin_mask = (mask > bin_ths).astype(np.int)
    out = []
    for i in range(B):
        pred_i, num = label(bin_mask[i], connectivity=connectivity, return_num=True)
        pred_i_split = []
        # Remove small labels
        for j in range(1, num + 1):
            color_j = (pred_i == j)
            area_j = mask[i, color_j].sum() / color_j.sum()
            if area_j > min_size:
                pred_i_split.append(color_j)
        out.append(np.max(pred_i_split, axis=0) if len(pred_i_split) > 0 else np.zeros((H, W)))
    return np.stack(out)

In [None]:
fig, axs = plt.subplots(len(ptx_ix) // 2, 2, figsize=(16 , 4 * len(ptx_ix) // 2))
for i in range(len(ptx_ix) // 2):
    axs[i, 0].imshow(mean_pred[i].astype(np.float32), cmap=plt.cm.bone)
#     axs[i, 1].imshow(mean_pred[i] > ths)
    axs[i, 1].imshow(split_instances(mean_pred[i:i+1], min_size=0.4)[0])
    axs[i, 1].imshow(mask[i].astype(np.float32), alpha=0.4, cmap="Reds")
    
plt.show()