In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from dataset.severstal_steel_dataset import SeverstalSteelDataset
from model.unet import build_unet_model
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
# # To run in half-precision mode on GPU
# dtype='float16'
# K.set_floatx(dtype)

# # default is 1e-7 which is too small for float16.  Without adjusting the epsilon, we will get NaN predictions because of divide by zero problems
# K.set_epsilon(1e-4)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training')
val_data, val_batches = dataset.create_dataset(dataset_type='validation')

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
from tensorflow.keras import backend as K

# https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
# def dice_coef(y_true, y_pred, smooth=1):
#     """
#     Dice = (2*|X & Y|)/ (|X|+ |Y|)
#          =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
#     ref: https://arxiv.org/pdf/1606.04797v1.pdf
#     """
#     intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
#     return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)


# https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
def jaccard_distance_loss(y_true, y_pred, smooth=100):
    """
    Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
            = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
    
    The jaccard distance loss is usefull for unbalanced datasets. This has been
    shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
    gradient.
    
    Ref: https://en.wikipedia.org/wiki/Jaccard_index
    
    @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
    @author: wassname
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[4, 4, 16, 32],
    drop_prob=0.5)

In [None]:
model.summary()

In [None]:
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
              loss='binary_crossentropy',
              metrics=['accuracy'])#[dice_coef, 'accuracy'])

In [None]:
checkpoint_path = "checkpoints/cp.ckpt"

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
                                                 save_weights_only=True,
                                                 verbose=1)
#tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),

logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir=logdir),
  cp_callback 
]

In [None]:
# model.fit_generator(
#     train_data,
#     epochs=10,
#     verbose=2,
#     callbacks=callbacks,
#     validation_data=val_data,
#     validation_freq=1,
#     max_queue_size=10,
#     workers=6,
#     use_multiprocessing=True,
#     shuffle=False,
# )

In [None]:
results = model.fit(train_data,
                    epochs=10,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=val_data,
                    steps_per_epoch=train_batches,
                    validation_steps=val_batches,
                    validation_freq=2)

In [None]:
for i in range(val_batches):
    y = model.predict(
        val_data,
        verbose=2,
        steps=1)