# Triangle model
This interactive notebook runs a triangle model


## Run parameters 
This block is necessary for running with [papermill](https://papermill.readthedocs.io/en/latest/)

In [None]:
code_name = "usual_pretrain"
batch_name = None

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.0
sem_noise_level = 0.0
activation = "sigmoid"

tau = 1 / 3
max_unit_time = 4.0
output_ticks = 13
inject_error_ticks = 11

# Training configs
learning_rate = 0.001
zero_error_radius = 0.1
save_freq = 10

# Environment configs
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999

task_names = ("pho_sem", "sem_pho", "pho_pho", "sem_sem")
tasks_ps = (0.4, 0.4, 0.1, 0.1)
# task_names = "pho_sem", "sem_pho", "pho_pho", "sem_sem", "triangle")
# tasks_ps = (0.0, 0.0, 0.1, 0.1, 0.5)

total_sample = 1_000_000
batch_size = 1
rng_seed = 2021
which_gpu = 0


## System environment

In [None]:
from meta import split_gpu

split_gpu(which_gpu=which_gpu)  # IMPORTANT: do not import TensorFlow before this line

import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from time import sleep
from dotenv import load_dotenv
import meta, data_wrangling, metrics, modeling, training

os.environ["PYTHONHASHSEED"] = str(rng_seed)
tf.random.set_seed(rng_seed)
np.random.seed(rng_seed)

load_dotenv()  # Loads .env file


## Create `Config()`, `Data()`, and `Model()`
- `Config()` stores all the run setting in a class. `**globals()` is used to access all global variables in the parameter block.
- `MyData()` contains all the static data sets.
- `MyModel()` contains the triangle model implementation on TensorFlow

In [None]:
# cfg = meta.Config.from_json(os.path.join(tf_root, "models", batch_name, code_name, "model_config.json"))   # Load from json
cfg = meta.Config.from_dict(**globals())
data = data_wrangling.MyData()
model = modeling.MyModel(cfg)
model.build()
print(cfg)

## Create sample generator
- `Experience()` defines what the model is trained on. It consists of one or more `Stage()`. 
- Each `Stage()` describes what tasks are the model trained with, and how often a task is used during training. It contains one or more `Task()`. 
- Each `Task()` contains how fast the corpus is opened (a set of word that can be sampled), defaults to full open.
- See the docstrings in each object for further details.

In [None]:
from environment import Task, Stage, Experience, Sampler

stages = [
    Stage(
        name="one",
        tasks=[Task(x) for x in cfg.task_names],
        stage_sample=cfg.total_sample,
        task_probability_start=cfg.tasks_ps,
    )
]

experience = Experience(stages)
sampler = Sampler(cfg, data, experience)
batch_generator = sampler.generator()


## Create optimizers, loss functions, metrics, and train steps
- Since each sub-task has its own states, it must be trained with separate optimizer.
- Regarding to the metrics, there are two types of metrics:
    - Stateless metrics: the average metric of all batches within an epoch (semantic acc, sse)
    - Stateful metrics: only taking last step/batch value in an epoch (pho acc)

In [None]:
optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (only for TensorBoard)
train_metrics = {}
train_steps = {}

acc = {"pho": metrics.PhoAccuracy, "sem": metrics.StatelessRightSideAccuracy}
sse = metrics.StatelessSumSquaredError

for task in cfg.task_names:

    # Optimizer
    optimizers[task] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)

    # Loss function
    if cfg.zero_error_radius is not None:
        loss_fns[task] = metrics.CustomBCE(radius=cfg.zero_error_radius)
    else:
        loss_fns[task] = tf.keras.losses.BinaryCrossentropy()

    train_losses[task] = tf.keras.metrics.Mean(f"train_loss_{task}", dtype=tf.float32)

    # Metrics & train steps
    task_output = modeling.IN_OUT[task][1]
    if task == "triangle":
        train_metrics[task] = {}
        for output in task_output:
            train_metrics[task][output] = [
                acc[output](f"{task}_{output}_acc"),
                sse(f"{task}_{output}_sse"),
            ]

        train_steps[task] = training.triangle_train_step()

    else:

        train_metrics[task] = [acc[task_output](f"{task}_acc"), sse(f"{task}_sse")]
        train_steps[task] = training.basic_train_step(task)


## Create tensorboard writer

In [None]:
# def write_scalar_to_tensorboard(task, step):
#     """Write metrics and loss to tensorboard"""
#     loss = train_losses[task]
#     tf.summary.scalar(loss.name, loss.result(), step=step)

#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [
#             tf.summary.scalar(m.name, m.result(), step=step)
#             for metrics in maybe_metrics.values()
#             for m in metrics
#         ]
#     else:
#         [tf.summary.scalar(m.name, m.result(), step=step) for m in maybe_metrics]


def write_weight_histogram_to_tensorboard(step):
    """Weight histogram"""
    [tf.summary.histogram(f"{x.name}", x, step=step) for x in model.weights]


# def reset_metrics(task):
#     maybe_metrics = train_metrics[task]
#     if task == "triangle":
#         [m.reset_states() for metrics in maybe_metrics.values() for m in metrics]
#     else:
#         [m.reset_states() for m in maybe_metrics]

# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(
    os.path.join(cfg.tensorboard_folder, "train")
)


## Load pretrained checkpoint

In [None]:
# pretrain_ckpt = tf.train.Checkpoint(model=model)
# pretrain_ckpt.restore(
#     os.path.join(cfg.tf_root, "models", "pretrain_3M", "checkpoints", "epoch-300")
# )


## Create checkpoint (save) manager

In [None]:
epoch = tf.Variable(0, name="epoch")  # Epoch counter

ckpt = tf.train.Checkpoint(
    epoch=epoch,
    model=model,
    optimizers=optimizers,
)

ckpt_manager = tf.train.CheckpointManager(
    ckpt,
    cfg.checkpoint_folder,
    max_to_keep=None,  # Keep all checkpoints
    checkpoint_name="epoch",
)


## Resume training from latest checkpoint
- CAUTION: Environment will no longer be identical if resume from checkpoint (Unable to put it in checkpoint for now)
- However, resume training is not very common, so it is not a big deal for now... 

In [None]:
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print(f"Restored from {ckpt_manager.latest_checkpoint}")
else:
    print("Initializing from scratch.")


## Train model

In [None]:
progress_bar = tqdm(total=cfg.total_number_of_epoch, desc="Training")
progress_bar.update(epoch.numpy())

while epoch.numpy() < cfg.total_number_of_epoch:

    # Train an epoch
    for step in range(cfg.steps_per_epoch):
        # Draw task, create batch
        task, exposed_words_idx, exposed_word, x_batch_train, y_batch_train = next(
            batch_generator
        )

        # task switching must be done outside train_step function (will crash otherwise)
        model.set_active_task(task)

        # Run a train step
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_metrics[task],
            train_losses[task],
        )

    # Post epoch Ops
    progress_bar.update(1)
    epoch.assign_add(1)

    with train_summary_writer.as_default():
        ## Write log to tensorboard (Once per epoch)
        # [write_scalar_to_tensorboard(task, step=epoch) for task in cfg.task_names]
        write_weight_histogram_to_tensorboard(step=epoch.numpy())

        #     ## Reset metric and loss
        [train_losses[x].reset_states() for x in cfg.task_names]
    #     [reset_metrics(x) for x in cfg.task_names]

    if epoch.numpy() in cfg.saved_epochs:
        f = ckpt_manager.save(epoch)


## Run tests

In [None]:
import benchmark_hs04
# Basic test
benchmark_hs04.run_test1(cfg.code_name)  # Basic accuracy test only

## All benchmarks
# benchmark_hs04.main(cfg.code_name, cfg.batch_name)  


## Full training set test
# import evaluate
# test = evaluate.Test(cfg)
# test.eval_train("triangle", to_bq=True)
