In [None]:
import os
import time
import pandas as pd
import numpy as np
import yaml
import tensorflow as tf
from tensorflow import keras

from src.model_creation import get_train_dataset
from src.model_creation import neural_net
from src.visualization import plot_history

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
#### load config file
with open('model_config.yaml') as f:
    config = yaml.safe_load(f)

In [None]:
#### define learning rate function
def lr_scheduler_fn(epoch):
    return config['INITIAL_LEARNING_RATE'] * \
           tf.math.pow(config['LR_DECAY_FACTOR'], epoch//config['EPOCHS_PER_LR_DECAY'])

#### define callbacks
callbacks = [
            tf.keras.callbacks.LearningRateScheduler(
            lr_scheduler_fn,
            verbose=1
),
            tf.keras.callbacks.EarlyStopping(
            monitor='loss', patience=3
),
]

In [None]:
def run_model():
    if TPU:
        resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
        tf.config.experimental_connect_to_cluster(resolver)
        # This is the TPU initialization code that has to be at the beginning.
        tf.tpu.experimental.initialize_tpu_system(resolver)
        print("All devices: ", tf.config.list_logical_devices('TPU'))
        strategy = tf.distribute.TPUStrategy(resolver)
    else:
        strategy = tf.distribute.get_strategy()

    (train_ds, num_train_examples) = get_train_dataset.make_dataset(
        config["TRAINING_DATA"],
        label_column_name=config["LABEL_COLUMN_NAME"],
        image_column_name=config["IMAGE_COLUMN_NAME"],
        image_size=config["IMAGE_SIZE"],
        batch_size=config["BATCH_SIZE"],
        repeat_forever=True,
        augment=True
    )    
    (val_ds, num_val_examples) = get_train_dataset.make_dataset(
        config["VAL_DATA"],
        label_column_name=config["LABEL_COLUMN_NAME"],
        image_column_name=config["IMAGE_COLUMN_NAME"],
        image_size=config["IMAGE_SIZE"],
        batch_size=config["BATCH_SIZE"],
        repeat_forever=True,
        augment=False
    )

    with strategy.scope():
        # create optimizer for neural network
        optimizer = keras.optimizers.RMSprop(
            lr=config["INITIAL_LEARNING_RATE"],
            rho=config["RMSPROP_RHO"],
            momentum=config["RMSPROP_MOMENTUM"],
            epsilon=config["RMSPROP_EPSILON"]
        )

        # create neural network
        model = nets.make_neural_network(
            base_arch_name = "effecientnetv2l",
            weights = config["PRETRAINED_MODEL"],
            image_size = config["IMAGE_SIZE"],
            dropout_pct = config["DROPOUT_PCT"],
            n_classes = config["NUM_CLASSES"],
            input_dtype = tf.float32,
            train_full_network = True
        )

        # load pretrained model
        if config["PRETRAINED_MODEL"] != "imagenet" and os.path.exists(config["PRETRAINED_MODEL"]):
            model.load_weights(config["PRETRAINED_MODEL"])

        if model is None:
            print("No model to train.")
            return

        # compile the network for training
        model.compile(
            loss=config['LOSS'],
            optimizer=optimizer,
            metrics=[
                "accuracy", 
                tf.keras.metrics.TopKCategoricalAccuracy(k=3, name="top3 accuracy"),
                tf.keras.metrics.TopKCategoricalAccuracy(k=10, name="top10 accuracy")
            ]
        )

        STEPS_PER_EPOCH = np.ceil(num_train_examples/config["BATCH_SIZE"])
        VAL_STEPS = np.ceil(num_val_examples/config["BATCH_SIZE"])

        start = time.time()
        history = model.fit(
            train_ds,
            validation_data=val_ds,
            validation_steps=VAL_STEPS,
            epochs=config["NUM_EPOCHS"],
            steps_per_epoch=STEPS_PER_EPOCH,
            callbacks=callbacks
        )

        end = time.time()
        print("time elapsed during fit: {:.1f}".format(end-start))
        print(history.history)
        model.save(config["FINAL_SAVE_DIR"])

In [None]:
plot_history.plot_history(accuracy)
plot_history.plot_history(loss)