# Initialize and load data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.model.unet import build_unet_model
from steel_seg.model.classification_wrapper import build_classification_model
from steel_seg.train import (
    class_weighted_binary_classification_crossentropy,
    binary_accuracy_by_class)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training', dense_segmentation=False)
val_data, val_batches = dataset.create_dataset(dataset_type='validation', dense_segmentation=False)

# Build Model

In [None]:
IMG_SHAPE = (cfg['IMG_HEIGHT'], cfg['IMG_WIDTH'], 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

In [None]:
#base_model.trainable = False

In [None]:
input = tf.keras.Input(shape=(cfg['IMG_HEIGHT'], cfg['IMG_WIDTH'], 1))
# Necessary to wrap in keras.layers.Lambda so that save_model works
x = tf.keras.layers.Lambda(lambda x: tf.tile(x / 127.5 - 1.0, [1, 1, 1, 3]))(input)
x = base_model(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
output = tf.keras.layers.Dense(cfg['NUM_CLASSES'], activation=tf.keras.activations.sigmoid)(x)
model = tf.keras.Model(inputs=[input], outputs=[output])

In [None]:
model_checkpoint_name = 'mobilenet_finetune'

In [None]:
#cls_weights = [30.0, 40.0, 10.0, 20.0]
cls_weights = [1.0, 1.0, 1.0, 1.0]

model.compile(optimizer=tf.train.AdamOptimizer(0.0001),#tf.keras.optimizers.RMSprop(lr=0.0001),
              loss=class_weighted_binary_classification_crossentropy(cls_weights),#'binary_crossentropy',
              metrics=[
                binary_accuracy_by_class(0),
                binary_accuracy_by_class(1),
                binary_accuracy_by_class(2),
                binary_accuracy_by_class(3),
              ]
)

In [None]:
model.summary()

# Load Initial Weights

In [None]:
!ls classification_checkpoints

In [None]:
model_checkpoint_name

In [None]:
date_str = '20191004-090653' # First try, probably not great weights
#date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
date_str

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'classification_checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('No checkpoints found. Starting from scratch.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    model.load_weights(latest_checkpoint)

## Use new model name?

In [None]:
#model_checkpoint_name = 'classification'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
initial_epoch = 0

# Train

In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    mode='auto',
    save_weights_only=True,
    verbose=1)


logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]
results = model.fit(
    train_data,
    epochs=400,
    verbose=2,
    callbacks=callbacks,
    validation_data=val_data,
    steps_per_epoch=train_batches,
    validation_steps=val_batches,
    validation_freq=1,
    initial_epoch=initial_epoch)

# Evaluate

In [None]:
val_imgs = dataset.get_image_list('validation')
len(val_imgs)

In [None]:
y_preds = np.zeros((len(val_imgs), 4), dtype=np.float32)
y_true = np.zeros((len(val_imgs), 4), dtype=np.uint8)

for i, img_name in enumerate(val_imgs):
    img, ann = dataset.get_example_from_img_name(img_name)
    img_batch = np.expand_dims(img, axis=0)
    y_cls = model.predict(img_batch)
    
    y_true[i, :] = np.amax(ann, axis=(0, 1))
    y_preds[i, :] = y_cls[0, :]

In [None]:
def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None):
    """pretty print for confusion matrixes"""
    columnwidth = max([len(x) for x in labels] + [5])  # 5 is value length
    empty_cell = " " * columnwidth
    # Print header
    print("    " + empty_cell, end=" ")
    for label in labels:
        print("%{0}s".format(columnwidth) % label, end=" ")
    print()
    # Print rows
    for i, label1 in enumerate(labels):
        print("    %{0}s".format(columnwidth) % label1, end=" ")
        for j in range(len(labels)):
            cell = "%{0}.2f".format(columnwidth) % cm[i, j]
            if hide_zeroes:
                cell = cell if float(cm[i, j]) != 0 else empty_cell
            if hide_diagonal:
                cell = cell if i != j else empty_cell
            if hide_threshold:
                cell = cell if cm[i, j] > hide_threshold else empty_cell
            print(cell, end=" ")
        print()

In [None]:
from sklearn.metrics import roc_curve, auc, confusion_matrix
thresholds = [0.5, 0.5, 0.5, 0.5]
for i in range(y_true.shape[-1]):
    y_preds_thresh = (y_preds > thresholds[i]).astype(np.uint8)
    cm = confusion_matrix(y_true[:, i], y_preds_thresh[:, i])
    print(f'Confusion matric for class {i}\n(Actual labels on left)')
    print_cm(cm, ['0', '1'])
    cm_norm = cm / np.sum(cm, axis=-1, keepdims=True)
    print_cm(cm_norm, ['0', '1'])
    fpr, tpr, _ = roc_curve(y_true[:, i], y_preds[:, i])
    roc_auc = auc(fpr, tpr)
    
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange',
             lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC for class {i}')
    plt.legend(loc="lower right")
    plt.show()

# Save HDF5 Model

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
model.save(f'mobilenet_classification_model_{date_str}.h5')