In [None]:
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'

In [None]:
%load_ext autoreload
%autoreload 2
from datetime import datetime
from pathlib import Path
import tensorflow as tf
from watch_recognition.reports import run_on_image_debug
from watch_recognition.datasets import view_image
from watch_recognition.models import get_unet_model, IouLoss2
import segmentation_models as sm
%matplotlib inline

ROOT_DIR = Path("../download_data/")
SAVE_DIR = Path("..")
EPOCHS = 100
image_size = (96, 96)
mask_size = image_size

In [None]:
from watch_recognition.data_preprocessing import load_binary_masks_from_coco_dataset



X, y, _ = load_binary_masks_from_coco_dataset(
    str(ROOT_DIR / "segmentation/train/result.json"),
    image_size=image_size,
)
X.shape, y.shape

In [None]:
X_val, y_val, _ = load_binary_masks_from_coco_dataset(
    str(ROOT_DIR / "segmentation/validation/result.json"),
    image_size=image_size,
)
X_val.shape, y_val.shape

In [None]:
dataset_train = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(8*32).batch(32)

In [None]:
# dataset_train = get_watch_keypoints_dataset(X, y, augment=False, image_size=image_size,
#                                             mask_size=mask_size)

In [None]:
# view_image(dataset_train)

In [None]:
# dataset_val = get_watch_keypoints_dataset(X_val, y_val, augment=False, image_size=image_size, shuffle=False,
#                                             mask_size=mask_size)
dataset_val = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(32).cache()

In [None]:
# view_image(dataset_val)


In [None]:
batch = next(dataset_val.as_numpy_iterator())

In [None]:
model = get_unet_model(
    unet_output_layer=None,
    image_size=image_size,
    n_outputs=1,
    output_activation='sigmoid',
)

In [None]:
model.summary()

In [None]:
loss = sm.losses.JaccardLoss() + sm.losses.BinaryCELoss()
optimizer = tf.keras.optimizers.Adam()
TYPE = "segmentation"
MODEL_NAME = f"efficientnetb0-unet-sigmoid-{image_size[0]}"

In [None]:
from functools import partial

from watch_recognition.reports import visualize_high_loss_examples

model.compile(
    loss=loss,
    optimizer=optimizer,
    metrics=[sm.metrics.iou_score, sm.metrics.f1_score]
)

start = datetime.now()

logdir = SAVE_DIR / f"tensorboard_logs/{TYPE}/{MODEL_NAME}/run_{start.timestamp()}"
print(logdir)
file_writer_distance_metrics_train = tf.summary.create_file_writer(str(logdir) + "/train")
file_writer_distance_metrics_validation = tf.summary.create_file_writer(
    str(logdir) + "/validation"
)

model_path = SAVE_DIR / f"models/{TYPE}/{MODEL_NAME}/run_{start.timestamp()}"
model.fit(
    dataset_train,
    epochs=EPOCHS*2,
    validation_data=dataset_val,
    callbacks=[
        tf.keras.callbacks.TensorBoard(
            log_dir=logdir,
            update_freq="epoch",
        ),
        tf.keras.callbacks.LambdaCallback(
                on_epoch_end=partial(
                    visualize_high_loss_examples,
                    dataset=dataset_train,
                    loss=loss,
                    file_writer=file_writer_distance_metrics_train,
                    model=model,
                    every_n_epoch=5,
                )
            ),
            tf.keras.callbacks.LambdaCallback(
                on_epoch_end=partial(
                    visualize_high_loss_examples,
                    dataset=dataset_val,
                    loss=loss,
                    file_writer=file_writer_distance_metrics_validation,
                    model=model,
                )
            ),
        # tf.keras.callbacks.ModelCheckpoint(
        #     filepath=model_path,
        #     save_weights_only=False,
        #     monitor="val_loss",
        #     save_best_only=True,
        # ),
    ],
)
elapsed = (datetime.now() - start).seconds
print(
    f"total training time: {elapsed / 60} minutes, average: {elapsed / 60 / EPOCHS} minutes/epoch"
)

In [None]:
train_X, train_y = next(iter(dataset_train))
train_X, train_y = train_X.numpy(), train_y.numpy()

run_on_image_debug(model, train_X[0])

In [None]:
for image in X_val[10:20]:
    run_on_image_debug(model, image)

In [None]:
print(model_path)
loaded_model = tf.keras.models.load_model(model_path, compile=False)


for image in X[10:20]:
    run_on_image_debug(loaded_model, image)

In [None]:
for image in X_val[10:20]:
    run_on_image_debug(loaded_model, image)

