Train a UNet segmentation model on the Severstal steel defect dataset.

# Initialize and load data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.model.unet import build_unet_model
from steel_seg.model.deep_q_postprocessor import build_deep_q_model
from steel_seg.train import (
    class_weighted_binary_crossentropy,
    weighted_binary_crossentropy,
    pixel_map_weighted_binary_crossentropy,
    dice_loss_multi_class,
    dice_coef,
    DiceCoefByClassAndEmptiness,
    eval)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training', use_patches=True)
val_data, val_batches = dataset.create_dataset(dataset_type='validation', use_patches=True)

# Build model

In [None]:
seg_model = build_unet_model(
    img_height=cfg['PATCH_SIZE'],
    img_width=cfg['PATCH_SIZE'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 2),
    num_features=[32, 64, 128, 256],
    drop_prob=0.5)
model_checkpoint_name = 'patches'

In [None]:
# train_imgs = dataset.get_image_list('training')

# cls_pixel_counts = np.array([0, 0, 0, 0], dtype=np.float64)
# total_pixels = 0.0
# for img_name in train_imgs:
#     img, ann = dataset.get_example_from_img_name(img_name)
#     img_pixel_counts = np.sum(ann, axis=(0, 1))
#     cls_pixel_counts += img_pixel_counts
#     total_pixels += ann.size
# cls_pixel_counts

In [None]:
# cls_weights = total_pixels / cls_pixel_counts
# cls_weights

In [None]:
#cls_weights = [5274.42494602, 26340.24368548, 158.51888478, 752.96782738]
#cls_weights = np.array([5274.42494602, 26340.24368548, 158.51888478, 752.96782738]) / 20
#cls_weights = [263.7212473, 1317.01218427, 7.92594424, 37.64839137]

In [None]:
# cls_2_weight = 10.0
# cls_0_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[0]
# cls_1_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[1]
# cls_3_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[3]
# cls_weights = [cls_0_weight, cls_1_weight, cls_2_weight, cls_3_weight]
#cls_weights = [332.7316460371491, 1661.6470474376874, 10.0, 47.50019711767978]
#cls_weights = [30.0, 40.0, 10.0, 20.0]
cls_weights = [40.0, 50.0, 5.0, 20.0]

In [None]:
seg_model.summary()

In [None]:
seg_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=pixel_map_weighted_binary_crossentropy(cls_weights),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        dice_coef(batch_size=cfg['BATCH_SIZE']),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls3'),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls3'),
    ]
)

# Train

In [None]:
!ls checkpoints

In [None]:
#date_str = '20190916-092052'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('No checkpoints found. Starting from scratch.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    seg_model.load_weights(latest_checkpoint)


In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val__dice_coef',#'val_loss',
    save_best_only=True,
    mode='max',#'auto',
    save_weights_only=True,
    verbose=1)

logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]

results = seg_model.fit(train_data,
                    epochs=400,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=val_data,
                    steps_per_epoch=train_batches,
                    validation_steps=val_batches,
                    validation_freq=3,
                    initial_epoch=initial_epoch)

# Evaluate

In [None]:
val_imgs = dataset.get_image_list('validation')
len(val_imgs)

In [None]:

# TODO: Move these functions out to .py files

# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=10000): 0.8795549804140013
# postprocess(y, thresh=0.7, upper_thresh=0.7, num_px_thresh=10000): 0.8685514979767833
# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=20000): 0.8597784574292987
# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=5000):  0.888272717085689
# postprocess(y, thresh=0.85, upper_thresh=0.85, num_px_thresh=5000): 0.8933982786371334
# postprocess(y, thresh=0.85, upper_thresh=0.85, num_px_thresh=2000): 0.8862070712564566
# postprocess(y, thresh=0.9, upper_thresh=0.9, num_px_thresh=2000):  0.8899766923286324

def postprocess(y, thresh=None, upper_thresh=None, num_px_thresh=None):
    if thresh is None:
        thresh = [0.85, 0.85, 0.85, 0.85]
    if upper_thresh is None:
        upper_thresh = [0.85, 0.85, 0.85, 0.85]
    if num_px_thresh is None:
        num_px_thresh = [5000, 5000, 5000, 5000]
    # TODO: handle batches properly
    batches, height, width, classes = y.shape
    assert batches == 1
    
    # Only allow one class at each pixel
    y_argmax = np.argmax(y, axis=-1)
    y_one_hot = onehottify(y_argmax, y.shape[-1], dtype=int)
    for c in range(classes):
        y_one_hot[:, :, :, c][y[:, :, :, c] < thresh[c]] = 0 # Background
    
    # Making predictions on an empty mask is very costly (score immediately goes from 1 to 0)
    # So, only predict a mask if there are many pixels (num_px_thresh) above a high threshold (upper_thresh)
    for c in range(classes):
        pixels_above_upper = np.sum(y[:, :, :, c] > upper_thresh[c])
        if pixels_above_upper < num_px_thresh[c]:
            y_one_hot[:, :, :, c] = 0
    return y_one_hot
    
#     # Segmentations seem to be drawn with a tool that has a minimum width of ~20 pixels
#     min_width_div_2 = 10
#     for c in range(classes):
#         for row in range(height):
#             mask_start = None
#             for col in range(width):
#                 if y_one_hot[0, row, col, c]: # Pixel is part of mask
#                     if mask_start is None: # First pixel in mask from left to right
#                         mask_start = col
#                 else: # Pixel is not part of mask
#                     if mask_start is not None: # End of mask from left to right
#                         mask_mid = int((col - mask_start) / 2) + col
#                         new_start = max(0, mask_mid - min_width_div_2)
#                         new_end = min(y_one_hot.shape[2], mask_mid + min_width_div_2)
#                         y_one_hot[0, row, new_start:new_end, c] = 1
#                     mask_start = None
#             # I'm lazy, so don't bother dealing with masks that go off the right edge

def eval(model, dataset, img_list, num_classes=4, thresh=None, upper_thresh=None, num_px_thresh=None, verbose=False):
    num_empty_gt = [0] * num_classes
    num_empty_gt_mask_pred = [0] * num_classes
    num_mask_gt_empty_pred = [0] * num_classes
    mask_sizes = [[] for _ in range(num_classes)]

    dice_coeffs = []
    for img_name in img_list:
        img, ann = dataset.get_example_from_img_name(img_name)
        img_batch = np.expand_dims(img, axis=0)
        y = model.predict(img_batch)
        y_one_hot = postprocess(y, thresh=thresh, upper_thresh=upper_thresh, num_px_thresh=num_px_thresh)
        dice_coeffs.append(dice_coeff_kaggle(y_one_hot[0, :, :, :], ann))
        
        for c in range(num_classes):
            gt_mask_size = np.count_nonzero(ann[:, :, c])
            gt_is_empty = gt_mask_size == 0
            pred_is_empty = np.count_nonzero(y_one_hot[0, :, :, c]) == 0
            
            if gt_is_empty:
                num_empty_gt[c] += 1
            else:
                mask_sizes[c].append(gt_mask_size)

            if gt_is_empty and not pred_is_empty:
                num_empty_gt_mask_pred[c] += 1
            
            if not gt_is_empty and pred_is_empty:
                num_mask_gt_empty_pred[c] += 1

    if verbose:
        for c in range(num_classes):
            print(f'**** Class {c} ****')
            print(f'Num empty gt masks: {num_empty_gt[c]}')
            print(f'Num non-empty gt masks: {len(img_list) - num_empty_gt[c]}')
            print(f'Num empty gt mask and non-empty pred: {num_empty_gt_mask_pred[c]} '
                  f'({num_empty_gt_mask_pred[c] / num_empty_gt[c]})')
            print(f'Num non-empty gt mask and empty pred: {num_mask_gt_empty_pred[c]} '
                  f'({num_mask_gt_empty_pred[c] / (len(img_list) - num_empty_gt[c])})')
            print(f'Mean mask size: {np.mean(mask_sizes[c])} (stddev: {np.std(mask_sizes[c])})')
        
    mean_dice_coeff = np.mean(dice_coeffs)
    print(f'Mean dice coeff: {mean_dice_coeff}')
    return mean_dice_coeff, dice_coeffs

In [None]:
thresh = [0.5, 0.5, 0.5, 0.5]
upper_thresh = [0.85, 0.85, 0.85, 0.85]
num_px_thresh = [5000, 5000, 5000, 5000]

mean_dice_coeff, dice_coeffs = eval(
    seg_model,
    dataset,
    val_imgs,
    thresh=thresh,
    upper_thresh=upper_thresh,
    num_px_thresh=num_px_thresh)

In [None]:
# Perform grid search for each class:
import threading
from itertools import product
from collections import defaultdict


thresh_grid = [0.8, 0.9, 0.95]
upper_thresh_grid = [0.6, 0.7, 0.8, 0.9, 0.95]
num_px_thresh_grid = [2000, 5000, 10000, 20000]

grid_size = len(thresh_grid) * len(upper_thresh_grid) * len(num_px_thresh_grid)
grid_search_dice_scores = []
for i in range(4):
    grid_search_dice_scores.append([])
    for _ in range(grid_size):
        grid_search_dice_scores[i].append([])

def postprocess_worker(thresh, upper_thresh, num_px_thresh, y, ann, grid_search_dice_scores, index):
    y_one_hot = postprocess(y,
                            thresh=thresh,
                            upper_thresh=upper_thresh,
                            num_px_thresh=num_px_thresh)
    for cls in range(y_one_hot.shape[-1]):
        grid_search_dice_scores[cls][index].append(
            dice_coeff_kaggle(y_one_hot[0, :, :, cls:cls+1], ann[:, :, cls:cls+1]))

for i, img_name in enumerate(val_imgs):
    if i % 10 == 0:
        print(f'Processing image {i + 1} / {len(val_imgs)}')

    img, ann = dataset.get_example_from_img_name(img_name)
    img_batch = np.expand_dims(img, axis=0)
    y = seg_model.predict(img_batch)
    
    index = 0
    threads = []
    for thresh, upper_thresh, num_px_thresh in product(thresh_grid, upper_thresh_grid, num_px_thresh_grid):
        thresh_array = [thresh, thresh, thresh, thresh]
        upper_thresh_array = [upper_thresh, upper_thresh, upper_thresh, upper_thresh]
        num_pc_thresh_array = [num_px_thresh, num_px_thresh, num_px_thresh, num_px_thresh]
        
        t = threading.Thread(
            target=postprocess_worker,
            args=(thresh_array, upper_thresh_array, num_pc_thresh_array, y, ann, grid_search_dice_scores, index)
        )
        threads.append(t)
        t.start()
#         y_one_hot = postprocess(y,
#                                 thresh=[thresh, thresh, thresh, thresh],
#                                 upper_thresh=[upper_thresh, upper_thresh, upper_thresh, upper_thresh],
#                                 num_px_thresh=[num_px_thresh, num_px_thresh, num_px_thresh, num_px_thresh])
#         for cls in range(y_one_hot.shape[-1]):
#             grid_search_dice_scores[cls][index].append(dice_coeff_kaggle(y_one_hot[0, :, :, cls:cls+1], ann[:, :, cls:cls+1]))
        index += 1

    for t in threads:
        t.join()

for cls in range(4):
    print('********************')
    print(f'Results for class {cls}:')
    print('********************')
    cls_scores = grid_search_dice_scores[cls]
    index = 0
    max_score = 0
    max_score_str = ''
    for thresh, upper_thresh, num_px_thresh in product(thresh_grid, upper_thresh_grid, num_px_thresh_grid):
        total_score = np.mean(cls_scores[index])
        desc_str = f'thresh={thresh}, upper_thresh={upper_thresh}, num_px_thresh={num_px_thresh}:\t{total_score}'
        if total_score > max_score:
            max_score = total_score
            max_score_str = desc_str
        print(desc_str)
        index += 1
    print('Best hyperparam combination:')
    print(max_score_str + '\n\n')
    

In [None]:
Class 0: thresh=0.9, upper_thresh=0.7, num_px_thresh=5000:	0.9281188718499144
Class 1: thresh=0.9, upper_thresh=0.7, num_px_thresh=2000:	0.9878215327643398
Class 2: thresh=0.8, upper_thresh=0.9, num_px_thresh=2000:	0.8120156858142843
Class 3: thresh=0.8, upper_thresh=0.9, num_px_thresh=5000:	0.9748260642241272

In [None]:
Class 0: thresh=0.95, upper_thresh=0.7, num_px_thresh=5000:	0.9359843396575916
Class 1: thresh=0.9, upper_thresh=0.7, num_px_thresh=2000:	0.9878215327643398
Class 2: thresh=0.9, upper_thresh=0.95, num_px_thresh=2000:	0.8149881087479656
Class 3: thresh=0.8, upper_thresh=0.95, num_px_thresh=2000:	0.9749012164048724

In [None]:
Mean dice coeff: 0.9140845065637913

In [None]:
import skopt
HP_SPACE = [
    skopt.space.Real(0.0, 1.0, name='cls1_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls2_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls3_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls4_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls1_upper_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls2_upper_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls3_upper_thresh'),
    skopt.space.Real(0.0, 1.0, name='cls4_upper_thresh'),
    skopt.space.Integer(0, 50000, name='cls1_px_thresh'),
    skopt.space.Integer(0, 50000, name='cls2_px_thresh'),
    skopt.space.Integer(0, 50000, name='cls3_px_thresh'),
    skopt.space.Integer(0, 50000, name='cls4_px_thresh'),
]

def build_eval_postprocess_hyperparams(model, dataset, val_imgs):
    @skopt.utils.use_named_args(HP_SPACE)
    def eval_postprocess_hyperparams(**kwargs):
        thresh = [kwargs['cls1_thresh'], kwargs['cls2_thresh'], kwargs['cls3_thresh'], kwargs['cls4_thresh']]
        upper_thresh = [kwargs['cls1_upper_thresh'], kwargs['cls2_upper_thresh'], kwargs['cls3_upper_thresh'], kwargs['cls4_upper_thresh']]
        num_px_thresh = [kwargs['cls1_px_thresh'], kwargs['cls2_px_thresh'], kwargs['cls3_px_thresh'], kwargs['cls4_px_thresh']]

        score, _ = eval(model, dataset, val_imgs, thresh=thresh, upper_thresh=upper_thresh, num_px_thresh=num_px_thresh)
        return 1 - score
    return eval_postprocess_hyperparams

res = skopt.gbrt_minimize(
    build_eval_postprocess_hyperparams(seg_model, dataset, val_imgs),
    HP_SPACE,
    n_calls=120,
    n_random_starts=10,
    random_state=123,
    verbose=True,
    n_jobs=6)
# res = skopt.gp_minimize(
#     build_eval_postprocess_hyperparams(model, dataset, val_imgs), # the function to minimize
#     HP_SPACE,                                                     # the bounds on each dimension of x
#     acq_func="EI",                                                # the acquisition function
#     n_calls=120,                                                  # the number of evaluations of f 
#     n_random_starts=20,                                           # the number of random initialization points
#     random_state=123)                                             # the random seed

#skopt.dump(res, 'skopt_result.pkl')

In [None]:
res.x

In [None]:
thres = [0.989482241473342, 0.4681867574192067, 0.8408261195643344, 0.7391752127314548]
upper_thresh = [0.02055732831801516, 0.8815747640346393, 0.8895357807949745, 0.84233544712227]
num_px_thresh = [17148, 29846, 5021, 34434]

In [None]:
skopt.dump(res, 'skopt_result.pkl')

In [None]:
indices = np.argsort(dice_coeffs)[:50] # Indices of 10 worst images
for i in indices:
    print(f'{i}: {dice_coeffs[i]}')

In [None]:
# Visualize Image Prediction
img_id = 1018
thresh = 0.85

img_name = val_imgs[img_id]
img, ann = dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = seg_model.predict(img_batch)
plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()

plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 0])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 1])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 2])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 3])
plt.show()

plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 0] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 1] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 2] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 3] > thresh)
plt.show()

# Save HDF5 Model

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
seg_model.save(f'seg_model_{date_str}.h5')

In [None]:
!ls