In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset, visualize_segmentations, dense_to_rle, rle_to_dense
from steel_seg.dataset.deep_q_data_generator import DeepQDataGenerator
from steel_seg.model.unet import build_unet_model
from steel_seg.model.deep_q_postprocessor import build_deep_q_model
from steel_seg.train import class_weighted_binary_crossentropy, dice_coef, dice_coeff_kaggle, eval
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# TODO
# Figure out how to submit model
# Alternative loss functions?
#    - Weight boundary pixels more heavily
#    - Are class weights necessary?
#    - BCE + dice loss
#    - Jacquard? Others?
#    - Focal loss?
# Further post-processing
#    - Identify segments, and process them 1-by-1
#    - Different thresholds for each class

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
# # To run in half-precision mode on GPU
# dtype='float16'
# K.set_floatx(dtype)

# # default is 1e-7 which is too small for float16.  Without adjusting the epsilon, we will get NaN predictions because of divide by zero problems
# K.set_epsilon(1e-4)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training')
val_data, val_batches = dataset.create_dataset(dataset_type='validation')

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[8, 16, 32, 64],
    drop_prob=0.5)

In [None]:
# train_imgs = dataset.get_image_list('training')

# cls_pixel_counts = [0, 0, 0, 0]
# for img_name in train_imgs:
#     img, ann = dataset.get_example_from_img_name(img_name)
#     for i in range(ann.shape[-1]):
#         cls_pixel_counts[i] += np.sum(ann[:, :, i])

In [None]:
# cls_2_weight = 10.0
# cls_0_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[0]
# cls_1_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[1]
# cls_3_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[3]
# cls_weights = [cls_0_weight, cls_1_weight, cls_2_weight, cls_3_weight]
#cls_weights = [332.7316460371491, 1661.6470474376874, 10.0, 47.50019711767978]
cls_weights = [30.0, 40.0, 10.0, 20.0]

In [None]:
model.summary()

In [None]:
model.compile(optimizer=tf.train.AdamOptimizer(0.0005),
              loss=class_weighted_binary_crossentropy(cls_weights),#weighted_binary_crossentropy(10.0),#'binary_crossentropy',
              metrics=[tf.keras.metrics.BinaryAccuracy(), dice_coef(batch_size=cfg['BATCH_SIZE'])])#[dice_coef, 'accuracy'])

In [None]:
# Load from checkpoint
#model.load_weights('checkpoints/cp_20190814-224021.ckpt')
model.load_weights('checkpoints/cp_20190820-085420.ckpt')

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f'checkpoints/cp_{date_str}.ckpt'

# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, save_weights_only=True, verbose=1)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=8, monitor='val_loss')

logdir = "logs/" + date_str
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
    early_stopping_cb,
]

results = model.fit(train_data,
                    epochs=200,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=val_data,
                    steps_per_epoch=train_batches,
                    validation_steps=val_batches,
                    validation_freq=1)

In [None]:
val_imgs = dataset.get_image_list('validation')

In [None]:
len(val_imgs)

In [None]:
def onehottify(x, n=None, dtype=float):
    '''1-hot encode x with the max value n (computed from data if n is None).
    '''
    x = np.asarray(x)
    n = np.max(x) + 1 if n is None else n
    return np.eye(n, dtype=dtype)[x]
# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=10000): 0.8795549804140013
# postprocess(y, thresh=0.7, upper_thresh=0.7, num_px_thresh=10000): 0.8685514979767833
# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=20000): 0.8597784574292987
# postprocess(y, thresh=0.8, upper_thresh=0.8, num_px_thresh=5000):  0.888272717085689
# postprocess(y, thresh=0.85, upper_thresh=0.85, num_px_thresh=5000): 0.8933982786371334
# postprocess(y, thresh=0.85, upper_thresh=0.85, num_px_thresh=2000): 0.8862070712564566
# postprocess(y, thresh=0.9, upper_thresh=0.9, num_px_thresh=2000):  0.8899766923286324

def postprocess(y, thresh=0.85, upper_thresh=0.85, num_px_thresh=5000):
    # TODO: handle batches properly
    
    # Making predictions on an empty mask is very costly (score immediately goes from 1 to 0)
    # So, only predict a mask if there are many pixels (num_px_thresh) above a high threshold (upper_thresh)
    y_post = np.zeros_like(y)
    for c in range(y.shape[-1]):
        pixels_above_upper = np.sum(y[:, :, :, c] > upper_thresh)
        if pixels_above_upper > num_px_thresh:
            y_post[:, :, :, c] = y[:, :, :, c]
    
    # Only allow one class at each pixel
    y_argmax = np.argmax(y_post, axis=-1)
    y_one_hot = onehottify(y_argmax, y.shape[-1])
    y_one_hot[y < thresh] = 0
    return y_one_hot

def eval(model, dataset, img_list):
    dice_coeffs = []
    for img_name in img_list:
        img, ann = dataset.get_example_from_img_name(img_name)
        img_batch = np.expand_dims(img, axis=0)
        y = model.predict(img_batch)
        y_one_hot = postprocess(y)
        dice_coeffs.append(dice_coeff_kaggle(y_one_hot[0, :, :, :], ann))
    mean_dice_coeff = np.mean(dice_coeffs)
    print(f'Mean dice coeff: {mean_dice_coeff}')
    return mean_dice_coeff, dice_coeffs

In [None]:
mean_dice_coeff, dice_coeffs = eval(model, dataset, val_imgs)

In [None]:
indices = np.argsort(dice_coeffs)[:50] # Indices of 10 worst images
for i in indices:
    print(f'{i}: {dice_coeffs[i]}')

In [None]:
img_name = val_imgs[101]
img, ann = dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = model.predict(img_batch)
plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()
thresh = 0.9
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 0] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 1] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 2] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 3] > thresh)
plt.show()

# Train Post-Processing Model

In [None]:
post_model = build_deep_q_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[4, 4, 4, 4],
    drop_prob=0.5)

In [None]:
post_model.summary()

In [None]:
post_model.compile(optimizer=tf.train.AdamOptimizer(0.0001),
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=[tf.keras.losses.MeanAbsoluteError()])

In [None]:
from steel_seg.dataset.deep_q_data_generator import DeepQDataGenerator

In [None]:
train_data_generator = DeepQDataGenerator(
    base_model=model,
    steel_dataset=dataset,
    dataset_name='training',
    batch_size=cfg['BATCH_SIZE'],
    threshold=0.85,
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    shuffle=True)

In [None]:
val_data_generator = DeepQDataGenerator(
    base_model=model,
    steel_dataset=dataset,
    dataset_name='validation',
    batch_size=cfg['BATCH_SIZE'],
    threshold=0.85,
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    shuffle=True)

In [None]:
#cp_20190828-225036.ckpt
#cp_20190829-222109.ckpt
#cp_20190829-225130.ckpt

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f'postprocess_checkpoints/cp_{date_str}.ckpt'

# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, save_weights_only=True, verbose=1)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=8, monitor='val_loss')

logdir = "postprocessing_logs/" + date_str
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
    early_stopping_cb,
]

results = post_model.fit(train_data_generator,
                         epochs=200,
                         verbose=2,
                         callbacks=callbacks,
                         validation_data=val_data_generator,
                         steps_per_epoch=len(train_data_generator),
                         validation_steps=len(val_data_generator),
                         validation_freq=1)

In [None]:
val_imgs = dataset.get_image_list('validation')

In [None]:
from steel_seg.dataset.deep_q_data_generator import postprocess
def deep_q_postprocess(y, post_model):
    pred_scores = post_model.predict(y)
    #print(pred_scores)
    y_post = np.copy(y)
    for batch in range(y_post.shape[0]):
        for cls in range(y_post.shape[-1]):
            # If no_mask_score > mask_score
            if pred_scores[batch, cls, 1] > pred_scores[batch, cls, 0]:
                y_post[batch, :, :, cls] = 0
    return y_post

def eval(model, post_model, dataset, img_list):
    dice_coeffs = []
    for img_name in img_list:
        img, ann = dataset.get_example_from_img_name(img_name)
        img_batch = np.expand_dims(img, axis=0)
        y = model.predict(img_batch)
        y_one_hot = postprocess(y, 0.85) #TODO: ad thresh to config file
        y_post = deep_q_postprocess(y_one_hot, post_model)
        dice_coeffs.append(dice_coeff_kaggle(y_post[0, :, :, :], ann))
    mean_dice_coeff = np.mean(dice_coeffs)
    print(f'Mean dice coeff: {mean_dice_coeff}')
    return mean_dice_coeff, dice_coeffs

In [None]:
mean_dice_coeff, dice_coeffs = eval(model, post_model, dataset, val_imgs)

In [None]:
indices = np.argsort(dice_coeffs)[:50] # Indices of 10 worst images
for i in indices:
    print(f'{i}: {dice_coeffs[i]}')

In [None]:
img_name = val_imgs[903]
img, ann = dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = model.predict(img_batch)
y_one_hot = postprocess(y, 0.85) #TODO: ad thresh to config file
y_post = deep_q_postprocess(y_one_hot, post_model)


plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()
thresh = 0.85
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 0] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 1] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 2] > thresh)
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y[0, :, :, 3] > thresh)
plt.show()

plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 0])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 1])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 2])
plt.show()
plt.figure(figsize=(10, 3))
plt.imshow(y_post[0, :, :, 3])
plt.show()