In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.model.unet import build_unet_model
from steel_seg.model.classification_wrapper import build_classification_model
from steel_seg.train import (
    class_weighted_binary_classification_crossentropy,
    binary_accuracy_by_class)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training', dense_segmentation=False)
val_data, val_batches = dataset.create_dataset(dataset_type='validation', dense_segmentation=False)

In [None]:
seg_model = build_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=1,
    num_classes=cfg['NUM_CLASSES'],
    num_layers=4,
    activation=tf.keras.activations.elu,
    kernel_initializer='he_normal',
    kernel_size=(3, 3),
    pool_size=(2, 4),
    num_features=[32, 64, 128, 256],
    drop_prob=0.5)
model_checkpoint_name = 'deep'

In [None]:
!ls checkpoints

In [None]:
date_str = '20190916-092052'

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('Error. No checkpoints found.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    seg_model.load_weights(latest_checkpoint)

In [None]:
cls_model = build_classification_model(seg_model, 'conv2d_7', 4)

In [None]:
cls_model.summary()

In [None]:
cls_weights = [30.0, 40.0, 10.0, 20.0]

In [None]:
cls_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=class_weighted_binary_classification_crossentropy(cls_weights),
    metrics=[
        binary_accuracy_by_class(0),
        binary_accuracy_by_class(1),
        binary_accuracy_by_class(2),
        binary_accuracy_by_class(3),
    ]
)

In [None]:
!ls classification_checkpoints

In [None]:
#date_str = '20190916-092052'
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
date_str

In [None]:
checkpoint_name = f'{model_checkpoint_name}_{date_str}'
checkpoint_path = f'classification_checkpoints/{checkpoint_name}/cp-{checkpoint_name}' + '-{epoch:04d}.ckpt'
checkpoint_dir = os.path.dirname(checkpoint_path)

latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
initial_epoch = 0
if latest_checkpoint is None:
    print('No checkpoints found. Starting from scratch.')
else:
    print(f'Loading weights from {latest_checkpoint}')
    last_epoch = latest_checkpoint.split('-')[-1]
    last_epoch = last_epoch.split('.')[0]
    initial_epoch = int(last_epoch)
    cls_model.load_weights(latest_checkpoint)

In [None]:
# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    mode='auto',
    save_weights_only=True,
    verbose=1)


logdir = f'logs/{checkpoint_name}-{initial_epoch}'
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
]

results = cls_model.fit(
    train_data,
    epochs=400,
    verbose=2,
    callbacks=callbacks,
    validation_data=val_data,
    steps_per_epoch=train_batches,
    validation_steps=val_batches,
    validation_freq=3,
    initial_epoch=initial_epoch)