In [None]:
%load_ext autoreload
%autoreload 2

# Initialize and load data

In [None]:
import os
import yaml
import random
import json
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    per_class_dice_coeff,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset_generator import \
    SeverstalSteelDatasetGenerator
from steel_seg.dataset.dataset_utils import load_annotations, split_data
from steel_seg.model.unet import build_unet_model
from steel_seg.train import (
    class_weighted_binary_crossentropy,
    weighted_binary_crossentropy,
    pixel_map_weighted_binary_crossentropy,
    dice_loss_multi_class,
    dice_coef,
    DiceCoefByClassAndEmptiness,
    eval)
import matplotlib.pyplot as plt
import keras
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
#os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
#os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
anns_dict = load_annotations(cfg['TRAIN_ANNOTATIONS_FILE'])
imgs = list(anns_dict.keys())

test_imgs, val_imgs, train_imgs = split_data(imgs,
                                             test_split=cfg['TEST_SPLIT'],
                                             val_split=cfg['VAL_SPLIT'],
                                             batch_size=cfg['SEGMENTATION_BATCH_SIZE'],
                                             load_cached=True)

In [None]:
train_data = SeverstalSteelDatasetGenerator(
     train_imgs,
     is_training=True,
     train_img_dir=cfg['TRAIN_IMAGE_DIR'],
     train_anns_file=cfg['TRAIN_ANNOTATIONS_FILE'],
     img_height=cfg['IMG_HEIGHT'],
     img_width=cfg['IMG_WIDTH'],
     num_classes=cfg['NUM_CLASSES'],
     batch_size=cfg['SEGMENTATION_BATCH_SIZE'],
     brightness_max_delta=cfg['BRIGHTNESS_MAX_DELTA'],
     contrast_lower_factor=cfg['CONTRAST_LOWER_FACTOR'],
     contrast_upper_factor=cfg['CONTRAST_UPPER_FACTOR'],
     balance_classes=cfg['SEGMENTATION_BALANCE_CLASSES'],
     max_oversample_rate=cfg['SEGMENTATION_MAX_OVERSAMPLE_RATE'])

val_data = SeverstalSteelDatasetGenerator(
     val_imgs,
     is_training=False,
     train_img_dir=cfg['TRAIN_IMAGE_DIR'],
     train_anns_file=cfg['TRAIN_ANNOTATIONS_FILE'],
     img_height=cfg['IMG_HEIGHT'],
     img_width=cfg['IMG_WIDTH'],
     num_classes=cfg['NUM_CLASSES'],
     batch_size=cfg['SEGMENTATION_BATCH_SIZE'],
     brightness_max_delta=None,
     contrast_lower_factor=None,
     contrast_upper_factor=None,
     balance_classes=None,
     max_oversample_rate=None)

test_data = SeverstalSteelDatasetGenerator(
     test_imgs,
     is_training=False,
     train_img_dir=cfg['TRAIN_IMAGE_DIR'],
     train_anns_file=cfg['TRAIN_ANNOTATIONS_FILE'],
     img_height=cfg['IMG_HEIGHT'],
     img_width=cfg['IMG_WIDTH'],
     num_classes=cfg['NUM_CLASSES'],
     batch_size=1,
     brightness_max_delta=None,
     contrast_lower_factor=None,
     contrast_upper_factor=None,
     balance_classes=None,
     max_oversample_rate=None)

In [None]:
# count = 0
# for ex in train_data:
#     count += 1
#     if count == 4:
#         break
# img, ann = ex

In [None]:
# c = 1

# plt.figure(figsize=(20, 10))
# plt.imshow(np.repeat(img[c, :, :, :], 3, axis=-1))
# plt.show()
# for i in range(4):
#     plt.figure(figsize=(20, 10))
#     plt.imshow(ann[c, :, :, i])
#     plt.show()

# Build Model

In [None]:
len(test_data)

In [None]:
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import segmentation_models as sm
import keras

In [None]:
tf.keras.backend.set_image_data_format('channels_last')

In [None]:
backbone = 'resnet34'#'mobilenetv2' # 'resnet34' #https://github.com/qubvel/segmentation_models#Backbones
base_model = sm.Unet(backbone, classes=cfg['NUM_CLASSES'], encoder_weights='imagenet', activation='sigmoid')

In [None]:
# TODO confirm that this is the correct preprocessing
input = keras.Input(shape=(cfg['IMG_HEIGHT'], cfg['IMG_WIDTH'], 1))
# Necessary to wrap in keras.layers.Lambda so that save_model works
x = keras.layers.Lambda(lambda x: tf.tile(x / 127.5 - 1.0, [1, 1, 1, 3]))(input)
output = base_model(x)
seg_model = keras.Model(inputs=[input], outputs=[output])

model_checkpoint_name = 'resnet34_unet_pretrained'

In [None]:
cls_weights = [1.0, 1.0, 1.0, 1.0]

In [None]:
batch_size = cfg['SEGMENTATION_BATCH_SIZE']

# set class weights for dice_loss (0, 1, 2, 3, background)
#dice_loss = sm.losses.DiceLoss(class_weights=np.array([1.0, 1.0, 1.0, 1.0, 1.0]))
# JaccardLoss
# See https://github.com/qubvel/segmentation_models/blob/master/examples/multiclass%20segmentation%20(camvid).ipynb
# See: https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/losses.py


seg_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=pixel_map_weighted_binary_crossentropy(cls_weights), #dice_loss_multi_class,
    metrics=[
        tf.keras.metrics.BinaryAccuracy(),
        dice_coef(batch_size=batch_size),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=True, batch_size=batch_size, name='dice_coef_on_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=True, batch_size=batch_size, name='dice_coef_on_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=True, batch_size=batch_size, name='dice_coef_on_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=True, batch_size=batch_size, name='dice_coef_on_empty_cls3'),
        DiceCoefByClassAndEmptiness(cls_id=0, empty_masks_only=False, batch_size=batch_size, name='dice_coef_on_non_empty_cls0'),
        DiceCoefByClassAndEmptiness(cls_id=1, empty_masks_only=False, batch_size=batch_size, name='dice_coef_on_non_empty_cls1'),
        DiceCoefByClassAndEmptiness(cls_id=2, empty_masks_only=False, batch_size=batch_size, name='dice_coef_on_non_empty_cls2'),
        DiceCoefByClassAndEmptiness(cls_id=3, empty_masks_only=False, batch_size=batch_size, name='dice_coef_on_non_empty_cls3'),
    ]
)

# Load Initial Weights

In [None]:
!ls checkpoints

In [None]:
model_checkpoint_name

In [None]:
#date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
date_str = '20191005-175252'

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

initial_epoch = 0
checkpoints = os.listdir(checkpoint_dir)
checkpoints.sort()
if len(checkpoints) == 0:
    print('No checkpoints found. Starting from scratch.')
else:
    latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    seg_model.load_weights(latest_checkpoint)

# Use new model name?

In [None]:
model_checkpoint_name = 'resnet34_unet_pretrained_imgaug'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
initial_epoch = 0

# Train

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

In [None]:
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val__dice_coef',#'val_loss',
    save_best_only=True,
    mode='max',#'auto',
    save_weights_only=True,
    verbose=1)

logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]

results = seg_model.fit_generator(
    train_data,
    epochs=400,
    verbose=2,
    callbacks=callbacks,
    validation_data=val_data,
    steps_per_epoch=len(train_data),
    validation_steps=len(val_data),
    validation_freq=1,#3,
    initial_epoch=initial_epoch,
    workers=6,
    use_multiprocessing=True)

# Tune Thresholds on Validation Set

In [None]:
def per_class_binary_cross_entropy(y_pred, y_true):
    assert len(y_true.shape) == 3
    eps = 0.000001
    y_pred = np.clip(y_pred, eps, 1 - eps)
    per_pixel_class_cross_entropy = \
        y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
    per_class_cross_entropy = np.mean(per_pixel_class_cross_entropy, axis=(0, 1))
    return per_class_cross_entropy

def per_class_iou_score(y_pred, y_true):
    assert len(y_true.shape) == 3
    intersection = np.sum(y_pred * y_true, axis=(0, 1))
    union = np.sum(y_pred, axis=(0, 1)) + np.sum(y_true, axis=(0, 1)) - intersection
    
    iou = []
    for i in range(intersection.shape[0]):
        if union[i] == 0:
            iou.append(1.0) # Both y_pred and y_true were empty
        else:
            iou.append(intersection[i] / union[i])
    return np.array(iou)

def per_class_mask_not_empty(y_true):
    return np.sum(y_true, axis=(0, 1)) > 0.5

def apply_tta(img):
    h, w, c = img.shape # Assert only 3 dimensions
    
    img_flip_h = img[:, ::-1, :]
    img_flip_v = img[::-1, :, :]
    img_flip_hv = img[::-1, ::-1, :]
    
    tta_batch = np.stack([img, img_flip_h, img_flip_v, img_flip_hv])
    return tta_batch

def combine_tta_preds(y_tta):
    y_0 = y_tta[0, :, :, :]
    y_1 = y_tta[1, :, ::-1, :]
    y_2 = y_tta[2, ::-1, :, :]
    y_3 = y_tta[3, ::-1, ::-1, :]
    
    y = np.stack([y_0, y_1, y_2, y_3])
    y = np.mean(y, axis=0, keepdims=True)
    return y

def eval_segmentation(
    model,
    dataset,
    img_list,
    thresholds=None,
    num_classes=4):
    
    binary_cross_entropy = np.zeros((len(img_list), num_classes), dtype=np.float32)
    dice_coeff = np.zeros((len(img_list), num_classes), dtype=np.float32)
    iou_score = np.zeros((len(img_list), num_classes), dtype=np.float32)
    mask_not_empty = np.zeros((len(img_list), num_classes), dtype=np.bool)
    
    if thresholds is None:
        thresholds = [0.5, 0.5, 0.5, 0.5]
    thresholds = np.array(thresholds)

    for i, img_name in enumerate(img_list):
        img, ann = dataset.get_example_from_img_name(img_name)
        
        #img_batch = np.expand_dims(img, axis=0)
        img_batch = apply_tta(img)
        y = model.predict(img_batch)
        y = combine_tta_preds(y)
        
        # Binarize predictions
        y_bin = np.zeros_like(y, dtype=np.uint8)
        y_bin[y > thresholds] = 1

        binary_cross_entropy[i, :] = per_class_binary_cross_entropy(y[0, :, :, :], ann)
        dice_coeff[i, :] = per_class_dice_coeff(y_bin[0, :, :, :], ann)
        iou_score[i, :] = per_class_iou_score(y_bin[0, :, :, :], ann)
        mask_not_empty[i, :] = per_class_mask_not_empty(ann)
    
    print(f'Mean dice coeff: {np.mean(dice_coeff)}')
    print(f'Mean binary cross entropy: {np.mean(binary_cross_entropy)}')
    print(f'Mean IoU: {np.mean(iou_score)}')
    
    for i in range(cfg['NUM_CLASSES']):
        print('*******************')
        print(f'***** Class {i} *****')
        print('*******************')
        print(f'Mean dice coeff: {np.mean(dice_coeff[:, i])}')
        print(f'Mean dice coeff (with mask): {np.mean(dice_coeff[:, i][mask_not_empty[:, i]])}')
        print(f'Mean dice coeff (no mask): {np.mean(dice_coeff[:, i][~mask_not_empty[:, i]])}')
        print(f'Mean IoU: {np.mean(iou_score[:, i])}')
        print(f'Mean IoU (with mask): {np.mean(iou_score[:, i][mask_not_empty[:, i]])}')
        print(f'Mean IoU (no mask): {np.mean(iou_score[:, i][~mask_not_empty[:, i]])}')
        print('')
        
    return binary_cross_entropy, dice_coeff, iou_score, mask_not_empty

In [None]:
for thresholds in [[0.5, 0.5, 0.5, 0.5]]:#, [0.6, 0.6, 0.6, 0.6], [0.7, 0.7, 0.7, 0.7]]:
    print(f'\n\nThresholds: {thresholds}')
    binary_cross_entropy, dice_coeff, iou_score, mask_not_empty = eval_segmentation(
        seg_model,
        val_data,
        val_imgs,
        thresholds=thresholds,
        num_classes=4)

# Evaluate on Test Set

In [None]:
seg_model = keras.models.load_model('resnet_seg_model_20191009-220108.h5', custom_objects={'tf': tf})

In [None]:
binary_cross_entropy, dice_coeff, iou_score, mask_not_empty = eval_segmentation(
    seg_model,
    test_data,
    test_imgs,
    thresholds=[0.5, 0.5, 0.5, 0.5],
    num_classes=4)

In [None]:
class_id = None
mask_only = True

scores = None
if class_id is None:
    print('Worst scores for all classes:')
    scores = np.mean(dice_coeff, axis=-1)
else:
    print(f'Worst scores for class {class_id}')
    scores = dice_coeff[:, class_id]

indices = np.argsort(scores) # Indices of worst images

if mask_only and class_id is not None:
    print('Including scores for non-empty ground truth masks only.')
    mask_only_indices = np.where(mask_not_empty[:, class_id])
    mask_only_indices = set(mask_only_indices[0].tolist())
    indices = [index for index in indices if index in mask_only_indices]

for i in indices:
    print(f'{i}: {scores[i]}')

In [None]:
# Visualize Image Prediction
img_id = 145
thresh = [0.5, 0.5, 0.5, 0.5]

img_name = test_imgs[img_id]
img, ann = test_data.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = seg_model.predict(img_batch)
plt.figure(figsize=(10, 3))
plt.imshow(visualize_segmentations(np.repeat(img, 3, axis=-1), ann))
plt.show()

for i in range(y.shape[-1]):
    plt.figure(figsize=(12.5, 3))
    plt.imshow(y[0, :, :, i])
    plt.colorbar()

for i in range(y.shape[-1]):
    plt.figure(figsize=(10, 3))
    plt.imshow(y[0, :, :, i] > thresh[i])
    plt.show()

# Save HDF5 Model

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
seg_model.save(f'resnet_seg_model_imgaug_{date_str}.h5')

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
tf.keras.models.save_model(
    seg_model,
    f'tf_resnet_seg_model_{date_str}.h5',
    include_optimizer=False)