# TF model 

# Parameters block (for papermill)

In [None]:
code_name = "_boo"
batch_name = None

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.
sem_noise_level = 0.
activation = "sigmoid"

tau = 1 / 3
max_unit_time = 4.0
output_ticks = 13
inject_error_ticks = 11

# Training configs
learning_rate = 0.001
zero_error_radius = 0.1
save_freq = 10

# Environment configs
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999

task_names = ("pho_sem", "sem_pho", "pho_pho", "sem_sem", "triangle")
tasks_ps = (0.2, 0.2, 0.05, 0.05, 0.5)
# tasks_ps = (0.4, 0.4, 0.1, 0.1)

total_sample = 5_000_000
batch_size = 4
rng_seed = 2021
which_gpu = 0

# Allocate GPU resources
IMPORTANT: DO NOT IMPORT OTHER LIBS BEFORE THIS BLOCK!!!
Otherwise, the GPU will not be instantiate correctly.

In [None]:
from meta import split_gpu
split_gpu(which_gpu=which_gpu) 

In [None]:
import tensorflow as tf
import numpy as np
print(tf.__version__)

# NOTE: determinism is only available in TF nightly (2.8)
# tf.keras.utils.set_random_seed(rng_seed)
# tf.config.experimental.enable_op_determinism()

tf.random.set_seed(rng_seed)
np.random.seed(rng_seed)

In [None]:
import os
from tqdm import tqdm
from time import sleep
import meta, data_wrangling, metrics, modeling
from dotenv import load_dotenv

load_dotenv()

# Config, Data, and Model modules
- About Config():
    - We build a Config() object from the parameters described in the parameter block above
    - As long as there is no change in the underlying source code (./src), we can use the same Config to producce a somewhat similar model in TF2.6
    - It is "somewhat" similar because running model on GPU is not deterministic (and thus the results are not identical)
    - Deterministic mode will be available in TF 2.8 release
- About MyData():
    - It is a class that contains all the data that is used by the model training
- About MyModel():
    - It is a class that contains the triangle model implementation on TensorFlow using subclass-level API
    - It is a subclass of tf.keras.Model
    - It contains multiple tasks (e.g., triangle, PS, SP, PP, SS, etc.)
    - The behavior of the model is defined by the task

In [None]:
cfg = meta.Config.from_global(**globals())
# cfg = meta.Config.from_json(os.path.join(tf_root, "models", batch_name, code_name, "model_config.json"))

data = data_wrangling.MyData()
model = modeling.MyModel(cfg)
model.build()

# (Training) Environment module
- Environment defines what the model is trained on. It consists of one or more stages. 
- Each stage describe what the model training tasks are and their probability (how often a task is used during training). It contains one or more tasks. 
- Each task contains how fast the corpus (actual word that can be sampled) is opened, default to full open (that follows the word frequency compression sampling stragegy defined in config). 
- See the docstrings in each object for further details.

In [None]:
from environment import Task, Stage, Experience, Sampler

stages = [
    Stage(
        name="one",
        tasks=[Task(x) for x in cfg.task_names], 
        stage_sample=cfg.total_sample, 
        task_probability_start=cfg.tasks_ps
        )
]

experience = Experience(stages)
sampler = Sampler(cfg, data, experience)
batch_generator = sampler.generator()

## Plots to visualize environment

In [None]:
# experience.plot_corpus()

In [None]:
# experience.plot_task_probability()

# Training module

## Load pretrained model checkpoint

In [None]:
pretrain_ckpt = tf.train.Checkpoint(model=model)
pretrain_ckpt.restore(os.path.join(cfg.tf_root, "models", "pretrain_3M", "checkpoints", "epoch-300"))

## Create optimizers, loss functions, and metrics
- Since each sub-task has its own states, it must be trained with separate optimizer and losses, instead of sharing the same optimizer instance, see this [issue](https://github.com/tensorflow/tensorflow/issues/27120)
- Regarding to the metrics, there are two types of metrics:
    - Stateless metrics: the average metric of all batches within an epoch (semantic acc, sse)
    - Stateful metrics: only taking last step/batch value in an epoch (pho acc)

In [None]:
optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (only for TensorBoard)
train_metrics = {}

acc = {"pho": metrics.PhoAccuracy, "sem": metrics.StatelessRightSideAccuracy}
sse = metrics.StatelessSumSquaredError

for task in cfg.task_names:
    # optimizers[task] = tf.keras.optimizers.SGD(learning_rate=cfg.learning_rate)
    optimizers[task] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
    if cfg.zero_error_radius is not None:
        loss_fns[task] = metrics.CustomBCE(radius=cfg.zero_error_radius)
    else:
        loss_fns[task] = tf.keras.losses.BinaryCrossentropy()

    train_losses[task] = tf.keras.metrics.Mean(
        f"train_loss_{task}", dtype=tf.float32
    )  # for tensorboard only

    task_output = modeling.IN_OUT[task][1]

    if type(task_output) is list:
        train_metrics[task] = {}

        for out in task_output:
            train_metrics[task][out] = [
                acc[out](f"{task}_{out}_acc"),
                sse(f"{task}_{out}_sse"),
            ]
    else:
        train_metrics[task] = [acc[task_output](f"{task}_acc"), sse(f"{task}_sse")]


# Trainstep

In [None]:
def get_train_step(task):
    input_name, output_name = modeling.IN_OUT[task]

    if task == "triangle":

        @tf.function()
        def train_step(
            x, y, model, task, loss_fn, optimizer, train_metrics, train_losses
        ):
            """Train a batch, log loss and metrics (last time step only)"""

            train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
            train_weights = [x for x in model.weights if x.name in train_weights_name]

            # TF Automatic differentiation
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                # training flag can be access within model by K.in_train_phase()
                # it can change the behavior in model() (e.g., turn on/off noise)

                loss_value_pho = loss_fn(y["pho"], y_pred["pho"])
                loss_value_sem = loss_fn(y["sem"], y_pred["sem"])
                loss_value = loss_value_pho + loss_value_sem

            grads = tape.gradient(loss_value, train_weights)

            # Weight update
            optimizer.apply_gradients(zip(grads, train_weights))

            # Calculate mean loss and metrics for tensorboard
            # Metrics update (Only last time step)
            # for y_name, metrics in train_metrics.items():
            #     if y_name == "pho":
            #         # y[0] is pho, y[0][-1] is last time step in pho
            #         [
            #             m.update_state(
            #                 tf.cast(y["pho"][-1], tf.float32), y_pred["pho"][-1]
            #             )
            #             for m in metrics
            #         ]
            #     else:
            #         # y[1] is sem, y[0][-1] is last time step in sem
            #         [
            #             m.update_state(
            #                 tf.cast(y["sem"][-1], tf.float32), y_pred["sem"][-1]
            #             )
            #             for m in metrics
            #         ]

            # Mean loss
            train_losses.update_state(loss_value)

    else:  # Single output tasks

        @tf.function()
        def train_step(
            x, y, model, task, loss_fn, optimizer, train_metrics, train_losses
        ):
            train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
            train_weights = [x for x in model.weights if x.name in train_weights_name]

            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value = loss_fn(y, y_pred[output_name])

            grads = tape.gradient(loss_value, train_weights)
            optimizer.apply_gradients(zip(grads, train_weights))

            # [
            #     m.update_state(tf.cast(y[-1], tf.float32), y_pred[output_name][-1])
            #     for m in train_metrics
            # ]
            train_losses.update_state(loss_value)

    return train_step


train_steps = {task: get_train_step(task) for task in cfg.task_names}
# train_steps = {task: modeling.get_train_step(task) for task in cfg.tasks}

## Tensorboard module

In [None]:
# def write_scalar_to_tensorboard(task, step):
#     """Write metrics and loss to tensorboard"""
#     loss = train_losses[task]
#     tf.summary.scalar(loss.name, loss.result(), step=step)

#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [
#             tf.summary.scalar(m.name, m.result(), step=step)
#             for metrics in maybe_metrics.values()
#             for m in metrics
#         ]
#     else:
#         [tf.summary.scalar(m.name, m.result(), step=step) for m in maybe_metrics]


def write_weight_histogram_to_tensorboard(step):
    """Weight histogram"""
    [tf.summary.histogram(f"{x.name}",   x, step=step) for x in model.weights]


# def reset_metrics(task):
#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [m.reset_states() for metrics in maybe_metrics.values() for m in metrics]
#     else:
#         [m.reset_states() for m in maybe_metrics]

# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(
    os.path.join(cfg.tensorboard_folder, "train")
)

## Checkpoint (save) module

In [None]:
epoch = tf.Variable(0, name='epoch')

ckpt = tf.train.Checkpoint(
    epoch=epoch,
    model=model, 
    optimizers=optimizers,
    )

ckpt_manager = tf.train.CheckpointManager(
    ckpt, 
    cfg.checkpoint_folder, 
    max_to_keep=None,  # Keep all checkpoints
    checkpoint_name="epoch",
)

# Train model

## Resume from latest checkpoint
- Environment will no longer be identical if resume from checkpoint (Unable to put it in checkpoint for now)
- However, resume training is not very common, so it is not a big deal for now... 

In [None]:
# Restore from checkpoint
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print("Restored from {}".format(ckpt_manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

In [None]:
progress_bar = tqdm(total=cfg.total_number_of_epoch, desc="Training")
progress_bar.update(epoch.numpy())

while epoch.numpy() < cfg.total_number_of_epoch:

    # Train an epoch
    for step in range(cfg.steps_per_epoch):
        # Draw task, create batch
        task, exposed_words_idx, exposed_word, x_batch_train, y_batch_train = next(batch_generator)

        # task switching must be done outside train_step function (will crash otherwise)
        model.set_active_task(task)

        # Run a train step
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_metrics[task],
            train_losses[task],
        )

    # Post epoch Ops
    progress_bar.update(1)
    epoch.assign_add(1)

    with train_summary_writer.as_default():
        ## Write log to tensorboard (Once per epoch)
        # [write_scalar_to_tensorboard(task, step=epoch) for task in cfg.tasks]
        write_weight_histogram_to_tensorboard(step=epoch.numpy())

    #     ## Reset metric and loss
        [train_losses[x].reset_states() for x in cfg.task_names]
    #     [reset_metrics(x) for x in cfg.tasks]

    # In training loop testset eval
    # Kind of fast, but hard to integrate and customize... maybe use offline eval again.
    # [strain(task, step=epoch) for task in strain.tasks]
    # [grain(task, step=epoch) for task in grain.tasks]

    ## Save weights every save_freq epochs
    if (epoch.numpy() in cfg.saved_epochs):
        f=ckpt_manager.save(epoch) 
        # model.save_weights(cfg.saved_weights_fstring.format(epoch=epoch.numpy()))

# Run tests

In [None]:
# Run usual evals
import benchmark_hs04
benchmark_hs04.main(cfg.code_name, cfg.batch_name)

In [None]:
# Just test1 (acc in triangle and oral)
# import benchmark_hs04
# benchmark_hs04.run_test1(cfg.code_name)

In [None]:
# import evaluate
test = evaluate.TestSet(cfg)
test.eval_train("triangle", to_bq=True)

In [None]:
# Push to storage bucket and shutdown
# !gsutil -m rsync -d -r models/{code_name} gs://tf_mirror/{code_name}
# sleep(30)
# !sudo shutdown -P +0