Train a UNet segmentation model on the Severstal steel defect dataset.

# Initialize and load data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# TODO
# - Dataset generator so that we can easily oversample different classes
# Turn on contrast and brightness augmentations
# Train full classification model (i.e. don't freeze weights) to see if it does any/much better
# Option to resume training (without breaking tensorboard)
# Resize before feeding into model, don't let the width get quite so small
# Efficient per-class postprocessing grid search
#    - Run model once, figure out score for each class for a range of values
#    - At the end, pick the best values
# Cropping augmentations
# Alternative loss functions?
#    - Weight boundary pixels more heavily
#    - Are class weights necessary?
#    - BCE + dice loss
#    - Jacquard? Others?
#    - Focal loss?

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    per_class_dice_coeff,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.model.unet import build_unet_model
from steel_seg.model.deep_q_postprocessor import build_deep_q_model
from steel_seg.train import (
    class_weighted_binary_crossentropy,
    weighted_binary_crossentropy,
    pixel_map_weighted_binary_crossentropy,
    dice_loss_multi_class,
    dice_coef,
    DiceCoefByClassAndEmptiness,
    eval)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training')
val_data, val_batches = dataset.create_dataset(dataset_type='validation')

# Build model

In [None]:
# seg_model = build_unet_model(
#     img_height=cfg['IMG_HEIGHT'],
#     img_width=cfg['IMG_WIDTH'],
#     img_channels=1,
#     num_classes=cfg['NUM_CLASSES'],
#     num_layers=4,
#     activation=tf.keras.activations.elu,
#     kernel_initializer='he_normal',
#     kernel_size=(3, 3),
#     pool_size=(2, 4),
#     num_features=[8, 16, 32, 64],
#     drop_prob=0.5)
#model_checkpoint_name = 'basic'
#checkpoints/cp_20190910-230338.ckpt
#checkpoints/cp_20190913-090408.ckpt

seg_model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[32, 64, 128, 256],
    drop_prob=0.5)
model_checkpoint_name = 'deep'
# # checkpoints/cp_20190914-142429.ckpt

In [None]:
# train_imgs = dataset.get_image_list('training')

# cls_pixel_counts = np.array([0, 0, 0, 0], dtype=np.float64)
# total_pixels = 0.0
# for img_name in train_imgs:
#     img, ann = dataset.get_example_from_img_name(img_name)
#     img_pixel_counts = np.sum(ann, axis=(0, 1))
#     cls_pixel_counts += img_pixel_counts
#     total_pixels += ann.size
# cls_pixel_counts

In [None]:
# cls_weights = total_pixels / cls_pixel_counts
# cls_weights

In [None]:
#cls_weights = [5274.42494602, 26340.24368548, 158.51888478, 752.96782738]
#cls_weights = np.array([5274.42494602, 26340.24368548, 158.51888478, 752.96782738]) / 20
#cls_weights = [263.7212473, 1317.01218427, 7.92594424, 37.64839137]

In [None]:
# cls_2_weight = 10.0
# cls_0_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[0]
# cls_1_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[1]
# cls_3_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[3]
# cls_weights = [cls_0_weight, cls_1_weight, cls_2_weight, cls_3_weight]
#cls_weights = [332.7316460371491, 1661.6470474376874, 10.0, 47.50019711767978]
cls_weights = [30.0, 40.0, 10.0, 20.0]
#cls_weights = [40.0, 50.0, 5.0, 20.0]

In [None]:
seg_model.summary()

In [None]:
seg_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=pixel_map_weighted_binary_crossentropy(cls_weights), #dice_loss_multi_class,
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        dice_coef(batch_size=cfg['BATCH_SIZE']),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls3'),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls3'),
    ]
)

# Load Initial Weights

In [None]:
!ls checkpoints

In [None]:
model_checkpoint_name

In [None]:
date_str = None
date_str = '20190916-092052'

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('No checkpoints found. Starting from scratch.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    seg_model.load_weights(latest_checkpoint)

## Use new model name?

In [None]:
model_checkpoint_name = 'deep_w_aug'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
initial_epoch = 0

# Train

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val__dice_coef',#'val_loss',
    save_best_only=True,
    mode='max',#'auto',
    save_weights_only=True,
    verbose=1)

logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]

results = seg_model.fit(
    train_data,
    epochs=400,
    verbose=2,
    callbacks=callbacks,
    validation_data=val_data,
    steps_per_epoch=train_batches,
    validation_steps=val_batches,
    validation_freq=3,
    initial_epoch=initial_epoch)

# Evaluate

In [None]:
val_imgs = dataset.get_image_list('validation')
len(val_imgs)

In [None]:
def per_class_binary_cross_entropy(y_pred, y_true):
    assert len(y_true.shape) == 3
    eps = 0.000001
    y_pred = np.clip(y_pred, eps, 1 - eps)
    per_pixel_class_cross_entropy = \
        y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
    per_class_cross_entropy = np.mean(per_pixel_class_cross_entropy, axis=(0, 1))
    return per_class_cross_entropy

def per_class_iou_score(y_pred, y_true):
    assert len(y_true.shape) == 3
    intersection = np.sum(y_pred * y_true, axis=(0, 1))
    union = np.sum(y_pred, axis=(0, 1)) + np.sum(y_true, axis=(0, 1)) - intersection
    
    iou = []
    for i in range(intersection.shape[0]):
        if union[i] == 0:
            iou.append(1.0) # Both y_pred and y_true were empty
        else:
            iou.append(intersection[i] / union[i])
    return np.array(iou)

def per_class_mask_not_empty(y_true):
    return np.sum(y_true, axis=(0, 1)) > 0.5

def eval_segmentation(
    model,
    dataset,
    img_list,
    thresholds=None,
    num_classes=4):
    
    binary_cross_entropy = np.zeros((len(img_list), num_classes), dtype=np.float32)
    dice_coeff = np.zeros((len(img_list), num_classes), dtype=np.float32)
    iou_score = np.zeros((len(img_list), num_classes), dtype=np.float32)
    mask_not_empty = np.zeros((len(img_list), num_classes), dtype=np.bool)
    
    if thresholds is None:
        thresholds = [0.5, 0.5, 0.5, 0.5]
    thresholds = np.array(thresholds)

    for i, img_name in enumerate(img_list):
        img, ann = dataset.get_example_from_img_name(img_name)
        img_batch = np.expand_dims(img, axis=0)
        y = model.predict(img_batch)
        
        # Binarize predictions
        y_bin = np.zeros_like(y, dtype=np.uint8)
        y_bin[y > thresholds] = 1

        binary_cross_entropy[i, :] = per_class_binary_cross_entropy(y[0, :, :, :], ann)
        dice_coeff[i, :] = per_class_dice_coeff(y_bin[0, :, :, :], ann)
        iou_score[i, :] = per_class_iou_score(y_bin[0, :, :, :], ann)
        mask_not_empty[i, :] = per_class_mask_not_empty(ann)
        
    return binary_cross_entropy, dice_coeff, iou_score, mask_not_empty

In [None]:
binary_cross_entropy, dice_coeff, iou_score, mask_not_empty = eval_segmentation(
    seg_model,
    dataset,
    val_imgs,
    thresholds=[0.9, 0.9, 0.9, 0.9],
    num_classes=4)

In [None]:
print(f'Mean dice coeff: {np.mean(dice_coeff)}')
print(f'Mean binary cross entropy: {np.mean(binary_cross_entropy)}')
print(f'Mean IoU: {np.mean(iou_score)}')

In [None]:
for i in range(cfg['NUM_CLASSES']):
    print('*******************')
    print(f'***** Class {i} *****')
    print('*******************')
    print(f'Mean dice coeff: {np.mean(dice_coeff[:, i])}')
    print(f'Mean dice coeff (with mask): {np.mean(dice_coeff[:, i][mask_not_empty[:, i]])}')
    print(f'Mean dice coeff (no mask): {np.mean(dice_coeff[:, i][~mask_not_empty[:, i]])}')
    print(f'Mean IoU: {np.mean(iou_score[:, i])}')
    print(f'Mean IoU (with mask): {np.mean(iou_score[:, i][mask_not_empty[:, i]])}')
    print(f'Mean IoU (no mask): {np.mean(iou_score[:, i][~mask_not_empty[:, i]])}')
    print('')

In [None]:
0.8
*******************
***** Class 0 *****
*******************
Mean dice coeff: 0.8249650597572327
Mean dice coeff (with mask): 0.5211997032165527
Mean dice coeff (no mask): 0.85056471824646
Mean IoU: 0.813450276851654
Mean IoU (with mask): 0.3730500340461731
Mean IoU (no mask): 0.85056471824646

*******************
***** Class 1 *****
*******************
Mean dice coeff: 0.9306501746177673
Mean dice coeff (with mask): 0.5132780075073242
Mean dice coeff (no mask): 0.9423393607139587
Mean IoU: 0.9266247153282166
Mean IoU (with mask): 0.365519642829895
Mean IoU (no mask): 0.9423393607139587

*******************
***** Class 2 *****
*******************
Mean dice coeff: 0.706938624382019
Mean dice coeff (with mask): 0.6717405915260315
Mean dice coeff (no mask): 0.7313432693481445
Mean IoU: 0.651603102684021
Mean IoU (with mask): 0.5365963578224182
Mean IoU (no mask): 0.7313432693481445

*******************
***** Class 3 *****
*******************
Mean dice coeff: 0.9221725463867188
Mean dice coeff (with mask): 0.7034797668457031
Mean dice coeff (no mask): 0.9367521405220032
Mean IoU: 0.9135252833366394
Mean IoU (with mask): 0.5651221871376038
Mean IoU (no mask): 0.9367521405220032

In [None]:
0.7
*******************
***** Class 0 *****
*******************
Mean dice coeff: 0.7857795357704163
Mean dice coeff (with mask): 0.5015755891799927
Mean dice coeff (no mask): 0.8097306489944458
Mean IoU: 0.7742512822151184
Mean IoU (with mask): 0.3532538414001465
Mean IoU (no mask): 0.8097306489944458

*******************
***** Class 1 *****
*******************
Mean dice coeff: 0.9162784218788147
Mean dice coeff (with mask): 0.485749751329422
Mean dice coeff (no mask): 0.9283360838890076
Mean IoU: 0.9122956991195679
Mean IoU (with mask): 0.3395600914955139
Mean IoU (no mask): 0.9283360838890076

*******************
***** Class 2 *****
*******************
Mean dice coeff: 0.6629075407981873
Mean dice coeff (with mask): 0.6581381559371948
Mean dice coeff (no mask): 0.6662144064903259
Mean IoU: 0.6058319211006165
Mean IoU (with mask): 0.5187441110610962
Mean IoU (no mask): 0.6662144064903259

*******************
***** Class 3 *****
*******************
Mean dice coeff: 0.9067474603652954
Mean dice coeff (with mask): 0.7002676129341125
Mean dice coeff (no mask): 0.9205127954483032
Mean IoU: 0.8978906273841858
Mean IoU (with mask): 0.5585583448410034
Mean IoU (no mask): 0.9205127954483032

In [None]:
0.6
*******************
***** Class 0 *****
*******************
Mean dice coeff: 0.7556981444358826
Mean dice coeff (with mask): 0.4856829345226288
Mean dice coeff (no mask): 0.778453528881073
Mean IoU: 0.7441787719726562
Mean IoU (with mask): 0.3374749422073364
Mean IoU (no mask): 0.778453528881073

*******************
***** Class 1 *****
*******************
Mean dice coeff: 0.903589129447937
Mean dice coeff (with mask): 0.4611539840698242
Mean dice coeff (no mask): 0.9159802198410034
Mean IoU: 0.8996537327766418
Mean IoU (with mask): 0.31670063734054565
Mean IoU (no mask): 0.9159802198410034

*******************
***** Class 2 *****
*******************
Mean dice coeff: 0.6304072141647339
Mean dice coeff (with mask): 0.6433426737785339
Mean dice coeff (no mask): 0.6214382648468018
Mean IoU: 0.571916937828064
Mean IoU (with mask): 0.5004937648773193
Mean IoU (no mask): 0.6214382648468018

*******************
***** Class 3 *****
*******************
Mean dice coeff: 0.8894842863082886
Mean dice coeff (with mask): 0.6932864785194397
Mean dice coeff (no mask): 0.9025641083717346
Mean IoU: 0.8804334998130798
Mean IoU (with mask): 0.5484738349914551
Mean IoU (no mask): 0.9025641083717346

In [None]:
0.5
*******************
***** Class 0 *****
*******************
Mean dice coeff: 0.7198657393455505
Mean dice coeff (with mask): 0.46796393394470215
Mean dice coeff (no mask): 0.741094708442688
Mean IoU: 0.7083962559700012
Mean IoU (with mask): 0.32039639353752136
Mean IoU (no mask): 0.741094708442688

*******************
***** Class 1 *****
*******************
Mean dice coeff: 0.895077645778656
Mean dice coeff (with mask): 0.4428459107875824
Mean dice coeff (no mask): 0.907742977142334
Mean IoU: 0.8911918997764587
Mean IoU (with mask): 0.3002220690250397
Mean IoU (no mask): 0.907742977142334

*******************
***** Class 2 *****
*******************
Mean dice coeff: 0.5963675379753113
Mean dice coeff (with mask): 0.6287020444869995
Mean dice coeff (no mask): 0.5739484429359436
Mean IoU: 0.5370190143585205
Mean IoU (with mask): 0.4837566316127777
Mean IoU (no mask): 0.5739484429359436

*******************
***** Class 3 *****
*******************
Mean dice coeff: 0.8759523034095764
Mean dice coeff (with mask): 0.6819044947624207
Mean dice coeff (no mask): 0.8888888955116272
Mean IoU: 0.8667181730270386
Mean IoU (with mask): 0.5341587662696838
Mean IoU (no mask): 0.8888888955116272

In [None]:
class_id = None
mask_only = True

scores = None
if class_id is None:
    print('Worst scores for class all classes:')
    scores = np.mean(dice_coeff, axis=-1)
else:
    print(f'Worst scores for class {class_id}')
    scores = dice_coeff[:, class_id]

indices = np.argsort(scores) # Indices of worst images

if mask_only and class_id is not None:
    print('Including scores for non-empty ground truth masks only.')
    mask_only_indices = np.where(mask_not_empty[:, class_id])
    mask_only_indices = set(mask_only_indices[0].tolist())
    indices = [index for index in indices if index in mask_only_indices]

for i in indices:
    print(f'{i}: {scores[i]}')

In [None]:
# Visualize Image Prediction
img_id = 423
thresh = [0.5, 0.5, 0.5, 0.5]

img_name = val_imgs[img_id]
img, ann = dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = seg_model.predict(img_batch)
plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()

for i in range(y.shape[-1]):
    plt.figure(figsize=(12.5, 3))
    plt.imshow(y[0, :, :, i])
    plt.colorbar()

for i in range(y.shape[-1]):
    plt.figure(figsize=(10, 3))
    plt.imshow(y[0, :, :, i] > thresh[i])
    plt.show()

# Save HDF5 Model

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
seg_model.save(f'seg_model_{date_str}.h5')

In [None]:
!ls