In [1]:
import numpy as np

from keras_unet_models import wuunet, unet
from keras_utils import jaccard_acc, bin_jaccard_acc, bce_jaccard_loss, bce_bin_jaccard_loss
from keras_dataset import Dataset, augm_train, augm_val
from keras_validation import Validation

from tensorflow.python.client import device_lib

import json

print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 6614445994182473656
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 7416424704
locality {
  bus_id: 1
  links {
  }
}
incarnation: 7161694883038686479
physical_device_desc: "device: 0, name: GeForce RTX 2070, pci bus id: 0000:1c:00.0, compute capability: 7.5"
]


In [2]:
%matplotlib inline
np.set_printoptions(precision=49)

In [3]:
import os

def get_snapshots(ml_model_name):
    all_snapshots = os.listdir('output/keras/' + ml_model_name + '/snapshots')
    snapshots_map = {
        int(snapshot[snapshot.rfind('_') + 1 : snapshot.rfind('.')]) : snapshot 
        for snapshot in all_snapshots
    }
    return snapshots_map.values()

In [4]:
def evaluate():
    test_model_metrics = {}
    for dataset_variant in DATASET_VARIANTS:
        for ml_model_name in ML_MODELS:
            print('================================================')
            print(ml_model_name)
            jb_mean_max, jb_mean_max_thr, jb_mean_max_model = 0.0, 0.0, ''
            jb_std_min,  jb_std_min_thr,  jb_std_min_model  = 1.0, 0.0, ''
            jm_mean_max, jm_mean_max_thr, jm_mean_max_model = 0.0, 0.0, ''
            jm_std_min,  jm_std_min_thr,  jm_std_min_model  = 1.0, 0.0, ''
            validation = Validation(dataset_variant)
            snapshots = get_snapshots(ml_model_name)
            for i, snapshot_variant in enumerate(snapshots):
                print('-----------------------------------------------')
                print('{}/{}'.format(i, len(snapshots)))
                print(ml_model_name + ": " + snapshot_variant)
                model_full_name = ml_model_name + '_' + snapshot_variant
                thr, jm_mean, jm_std, jb_mean, jb_std = validation.calculate_thresholds(
                    ml_model_name, snapshot_variant
                )
                test_model_metrics[model_full_name] = (thr, jm_mean, jm_std, jb_mean, jb_std)
                if (jb_mean_max < jb_mean):
                    jb_mean_max = jb_mean
                    jb_mean_max_thr = thr
                    jb_mean_max_model = model_full_name
                if (jb_std_min > jb_std):
                    jb_std_min = jb_std
                    jb_std_min_thr = thr
                    jb_std_min_model = model_full_name
                if (jm_mean_max < jm_mean):
                    jm_mean_max = jm_mean
                    jm_mean_max_thr = thr
                    jm_mean_max_model = model_full_name
                if (jm_std_min > jm_std):
                    jm_std_min = jm_std
                    jm_std_min_thr = thr
                    jm_std_min_model = model_full_name
            #validation.draw_barchart(test_model_metrics)
            print('================================================')
            print('SUMMARY:')
            print('Jaccard bin  mean max: name = {}, value = {}, thr = {}'.format(jb_mean_max_model, jb_mean_max, jb_mean_max_thr))
            print('Jaccard bin  std min:  name = {}, value = {}, thr = {}'.format(jb_std_min_model,  jb_std_min, jb_std_min_thr))
            print('Jaccard mult mean max: name = {}, value = {}, thr = {}'.format(jm_mean_max_model, jm_mean_max, jm_mean_max_thr))
            print('Jaccard mult std min:  name = {}, value = {}, thr = {}'.format(jm_std_min_model,  jm_std_min, jm_std_min_thr))
            print('================================================')

        test_model_metrics_json = {
            key: [str(test_model_metrics[key][0]), str(test_model_metrics[key][1]), str(test_model_metrics[key][2]), 
                  str(test_model_metrics[key][3]), str(test_model_metrics[key][4])]
            for key in test_model_metrics
        }

        json.dump(test_model_metrics_json, 
                  open('output/evaluation/ow_224/{}_validation_metrics.json'.format(dataset_variant), 'w'))

In [5]:
ML_MODELS = [
    'unet_224_ow',
    'wuunet_224_ow',
    'wuunet_light_224_ow'
]


INPUT_SIZE = (224, 224)

DATASET_VARIANTS = ['test']  #['train_test']

COMBINATION_TYPES = ['nonint']
print(COMBINATION_TYPES)

evaluate()

['nonint']
unet_224_ow
-----------------------------------------------
0/54
unet_224_ow: best_val_loss_51.hdf5
[unet_224_ow_best_val_loss_51.hdf5, threshold = 0.27]:
MULTICLASS JACCARD [val = 0.8227515816688538, std = 0.10845281928777695]
BINARY JACCARD [val = 0.8654728531837463, std = 0.09086626023054123]

-----------------------------------------------
1/54
unet_224_ow: best_val_acc_60.hdf5
[unet_224_ow_best_val_acc_60.hdf5, threshold = 0.22]:
MULTICLASS JACCARD [val = 0.8208708763122559, std = 0.11016514897346497]
BINARY JACCARD [val = 0.8616129159927368, std = 0.09434057772159576]

-----------------------------------------------
2/54
unet_224_ow: best_val_acc_11.hdf5
[unet_224_ow_best_val_acc_11.hdf5, threshold = 0.96]:
MULTICLASS JACCARD [val = 0.6780094504356384, std = 0.1659507155418396]
BINARY JACCARD [val = 0.7059192061424255, std = 0.15229350328445435]

-----------------------------------------------
3/54
unet_224_ow: best_val_acc_343.hdf5
[unet_224_ow_best_val_acc_343.hdf5, 

[wuunet_light_224_ow_best_val_mult_acc_14.hdf5, threshold = 0.37]:
MULTICLASS JACCARD [val = 0.7010120749473572, std = 0.15753582119941711]
BINARY JACCARD [val = 0.7743555903434753, std = 0.13074597716331482]

-----------------------------------------------
6/85
wuunet_light_224_ow: best_val_mult_acc_13.hdf5
[wuunet_light_224_ow_best_val_mult_acc_13.hdf5, threshold = 0.34]:
MULTICLASS JACCARD [val = 0.69073086977005, std = 0.17395810782909393]
BINARY JACCARD [val = 0.7620433568954468, std = 0.14487461745738983]

-----------------------------------------------
7/85
wuunet_light_224_ow: best_val_bin_acc_279.hdf5
[wuunet_light_224_ow_best_val_bin_acc_279.hdf5, threshold = 0.6900000000000001]:
MULTICLASS JACCARD [val = 0.8504899740219116, std = 0.08713644742965698]
BINARY JACCARD [val = 0.8894743919372559, std = 0.06899359822273254]

-----------------------------------------------
8/85
wuunet_light_224_ow: best_val_mult_acc_373.hdf5
[wuunet_light_224_ow_best_val_mult_acc_373.hdf5, threshol