# TF model v4.0

HS04 model incorporating non-stationary environment

- *50 sem_cleanup*
- *50 pho_cleanup*
- *500 sem_pho_hidden_units*
- *500 pho_sem_hidden_units*
- *4 output_ticks* 
- *No auto-connection lock*
- *Attractor clamped for 8 steps, free for last 4 steps*
- implemented only in modeling.py but not the generator, it will just drop the extra generated time ticks
- can do both phase 1 (oral) and 2 (reading) in this notebook

In [None]:
# %load_ext lab_black
import os, time
import tensorflow as tf
from IPython.display import clear_output
import meta, data_wrangling, modeling, metrics, evaluate

# meta.limit_gpu_memory_use(7000)

# Parameters block (for papermill)

In [None]:
code_name = "ort_sem"
tf_root = "/home/jupyter/tf"

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.0
sem_noise_level = 0.0
activation = "sigmoid"
tau = 1 / 3
max_unit_time = 4.0
output_ticks = 12
inject_error_ticks = 11

# Training configs
learning_rate = 0.005
zero_error_radius = 0.1
save_freq = 10
batch_name = None

# Environment configs
tasks = ("ort_sem")
wf_clipping_edges = None
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 0.02
oral_end_pct = 1.0
oral_sample = 900_000
oral_tasks_ps = (1.0)
transition_sample = 0
reading_sample = 0
reading_tasks_ps = (1.0)
batch_size = 100
rng_seed = 2021


In [None]:
# cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))

In [None]:
# Load global cfg variables into a dictionary for feeding into ModelConfig()

config_dict = {}
for v in meta.CORE_CONFIGS + meta.ENV_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        raise

for v in meta.OPTIONAL_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        pass

# Construct ModelConfig object
cfg = meta.ModelConfig(**config_dict)
cfg.save()
del config_dict

# Build model and all supporting components

In [None]:
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)
model.build()
# from importlib import reload
# reload(data_wrangling)
sampler = data_wrangling.Sampler(cfg, data)
sampler.plot()

In [None]:
generator = sampler.generator()

# Full set of task specific components

optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (for TensorBoard)

optimizers["ort_sem"] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate, beta_1=0.9)
loss_fns["ort_sem"] = metrics.CustomBCE(radius=cfg.zero_error_radius)
train_losses["ort_sem"] = tf.keras.metrics.Mean(f"train_loss_ort_sem", dtype=tf.float32)


# Training acc is output specific
train_acc = {
    "ort_sem": metrics.RightSideAccuracy("acc_ort_sem"),
}

## Train step for each task

In [None]:
# Since each sub-task has its own states, it must be trained with separate optimizer,
# instead of sharing the same optimizer instance (https://github.com/tensorflow/tensorflow/issues/27120)


def get_train_step():
    """ Universal training step creator """

    @tf.function
    def train_step(
        x,
        y,
        model,
        task,
        loss_fn,
        optimizer,
        train_metrics,
        train_losses,
    ):

        train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
        train_weights = [x for x in model.weights if x.name in train_weights_name]

        if task == "triangle":
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value_pho = loss_fn(y[0], y_pred[0])  # Caution order matter
                loss_value_sem = loss_fn(y[1], y_pred[1])
                loss_value = loss_value_pho + loss_value_sem
        else:
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value = loss_fn(y, y_pred)

        grads = tape.gradient(loss_value, train_weights)
        optimizer.apply_gradients(zip(grads, train_weights))

        # Mean loss for Tensorboard
        train_losses.update_state(loss_value)

        # Metric for last time step (output first dimension is time ticks, from -cfg.output_ticks to end) for live results
        if type(train_metrics) is list:
            for i, x in enumerate(train_metrics):
                x.update_state(tf.cast(y[i][-1], tf.float32), y_pred[i][-1])
        else:
            train_metrics.update_state(tf.cast(y[-1], tf.float32), y_pred[-1])

    return train_step


train_steps = {cfg.tasks: get_train_step()}

# Train model

In [None]:
# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(cfg.path["tensorboard_folder"])

for epoch in range(int(sampler.total_batches / 100)):
    start_time = time.time()

    for step in range(100):
        # Run 100 batches before every logging
        # Draw task, create batch
        task, exposed_words_idx, x_batch_train, y_batch_train = next(generator)
        model.set_active_task(task)  # task switching must be done outside trainstep...
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_acc[task],
            train_losses[task],
        )

    # End of epoch operations

    ## Write log to tensorboard
    with train_summary_writer.as_default():
        ### Losses
        [
            tf.summary.scalar(f"loss_{x}", train_losses[x].result(), step=epoch)
            for x in train_losses.keys()
        ]

        ### Metrics

        for task in train_acc.keys():
            if task == "triangle":
                [
                    tf.summary.scalar(acc.name, acc.result(), step=epoch)
                    for acc in train_acc[task]
                ]
            else:
                tf.summary.scalar(
                    train_acc[task].name, train_acc[task].result(), step=epoch
                )

        ### Weight histogram
        [tf.summary.histogram(f"{x.name}", x, step=epoch) for x in model.weights]

    ## Print status
    print(f"Epoch {epoch + 1} trained for {time.time() - start_time:.0f}s")
    print(f"Losses: {cfg.tasks}: {train_losses[cfg.tasks].result().numpy()}")
    clear_output(wait=True)

    ## Save weights
    if (epoch < 10) or ((epoch + 1) % cfg.save_freq == 0):
        weight_path = cfg.path["weights_checkpoint_fstring"].format(epoch=epoch + 1)
        model.save_weights(weight_path, overwrite=True, save_format="tf")

    ## Reset metric and loss
    [train_losses[x].reset_states() for x in train_losses.keys()]

    for task in train_acc.keys():
        if task == "triangle":
            [x.reset_states() for x in train_acc[task]]
        else:
            train_acc[task].reset_states()


# End of training ops
print("Done")

# Evaluate model
eval 3.0 under construction
features:
- Speed (2.0 code is easy to read but way too slow)
- No separtion between oral and reading
- More plots build in 