# TF model v4.0
HS04 model incorporating non-stationary environment

In [None]:
# %load_ext lab_black
import os
import tensorflow as tf
from tqdm import tqdm
from time import sleep
import meta, data_wrangling, metrics, benchmark_hs04, modeling

meta.limit_gpu_memory_use(3000)

# Parameters block (for papermill)

In [None]:
code_name = "OS_ff"
batch_name = None
tf_root = "/home/jupyter/tf"

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 1.0
sem_noise_level = 1.0
activation = "sigmoid"

tau = 1 / 3
max_unit_time = 4.0
output_ticks = 13
inject_error_ticks = 6

# Training configs
learning_rate = 0.005
zero_error_radius = 0.1
save_freq = 10

# Environment configs
# tasks = ("pho_sem", "sem_pho", "pho_pho", "sem_sem", "triangle")
tasks = ("exp_os_ff", "triangle")
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 0.02
# oral_end_pct = 0.5
oral_end_pct = 1.0

# oral_sample = 900_000
# oral_tasks_ps = (0.4, 0.4, 0.1, 0.1, 0.)
# transition_sample = 400_000
# reading_sample = 4_100_000
# reading_tasks_ps = (0.2, 0.2, 0.05, 0.05, 0.5)

oral_sample = 1_000_000
oral_tasks_ps = (1.0, 0.0)
transition_sample = 0
reading_sample = 1_000_000
reading_tasks_ps = (1.0, 0.0)

batch_size = 100
rng_seed = 2021

In [None]:
cfg = meta.ModelConfig.from_global(globals_dict=globals())

# Build model and all supporting components

In [None]:
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()

# Architechture
model = modeling.MyModel(cfg)
model.build()

# Non-stationary Environment
sampler = data_wrangling.Sampler(cfg, data)
batch_generator = sampler.generator()
sampler.plot_env()

## Core training modules

In [None]:
# Since each sub-task has its own states, it must be trained with separate optimizer and losses,
# instead of sharing the same optimizer instance (https://github.com/tensorflow/tensorflow/issues/27120)
optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (only for TensorBoard)
train_metrics = {}

# Task specific accuracy
## Caution PhoAccuracy is stateful (only taking last batch value in an epoch)
## Otherwise, all Stateless metrics are the average of all batches within an epoch

acc = {"pho": metrics.PhoAccuracy, "sem": metrics.StatelessRightSideAccuracy}
sse = metrics.StatelessSumSquaredError

for task in cfg.tasks:
    optimizers[task] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
    if cfg.zero_error_radius is not None:
        loss_fns[task] = metrics.CustomBCE(radius=cfg.zero_error_radius)
    else:
        loss_fns[task] = tf.keras.losses.BinaryCrossentropy()

    train_losses[task] = tf.keras.metrics.Mean(
        f"train_loss_{task}", dtype=tf.float32
    )  # for tensorboard only

    task_output = modeling.IN_OUT[task][1]

    if type(task_output) is list:
        train_metrics[task] = {}

        for out in task_output:
            train_metrics[task][out] = [
                acc[out](f"{task}_{out}_acc"),
                sse(f"{task}_{out}_sse"),
            ]
    else:
        train_metrics[task] = [acc[task_output](f"{task}_acc"), sse(f"{task}_sse")]

# Trainstep

In [None]:
def get_train_step(task):
    input_name, output_name = modeling.IN_OUT[task]

    if task == "triangle":

        @tf.function
        def train_step(
            x, y, model, task, loss_fn, optimizer, train_metrics, train_losses
        ):
            """Train a batch, log loss and metrics (last time step only)"""

            train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
            train_weights = [x for x in model.weights if x.name in train_weights_name]

            # TF Automatic differentiation
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                # training flag can be access within model by K.in_train_phase()
                # it can change the behavior in model() (e.g., turn on/off noise)

                loss_value_pho = loss_fn(y["pho"], y_pred["pho"])
                loss_value_sem = loss_fn(y["sem"], y_pred["sem"])
                loss_value = loss_value_pho + loss_value_sem

            grads = tape.gradient(loss_value, train_weights)

            # Weight update
            optimizer.apply_gradients(zip(grads, train_weights))

            # Calculate mean loss and metrics for tensorboard
            # Metrics update (Only last time step)
            # for y_name, metrics in train_metrics.items():
            #     if y_name == "pho":
            #         # y[0] is pho, y[0][-1] is last time step in pho
            #         [
            #             m.update_state(
            #                 tf.cast(y["pho"][-1], tf.float32), y_pred["pho"][-1]
            #             )
            #             for m in metrics
            #         ]
            #     else:
            #         # y[1] is sem, y[0][-1] is last time step in sem
            #         [
            #             m.update_state(
            #                 tf.cast(y["sem"][-1], tf.float32), y_pred["sem"][-1]
            #             )
            #             for m in metrics
            #         ]

            # # Mean loss
            # train_losses.update_state(loss_value)

    else:  # Single output tasks

        @tf.function
        def train_step(
            x, y, model, task, loss_fn, optimizer, train_metrics, train_losses
        ):
            train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
            train_weights = [x for x in model.weights if x.name in train_weights_name]

            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value = loss_fn(y, y_pred[output_name])

            grads = tape.gradient(loss_value, train_weights)
            optimizer.apply_gradients(zip(grads, train_weights))

            # [
            #     m.update_state(tf.cast(y[-1], tf.float32), y_pred[output_name][-1])
            #     for m in train_metrics
            # ]
            # train_losses.update_state(loss_value)

    return train_step


train_steps = {task: get_train_step(task) for task in cfg.tasks}
# train_steps = {task: modeling.get_train_step(task) for task in cfg.tasks}

## Tensorboard modules

In [None]:
# def write_scalar_to_tensorboard(task, step):
#     """Write metrics and loss to tensorboard"""
#     loss = train_losses[task]
#     tf.summary.scalar(loss.name, loss.result(), step=step)

#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [
#             tf.summary.scalar(m.name, m.result(), step=step)
#             for metrics in maybe_metrics.values()
#             for m in metrics
#         ]
#     else:
#         [tf.summary.scalar(m.name, m.result(), step=step) for m in maybe_metrics]


# def write_weight_histogram_to_tensorboard(step):
#     """Weight histogram"""
#     [tf.summary.histogram(f"{x.name}", x, step=step) for x in model.weights]


# def reset_metrics(task):
#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [m.reset_states() for metrics in maybe_metrics.values() for m in metrics]
#     else:
#         [m.reset_states() for m in maybe_metrics]

# Train model

In [None]:
# TensorBoard writer
# train_summary_writer = tf.summary.create_file_writer(
#     os.path.join(cfg.tensorboard_folder, "train")
# )

for epoch in tqdm(range(cfg.total_number_of_epoch)):
    for step in range(cfg.steps_per_epoch):
        # Draw task, create batch
        task, exposed_words_idx, x_batch_train, y_batch_train = next(batch_generator)

        # task switching must be done outside train_step function (will crash otherwise)
        model.set_active_task(task)

        # Run a train step
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_metrics[task],
            train_losses[task],
        )

    # with train_summary_writer.as_default():
    #     ## Write log to tensorboard (Once per epoch)
    #   x`  [write_scalar_to_tensorboard(task, step=epoch) for task in cfg.tasks]
    #     write_weight_histogram_to_tensorboard(step=epoch)

    #     ## Reset metric and loss
    #     [train_losses[x].reset_states() for x in cfg.tasks]
    #     [reset_metrics(x) for x in cfg.tasks]

    # In training loop testset eval
    # Kind of fast, but hard to integrate and customize... maybe use offline eval again.
    # [strain(task, step=epoch) for task in strain.tasks]
    # [grain(task, step=epoch) for task in grain.tasks]

    ## Save weights
    one_indexing_epoch = epoch + 1
    if one_indexing_epoch in cfg.saved_epoches:
        weight_path = cfg.saved_weights_fstring.format(epoch=one_indexing_epoch)
        model.save_weights(weight_path, overwrite=True, save_format="tf")

# Run benchmark

In [None]:
benchmark_hs04.main(code_name)

In [None]:
# !gsutil -m rsync -d -r models/{code_name} gs://tf_mirror/{code_name}
# sleep(30)
# !sudo shutdown -P +0