Train a UNet segmentation model on the Severstal steel defect dataset.

# Initialize and load data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    per_class_dice_coeff,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.dataset.severstal_steel_dataset_patch_generator import \
    SeverstalSteelDatasetPatchGenerator
from steel_seg.model.unet import build_unet_model
from steel_seg.model.deep_q_postprocessor import build_deep_q_model
from steel_seg.train import (
    class_weighted_binary_crossentropy,
    weighted_binary_crossentropy,
    pixel_map_weighted_binary_crossentropy,
    dice_loss_multi_class,
    dice_coef,
    DiceCoefByClassAndEmptiness,
    eval)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
# Just load this dataset so we can use the same train/val split
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')
train_imgs = dataset.get_image_list('training')
val_imgs = dataset.get_image_list('validation')
dataset = None

In [None]:
train_data_gen = SeverstalSteelDatasetPatchGenerator.init_from_config(
    'SETTINGS.yaml',
    train_imgs,
    is_training=True)
val_data_gen = SeverstalSteelDatasetPatchGenerator.init_from_config(
    'SETTINGS.yaml',
    val_imgs,
    is_training=False)

# Build model

In [None]:
seg_model = build_unet_model(
    img_height=cfg['PATCH_SIZE'],
    img_width=cfg['PATCH_SIZE'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 2),
    num_features=[32, 64, 128, 256],
    drop_prob=0.5)
model_checkpoint_name = 'patches'

In [None]:
# train_imgs = dataset.get_image_list('training')

# cls_pixel_counts = np.array([0, 0, 0, 0], dtype=np.float64)
# total_pixels = 0.0
# for img_name in train_imgs:
#     img, ann = dataset.get_example_from_img_name(img_name)
#     img_pixel_counts = np.sum(ann, axis=(0, 1))
#     cls_pixel_counts += img_pixel_counts
#     total_pixels += ann.size
# cls_pixel_counts

In [None]:
# cls_weights = total_pixels / cls_pixel_counts
# cls_weights

In [None]:
#cls_weights = [5274.42494602, 26340.24368548, 158.51888478, 752.96782738]
#cls_weights = np.array([5274.42494602, 26340.24368548, 158.51888478, 752.96782738]) / 20
#cls_weights = [263.7212473, 1317.01218427, 7.92594424, 37.64839137]

In [None]:
# cls_2_weight = 10.0
# cls_0_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[0]
# cls_1_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[1]
# cls_3_weight = (cls_pixel_counts[2] * cls_2_weight) / cls_pixel_counts[3]
# cls_weights = [cls_0_weight, cls_1_weight, cls_2_weight, cls_3_weight]
#cls_weights = [332.7316460371491, 1661.6470474376874, 10.0, 47.50019711767978]
#cls_weights = [30.0, 40.0, 10.0, 20.0]
#cls_weights = [40.0, 50.0, 5.0, 20.0]
cls_weights = [2.0, 4.0, 1.0, 2.0]

In [None]:
seg_model.summary()

In [None]:
seg_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=pixel_map_weighted_binary_crossentropy(cls_weights),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        dice_coef(batch_size=cfg['BATCH_SIZE']),
        #DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls0'),
        #DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls1'),
        #DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls2'),
        #DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=True, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_empty_cls3'),
        #DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls0'),
        #DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls1'),
        #DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls2'),
        #DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=False, batch_size=cfg['BATCH_SIZE'], name='dice_coef_on_non_empty_cls3'),
    ]
)

# Load Initial Weights

In [None]:
!ls checkpoints

In [None]:
model_checkpoint_name

In [None]:
date_str = None
date_str = '20190928-210154'

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('No checkpoints found. Starting from scratch.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    seg_model.load_weights(latest_checkpoint)

## Use new model name?

In [None]:
model_checkpoint_name = 'patches_lower_weight'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
#initial_epoch = 0

# Train

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss', #'val__dice_coef'
    save_best_only=True,
    mode='auto',#'auto', 'max',
    save_weights_only=True,
    verbose=1)

logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]

results = seg_model.fit(
    train_data_gen,
    epochs=400,
    verbose=2,
    callbacks=callbacks,
    validation_data=val_data_gen,
    steps_per_epoch=len(train_data_gen),
    validation_steps=len(val_data_gen),
    validation_freq=3,
    initial_epoch=initial_epoch,
    workers=6,
    use_multiprocessing=True)

# Evaluate

In [None]:
from steel_seg.dataset.severstal_steel_dataset_patch_generator import get_image_patches

def per_class_binary_cross_entropy(y_pred, y_true):
    assert len(y_true.shape) == 3
    eps = 0.000001
    y_pred = np.clip(y_pred, eps, 1 - eps)
    per_pixel_class_cross_entropy = \
        y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
    per_class_cross_entropy = np.mean(per_pixel_class_cross_entropy, axis=(0, 1))
    return per_class_cross_entropy

def per_class_iou_score(y_pred, y_true):
    assert len(y_true.shape) == 3
    intersection = np.sum(y_pred * y_true, axis=(0, 1))
    union = np.sum(y_pred, axis=(0, 1)) + np.sum(y_true, axis=(0, 1)) - intersection
    
    iou = []
    for i in range(intersection.shape[0]):
        if union[i] == 0:
            iou.append(1.0) # Both y_pred and y_true were empty
        else:
            iou.append(intersection[i] / union[i])
    return np.array(iou)

def per_class_mask_not_empty(y_true):
    return np.sum(y_true, axis=(0, 1)) > 0.5

def predict_patches(model, img, patch_size, num_patches_per_image, num_classes):
    h, w, c = img.shape
    img_patches, x_step_size = get_image_patches(img, patch_size, num_patches_per_image)
    img_patches = np.stack(img_patches)
    y_patches = model.predict(img_patches)
    
    num_patches = y_patches.shape[0]
    combined_patches = np.zeros((num_patches, h, w, num_classes), dtype=np.float32)
    for i in range(num_patches):
        x_start = i * x_step_size
        combined_patches[i, :, x_start:x_start+patch_size, :] = y_patches[i, :, :, :]
    y = np.amax(combined_patches, axis=0, keepdims=True)
    return y

def eval_segmentation(
    model,
    dataset,
    img_list,
    patch_size,
    num_patches_per_image,
    thresholds=None,
    num_classes=4):
    
    binary_cross_entropy = np.zeros((len(img_list), num_classes), dtype=np.float32)
    dice_coeff = np.zeros((len(img_list), num_classes), dtype=np.float32)
    iou_score = np.zeros((len(img_list), num_classes), dtype=np.float32)
    mask_not_empty = np.zeros((len(img_list), num_classes), dtype=np.bool)
    
    if thresholds is None:
        thresholds = [0.5, 0.5, 0.5, 0.5]
    thresholds = np.array(thresholds)

    for i, img_name in enumerate(img_list):
        img, ann = dataset.get_example_from_img_name(img_name)
        y = predict_patches(model, img, patch_size, num_patches_per_image, num_classes)
        
        # Binarize predictions
        y_bin = np.zeros_like(y, dtype=np.uint8)
        y_bin[y > thresholds] = 1

        binary_cross_entropy[i, :] = per_class_binary_cross_entropy(y[0, :, :, :], ann)
        dice_coeff[i, :] = per_class_dice_coeff(y_bin[0, :, :, :], ann)
        iou_score[i, :] = per_class_iou_score(y_bin[0, :, :, :], ann)
        mask_not_empty[i, :] = per_class_mask_not_empty(ann)
        
    return binary_cross_entropy, dice_coeff, iou_score, mask_not_empty

In [None]:
binary_cross_entropy, dice_coeff, iou_score, mask_not_empty = eval_segmentation(
    seg_model,
    val_data_gen,
    val_imgs,
    patch_size=cfg['PATCH_SIZE'],
    num_patches_per_image=cfg['NUM_PATCHES_PER_IMAGE_VAL'],
    thresholds=None,
    num_classes=4)

In [None]:
print(f'Mean dice coeff: {np.mean(dice_coeff)}')
print(f'Mean binary cross entropy: {np.mean(binary_cross_entropy)}')
print(f'Mean IoU: {np.mean(iou_score)}')

In [None]:
for i in range(cfg['NUM_CLASSES']):
    print('*******************')
    print(f'***** Class {i} *****')
    print('*******************')
    print(f'Mean dice coeff: {np.mean(dice_coeff[:, i])}')
    print(f'Mean dice coeff (with mask): {np.mean(dice_coeff[:, i][mask_not_empty[:, i]])}')
    print(f'Mean dice coeff (no mask): {np.mean(dice_coeff[:, i][~mask_not_empty[:, i]])}')
    print(f'Mean IoU: {np.mean(iou_score[:, i])}')
    print(f'Mean IoU (with mask): {np.mean(iou_score[:, i][mask_not_empty[:, i]])}')
    print(f'Mean IoU (no mask): {np.mean(iou_score[:, i][~mask_not_empty[:, i]])}')
    print('')

In [None]:
class_id = None
mask_only = True

scores = None
if class_id is None:
    print('Worst scores for class all classes:')
    scores = np.mean(dice_coeff, axis=-1)
else:
    print(f'Worst scores for class {class_id}')
    scores = dice_coeff[:, class_id]

indices = np.argsort(scores) # Indices of worst images

if mask_only and class_id is not None:
    print('Including scores for non-empty ground truth masks only.')
    mask_only_indices = np.where(mask_not_empty[:, class_id])
    mask_only_indices = set(mask_only_indices[0].tolist())
    indices = [index for index in indices if index in mask_only_indices]

for i in indices:
    print(f'{i}: {scores[i]}')

In [None]:
# Visualize Image Prediction
img_id = 872
thresh = [0.5, 0.5, 0.5, 0.5]

img_name = val_imgs[img_id]
img, ann = val_data_gen.get_example_from_img_name(img_name)
y = predict_patches(seg_model, img, cfg['PATCH_SIZE'], cfg['NUM_PATCHES_PER_IMAGE_VAL'], cfg['NUM_CLASSES'])
plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()

for i in range(y.shape[-1]):
    plt.figure(figsize=(12.5, 3))
    plt.imshow(y[0, :, :, i])
    plt.colorbar()

for i in range(y.shape[-1]):
    plt.figure(figsize=(10, 3))
    plt.imshow(y[0, :, :, i] > thresh[i])
    plt.show()

# Save HDF5 Model

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
seg_model.save(f'patches_seg_model_{date_str}.h5')

In [None]:
!ls