# TF model v4.0
HS04 model incorporating non-stationary environment

In [None]:
%load_ext lab_black
import os
import tensorflow as tf
from tqdm import tqdm
import meta, data_wrangling, modeling, metrics

# meta.limit_gpu_memory_use(7000)

# Parameters block (for papermill)

In [None]:
code_name = "triangle_high_time_res_4M"
tf_root = "/home/jupyter/tf"

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.0
sem_noise_level = 0.0
activation = "sigmoid"
tau = 1 / 12
max_unit_time = 4.0
output_ticks = 11
inject_error_ticks = 11

# Training configs
learning_rate = 0.005
zero_error_radius = 0.1
save_freq = 10
batch_name = None

# Environment configs
tasks = ("pho_sem", "sem_pho", "pho_pho", "sem_sem", "triangle")
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 0.02
oral_end_pct = 0.5
oral_sample = 900_000
oral_tasks_ps = (0.4, 0.4, 0.1, 0.1, 0.0)
transition_sample = 400_000
reading_sample = 3_100_000
reading_tasks_ps = (0.2, 0.2, 0.05, 0.05, 0.5)
batch_size = 100
rng_seed = 2021

In [None]:
# cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))

In [None]:
# Load global cfg variables into a dictionary for feeding into ModelConfig()

config_dict = {}
for v in meta.CORE_CONFIGS + meta.ENV_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        raise

for v in meta.OPTIONAL_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        pass

# Construct ModelConfig object
cfg = meta.ModelConfig(**config_dict)
cfg.save()
del config_dict

# Build model and all supporting components

In [None]:
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()

# Architechture
model = modeling.MyModel(cfg)
model.build()

# Non-stationary Environment
sampler = data_wrangling.Sampler(cfg, data)
batch_generator = sampler.generator()
sampler.plot()

## Core training modules

In [None]:
# Since each sub-task has its own states, it must be trained with separate optimizer and losses,
# instead of sharing the same optimizer instance (https://github.com/tensorflow/tensorflow/issues/27120)
optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (only for TensorBoard)

for task in cfg.tasks:
    optimizers[task] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
    loss_fns[task] = metrics.CustomBCE(radius=cfg.zero_error_radius)
    train_losses[task] = tf.keras.metrics.Mean(
        f"train_loss_{task}", dtype=tf.float32
    )  # for tensorboard only

# Task specific train_metrics
## Caution PhoAccuracy is stateful (only taking last batch value in an epoch)
## Otherwise, all Stateless metrics are the average of all batches within an epoch
train_metrics = {
    "pho_pho": [
        metrics.PhoAccuracy("pho_pho_acc"),
        metrics.StatelessSumSquaredError("pho_pho_sse"),
    ],
    "sem_sem": [
        metrics.StatelessRightSideAccuracy("sem_sem_acc"),
        metrics.StatelessSumSquaredError("sem_sem_sse"),
    ],
    "pho_sem": [
        metrics.StatelessRightSideAccuracy("pho_sem_acc"),
        metrics.StatelessSumSquaredError("pho_sem_sse"),
    ],
    "sem_pho": [
        metrics.PhoAccuracy("sem_pho_acc"),
        metrics.StatelessSumSquaredError("sem_pho_sse"),
    ],
    "triangle": {},
}
train_metrics["triangle"]["pho"] = [
    metrics.PhoAccuracy("triangle_pho_acc"),
    metrics.StatelessSumSquaredError("triangle_pho_sse"),
]
train_metrics["triangle"]["sem"] = [
    metrics.StatelessRightSideAccuracy("triangle_sem_acc"),
    metrics.StatelessSumSquaredError("triangle_sem_sse"),
]

# Train step
train_steps = {task: modeling.get_train_step(task) for task in cfg.tasks}

## Tensorboard modules

In [None]:
def write_scalar_to_tensorboard(task, step):
    """Write metrics and loss to tensorboard"""
    loss = train_losses[task]
    tf.summary.scalar(loss.name, loss.result(), step=step)

    maybe_metrics = train_metrics[task]
    if task == "triangle":
        [
            tf.summary.scalar(m.name, m.result(), step=step)
            for metrics in maybe_metrics.values()
            for m in metrics
        ]
    else:
        [tf.summary.scalar(m.name, m.result(), step=step) for m in maybe_metrics]


def write_weight_histogram_to_tensorboard(step):
    """Weight histogram"""
    [tf.summary.histogram(f"{x.name}", x, step=step) for x in model.weights]


def reset_metrics(task):
    maybe_metrics = train_metrics[task]
    if task == "triangle":
        [m.reset_states() for metrics in maybe_metrics.values() for m in metrics]
    else:
        [m.reset_states() for m in maybe_metrics]

# Train model

In [None]:
# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(
    os.path.join(cfg.path["tensorboard_folder"], "train")
)

for epoch in tqdm(range(cfg.total_number_of_epoch)):
    for step in range(cfg.steps_per_epoch):
        # Draw task, create batch
        task, exposed_words_idx, x_batch_train, y_batch_train = next(batch_generator)

        # task switching must be done outside train_step function (will crash otherwise)
        model.set_active_task(task)

        # Run a train step
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_metrics[task],
            train_losses[task],
        )

    with train_summary_writer.as_default():
        ## Write log to tensorboard (Once per epoch)
        [write_scalar_to_tensorboard(task, step=epoch) for task in cfg.tasks]
        write_weight_histogram_to_tensorboard(step=epoch)

        ## Reset metric and loss
        [train_losses[x].reset_states() for x in cfg.tasks]
        [reset_metrics(x) for x in cfg.tasks]

    # In training loop testset eval
    # Kind of fast, but hard to integrate and customize... maybe use offline eval again.
    # [strain(task, step=epoch) for task in strain.tasks]
    # [grain(task, step=epoch) for task in grain.tasks]

    ## Save weights
    if (epoch < 10) or ((epoch + 1) % cfg.save_freq == 0):
        weight_path = cfg.path["weights_checkpoint_fstring"].format(epoch=epoch + 1)
        model.save_weights(weight_path, overwrite=True, save_format="tf")