The Dice Coefficient metric used for the Severstal challenge heavily penalizes a single predicted mask pixel when the ground truth is an empty mask:
 - If the ground truth mask is empty and the predicted mask is empty, dice_score = 1.0
 - If the ground truth mask is empty and the predicted mask has a single mask pixel, dice_score = 0.0

This notebook trains a model to predict the Dice score that would be achieved under 2 conditions (mask and no_mask) given the prediction made by the segmentation model.

# Initialize and load segmentation data

In [None]:
# TODO:
# Make a submission using the heuristic-based post-processing
# Force the minimum width of predicted masks
# Only save best-so-far checkpoints
# Split the datasets properly!!!!!
# seg_train, post_train, val
# Try training to predict the better action? Maybe weight it based on the difference in score?
#    - Might work better than training a regression model?
# Try reducing the lr
# Try training on thresholded input again?
# Try q model with different thresholds

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import dice_coeff_kaggle, rle_to_dense, dense_to_rle, visualize_segmentations
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset, SeverstalSteelPostprocessDataset
from steel_seg.model.unet import build_unet_model
from steel_seg.model.deep_q_postprocessor import build_deep_q_model
from steel_seg.train import class_weighted_binary_crossentropy, weighted_binary_crossentropy, dice_coef, eval, empty_mask_loss, empty_mask_accuracy
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
seg_dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = seg_dataset.create_dataset(dataset_type='training')
val_data, val_batches = seg_dataset.create_dataset(dataset_type='validation')

# Load segmentation model from checkpoint

In [None]:
# Copy this from `train_segmentation_model.ipynb`
seg_model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[8, 16, 32, 64],
    drop_prob=0.5)

In [None]:
# Load from checkpoint
#model.load_weights('checkpoints/cp_20190814-224021.ckpt')
seg_model.load_weights('checkpoints/cp_20190820-085420.ckpt')

# Generate Q function dataset tfrecords (if necessary)

In [None]:
post_dataset = SeverstalSteelPostprocessDataset.init_from_config(
    config_path='SETTINGS.yaml',
    seg_model=seg_model,
    seg_dataset=seg_dataset,
    postprocess_thresh=0.85)

In [None]:
# post_dataset.create_tfrecords()

# Build Postprocessing Q Function Model

In [None]:
post_model = build_deep_q_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[4, 8, 16, 16],
    drop_prob=0.5)

In [None]:
post_model.summary()

In [None]:
post_model.compile(optimizer=tf.train.AdamOptimizer(0.00001),
                   loss=tf.keras.losses.MeanSquaredError(),
                   metrics=[tf.keras.losses.MeanAbsoluteError()])

# Train

In [None]:
train_data, train_batches = post_dataset.create_dataset(dataset_type='training')
val_data, val_batches = post_dataset.create_dataset(dataset_type='validation')

In [None]:
#post_model.load_weights('postprocess_checkpoints/cp_20190907-171107.ckpt')

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f'postprocess_checkpoints/cp_{date_str}.ckpt'

# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, save_weights_only=True, verbose=1)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=8, monitor='val_loss')

logdir = "postprocessing_logs/" + date_str
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
    early_stopping_cb,
]

results = post_model.fit(train_data,
                         epochs=200,
                         verbose=2,
                         callbacks=callbacks,
                         validation_data=val_data,
                         steps_per_epoch=train_batches,
                         validation_steps=val_batches,
                         validation_freq=1)

# Evaluate

In [None]:
val_imgs = seg_dataset.get_image_list('validation')

In [None]:
from steel_seg.model.unet import postprocess
def deep_q_postprocess(y, q_scores):
    y_post = np.copy(y)
    for batch in range(y_post.shape[0]):
        for cls in range(y_post.shape[-1]):
            # If no_mask_score > mask_score
            if q_scores[batch, cls, 1] > q_scores[batch, cls, 0]:
                y_post[batch, :, :, cls] = 0
    return y_post

def eval(seg_model, post_model, dataset, img_list):
    dice_coeffs = []
    for img_name in img_list:
        img, ann = dataset.get_example_from_img_name(img_name)
        img_batch = np.expand_dims(img, axis=0)
        y = seg_model.predict(img_batch)
        y_one_hot = postprocess(y, 0.85) #TODO: add thresh to config file
        q_scores = post_model.predict(y)
        y_post = deep_q_postprocess(y_one_hot, q_scores)
        dice_coeffs.append(dice_coeff_kaggle(y_post[0, :, :, :], ann))
    mean_dice_coeff = np.mean(dice_coeffs)
    print(f'Mean dice coeff: {mean_dice_coeff}')
    return mean_dice_coeff, dice_coeffs

In [None]:
mean_dice_coeff, dice_coeffs = eval(seg_model, post_model, seg_dataset, val_imgs)

In [None]:
# Indices of worst images
indices = np.argsort(dice_coeffs)[:50]
for i in indices:
    print(f'{i}: {dice_coeffs[i]}')

In [None]:
# Visualize preds
index = 824
thresh = 0.85 #TODO: add to cfg file

img_name = val_imgs[index]
img, ann = seg_dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = seg_model.predict(img_batch)
y_one_hot = postprocess(y, thresh)
q_scores = post_model.predict(y)
y_post = deep_q_postprocess(y_one_hot, q_scores)


plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()
thresh = 0.85
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 0] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 1] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 2] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 3] > thresh)
plt.show()

plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 0])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 1])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 2])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 3])
plt.show()