In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from dataset.severstal_steel_dataset import SeverstalSteelDataset
from model.unet import build_unet_model
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# TODO
# - Write better visualization code
# - Move training into a script that can be called from the command line
# - Read about approaches to parameter search
# - Export model and load in Kaggle kernel
# - Figure out why dice_coeff is wrong

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
# # To run in half-precision mode on GPU
# dtype='float16'
# K.set_floatx(dtype)

# # default is 1e-7 which is too small for float16.  Without adjusting the epsilon, we will get NaN predictions because of divide by zero problems
# K.set_epsilon(1e-4)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training')
val_data, val_batches = dataset.create_dataset(dataset_type='validation')

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
from tensorflow.keras import backend as K

# https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
# def dice_coef(y_true, y_pred, smooth=1):
#     """
#     Dice = (2*|X & Y|)/ (|X|+ |Y|)
#          =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
#     ref: https://arxiv.org/pdf/1606.04797v1.pdf
#     """
#     intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
#     return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)


# https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
def jaccard_distance_loss(y_true, y_pred, smooth=100):
    """
    Jaccard = (|X & Y|)/ (|X|+ |Y| - |X & Y|)
            = sum(|A*B|)/(sum(|A|)+sum(|B|)-sum(|A*B|))
    
    The jaccard distance loss is usefull for unbalanced datasets. This has been
    shifted so it converges on 0 and is smoothed to avoid exploding or disapearing
    gradient.
    
    Ref: https://en.wikipedia.org/wiki/Jaccard_index
    
    @url: https://gist.github.com/wassname/f1452b748efcbeb4cb9b1d059dce6f96
    @author: wassname
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return (1 - jac) * smooth

def dice_coef(y_true, y_pred, smooth=0.0001):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_intersection(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return intersection

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[4, 4, 16, 32],
    drop_prob=0.5)

In [None]:
model.summary()

In [None]:
def focal_loss(alpha=0.25, gamma=2):
    def focal_loss_with_logits(logits, targets, alpha, gamma, y_pred):
        weight_a = alpha * (1 - y_pred) ** gamma * targets
        weight_b = (1 - alpha) * y_pred ** gamma * (1 - targets)

        return (tf.log1p(tf.exp(-tf.abs(logits))) + tf.nn.relu(-logits)) * (weight_a + weight_b) + logits * weight_b 

    def loss(y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())
        logits = tf.log(y_pred / (1 - y_pred))

        loss = focal_loss_with_logits(logits=logits, targets=y_true, alpha=alpha, gamma=gamma, y_pred=y_pred)

        return tf.reduce_mean(loss)

    return loss


# def weighted_cross_entropy(beta):
#     def convert_to_logits(y_pred):
#         # see https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/keras/backend.py#L3525
#         y_pred = tf.clip_by_value(y_pred, tf.keras.backend.epsilon(), 1 - tf.keras.backend.epsilon())

#         return tf.log(y_pred / (1 - y_pred))

#     def loss(y_true, y_pred):
#         y_pred = convert_to_logits(y_pred)
#         loss = tf.nn.weighted_cross_entropy_with_logits(logits=y_pred, targets=y_true, pos_weight=beta)

#         return tf.reduce_mean(loss)

#     return loss

def weighted_binary_crossentropy(beta, from_logits=False):
    def _weighted_binary_crossentropy(target, output):
        # From https://github.com/tensorflow/tensorflow/blob/r1.14/tensorflow/python/keras/backend.py#L4213-L4243
        if not from_logits:
#             if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
#                 output.op.type != 'Sigmoid'):
#                 epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
#                 output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)

#                 # Compute cross entropy from probabilities.
#                 bce = beta * target * math_ops.log(output + epsilon())
#                 bce += (1 - target) * math_ops.log(1 - output + epsilon())
#                 return -bce
#             else:
            # When sigmoid activation function is used for output operation, we
            # use logits from the sigmoid function directly to compute loss in order
            # to prevent collapsing zero when training.
            assert len(output.op.inputs) == 1
            output = output.op.inputs[0]
        return tf.nn.weighted_cross_entropy_with_logits(logits=output, targets=target, pos_weight=beta)
    return _weighted_binary_crossentropy

def binary_crossentropy(target, output):
    # When sigmoid activation function is used for output operation, we
    # use logits from the sigmoid function directly to compute loss in order
    # to prevent collapsing zero when training.
    assert len(output.op.inputs) == 1
    output = output.op.inputs[0]
    return tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)


In [None]:
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
              loss=weighted_binary_crossentropy(10.0),#'binary_crossentropy',
              metrics=[tf.keras.metrics.BinaryAccuracy(), dice_coef, dice_coef_intersection])#[dice_coef, 'accuracy'])

In [None]:
# Load from checkpoint
#model.load_weights('checkpoints/cp_20190813-211504.ckpt')
model.load_weights('checkpoints/cp_20190814-224021.ckpt')

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f'checkpoints/cp_{date_str}.ckpt'

# Create checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, 
                                                 save_weights_only=True,
                                                 verbose=1)
#tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),

logdir = "logs/" + date_str
callbacks = [
  tf.keras.callbacks.TensorBoard(log_dir=logdir),
  cp_callback 
]

results = model.fit(train_data,
                    epochs=20,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=val_data,
                    steps_per_epoch=train_batches,
                    validation_steps=val_batches,
                    validation_freq=1)

In [None]:
iterator = train_data.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    value = sess.run(next_element)

In [None]:
np.amax(value[1])

In [None]:
val_imgs = dataset.get_image_list('validation')

In [None]:
len(val_imgs)

In [None]:
def dice_coeff(y_pred, y_true):
    y_pred = np.where(y_pred > 0.5, 1, 0)
    
    dice_scores = []
    for i in range(y_pred.shape[-1]):
        y_pred_sum = np.sum(y_pred[:, :, i])
        y_true_sum = np.sum(y_true[:, :, i])
        if y_pred_sum == 0 and y_true_sum == 0:
            dice_scores.append(1.0)
            continue
        intersection = np.sum(y_pred[:, :, i] * y_true[:, :, i])
        dice_scores.append(
            2 * intersection / (y_pred_sum + y_true_sum))
    return np.mean(dice_scores)

def onehottify(x, n=None, dtype=float):
    """1-hot encode x with the max value n (computed from data if n is None)."""
    x = np.asarray(x)
    n = np.max(x) + 1 if n is None else n
    return np.eye(n, dtype=dtype)[x]

In [None]:
dice_coeffs = []
for img_name in val_imgs:
    img, ann = dataset.get_example_from_img_name(img_name)
    img_batch = np.expand_dims(img, axis=0)
    y = model.predict(img_batch)
    #dice_coeffs.append(dice_coeff(y[0, :, :, :], ann))
    y_argmax = np.argmax(y, axis=-1)
    y_one_hot = onehottify(y_argmax, 4)
    y_one_hot[y < 0.5] = 0
    dice_coeffs.append(dice_coeff(y_one_hot[0, :, :, :], ann))

print(f'Mean dice coeff: {np.mean(dice_coeffs)}')

In [None]:
am = np.argmax(y, axis=-1)

In [None]:
am

In [None]:
onehottify(am, 4)

In [None]:
a = np.zeros((256, 1600, 4))
a[np.arange()]
>>> b = np.zeros((3, 4))
>>> b[np.arange(3), a] = 1
>>> b

In [None]:
img_name = val_imgs[2]
img, ann = dataset.get_example_from_img_name(img_name)
img_batch = np.expand_dims(img, axis=0)
y = model.predict(img_batch)
plt.imshow(np.repeat(img, 3, axis=-1))
plt.show()
plt.imshow(y[0, :, :, 2])
plt.show()


In [None]:
y

In [None]:
y_tmp = y[0, :, :, 4]
y_norm = (y_tmp - np.amin(y_tmp))
y_norm = y_norm / np.amax(y_norm)

In [None]:
np.sum(y_norm > 0.01)

In [None]:
def visualize_prediction(x, y_pred, y_true):
    x = np.repeat(x, 3, axis=-1)
    _, axs = plt.subplots(y_pred.shape[-1] + 1, 2, figsize=(18, 10))
    axs[0, 0].imshow(x)

    cmaps = ['Reds', 'Blues', 'Greens', 'Purples']
    for i in range(y_pred.shape[-1]):
        #axs[i + 1, 0].imshow(x)
        axs[i + 1, 0].imshow(y_true[:, :, i], alpha=0.4, cmap=cmaps[i])
        axs[i + 1, 1].imshow(x)
        axs[i + 1, 1].imshow(y_pred[:, :, i], alpha=0.4, cmap=cmaps[i])
    plt.show()

In [None]:
visualize_prediction(img, y_bin[0, :, :, :], ann)
