In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import yaml
from datetime import datetime
import numpy as np
import tensorflow as tf
from steel_seg.utils import (
    dice_coeff_kaggle,
    rle_to_dense,
    dense_to_rle,
    visualize_segmentations,
    onehottify)
from steel_seg.dataset.severstal_steel_dataset import SeverstalSteelDataset
from steel_seg.model.pretrained_unet import build_pretrained_unet_model
from steel_seg.train import (
    class_weighted_binary_crossentropy,
    weighted_binary_crossentropy,
    pixel_map_weighted_binary_crossentropy,
    dice_coef,
    eval)
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Necessary for CUDA 10 or something?
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"
os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_LOSS_SCALING"] = "1"

In [None]:
with open('SETTINGS.yaml') as f:
    cfg = yaml.load(f)

In [None]:
dataset = SeverstalSteelDataset.init_from_config('SETTINGS.yaml')

In [None]:
train_data, train_batches = dataset.create_dataset(dataset_type='training')
val_data, val_batches = dataset.create_dataset(dataset_type='validation')

In [None]:
seg_model = build_pretrained_unet_model(
    img_height=cfg['IMG_HEIGHT'],
    img_width=cfg['IMG_WIDTH'],
    img_channels=3,
    num_classes=cfg['NUM_CLASSES'])

In [None]:
cls_weights = [30.0, 40.0, 10.0, 20.0]

In [None]:
seg_model.summary()

In [None]:
seg_model.compile(
    optimizer=tf.train.AdamOptimizer(0.0001),
    loss=pixel_map_weighted_binary_crossentropy(cls_weights),#class_weighted_binary_crossentropy(cls_weights),
    metrics=[tf.keras.metrics.BinaryAccuracy(), dice_coef(batch_size=cfg['BATCH_SIZE'])])#[dice_coef, 'accuracy'])

In [None]:
date_str = datetime.now().strftime("%Y%m%d-%H%M%S")
checkpoint_path = f'checkpoints/cp_{date_str}.ckpt'

# Create checkpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=True,
    verbose=1)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=8, monitor='val_loss')

logdir = "logs/" + date_str
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir=logdir),
    checkpoint_cb,
    #early_stopping_cb,
]

results = seg_model.fit(train_data,
                    epochs=200,
                    verbose=2,
                    callbacks=callbacks,
                    validation_data=val_data,
                    steps_per_epoch=train_batches,
                    validation_steps=val_batches,
                    validation_freq=1)