# HS04 model

This is a ort to sem model written under modeling.HS04 framework. The purpose of this model is to examine the 



In [None]:
%reload_ext lab_black
import pickle, os, time
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt
from IPython.display import clear_output

import meta, data_wrangling, modeling, metrics, evaluate

# meta.limit_gpu_memory_use(7000)

# Parameters block (for papermill)

In [None]:
code_name = "ort_pho"
tf_root = "/home/jupyter/tf"

# Model architechture
ort_units = 119
pho_units = 250
sem_units = 2446

hidden_os_units = 500  # P2
hidden_op_units = 100  # P2
hidden_ps_units = 500
hidden_sp_units = 500

pho_cleanup_units = 20
sem_cleanup_units = 50

pho_noise_level = 0.0  # P3
sem_noise_level = 0.0  # P3

activation = "sigmoid"
tau = 1 / 3
max_unit_time = 4.0
inject_error_ticks = 2
output_ticks = 11

# Training
sample_name = "jay"

rng_seed = 2021
learning_rate = 0.01
n_mil_sample = 1.0
zero_error_radius = 0.1
batch_size = 100
save_freq = 10

batch_name = None

In [None]:
# cfg = meta.ModelConfig.from_json(
#     os.path.join(tf_root, "models", code_name, "model_config.json")
# )

In [None]:
config_dict = {}

# Load global cfg variables into a dictionary for feeding into ModelConfig()
for v in meta.CORE_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        raise

for v in meta.OPTIONAL_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        pass

# Construct ModelConfig object
cfg = meta.ModelConfig(**config_dict)
cfg.save()
del config_dict

# Build model and all supporting components

In [None]:
from importlib import reload
reload(modeling)

In [None]:
task = "ort_pho" 
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)

sampler = data_wrangling.FastSampling(cfg, data)
# sampler = data_wrangling.FastSampling_uniform(cfg, data)

generators = {task: sampler.sample_generator(x="ort", y="pho")}
optimizers = {task: tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)}
loss_fns = {task: metrics.CustomBCE(radius=cfg.zero_error_radius)}

# Train metrics
train_losses = {task: tf.keras.metrics.Mean("train_loss", dtype=tf.float32)}
train_acc = {task: metrics.PhoAccuracy()}

# Train step for triangle model 

In [None]:
@tf.function
def train_step_ort_pho(
    x,
    y,
    model,
    task,
    loss_fn,
    optimizer,
    train_metric,
    train_losses,
):

    train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
    train_weights = [x for x in model.weights if x.name in train_weights_name]

    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)
        loss_value = loss_fn(y, y_pred)

    grads = tape.gradient(loss_value, train_weights)
    optimizer.apply_gradients(zip(grads, train_weights))

    # Mean loss for Tensorboard
    train_losses.update_state(loss_value)

    # Metric for last time step (output first dimension is time ticks, from -cfg.output_ticks to end)
    train_metric.update_state(tf.cast(y[-1], tf.float32), y_pred[-1])


train_steps = {task: train_step_ort_pho}

# Train model

In [None]:
model.build()
model.set_active_task(task)


# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(cfg.path["tensorboard_folder"])

for epoch in range(cfg.total_number_of_epoch):
    start_time = time.time()

    for step in range(cfg.steps_per_epoch):

        x_batch_train, y_batch_train = next(generators[task])

        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_acc[task],
            train_losses[task],
        )

    # End of epoch operations

    ## Log all scalar metrics (losses and metrics)and histogram (weights and biases) to tensorboard
    with train_summary_writer.as_default():

        ### Loss
        [
            tf.summary.scalar(f"loss_{x}", train_losses[x].result(), step=epoch)
            for x in train_losses.keys()
        ]

        ### Metrics
        [
            tf.summary.scalar(f"acc_{x}", train_acc[x].result(), step=epoch)
            for x in train_acc.keys()
        ]

        ### Weights histogram
        [tf.summary.histogram(f"{x.name}", x, step=epoch) for x in model.weights]

    ## Print status
    compute_time = time.time() - start_time
    print(f"Epoch {epoch + 1} trained for {compute_time:.0f}s")
    print(f"Losses: {train_losses[task].result().numpy()}")
    clear_output(wait=True)

    ## Save weights
    if (epoch < 10) or ((epoch + 1) % cfg.save_freq == 0):
        weight_path = cfg.path["weights_checkpoint_fstring"].format(epoch=epoch + 1)
        model.save_weights(weight_path, overwrite=True, save_format="tf")

    ## Reset metric and loss
    [train_losses[x].reset_states() for x in train_losses.keys()]
    [train_acc[x].reset_states() for x in train_acc.keys()]

# End of training ops
# model.save(cfg.path["save_model_folder"])
print("Done")

# Evaluate model

In [None]:
data.testsets.keys()

In [None]:
def run_test(testset_name):

    testset_object = evaluate.TestSet(
        name=testset_name,
        cfg=cfg,
        model=model,
        task="ort_pho",
        testitems=data.testsets[testset_name]["item"],
        x_test=data.testsets[testset_name]["ort"],
        y_test=data.testsets[testset_name]["pho"],
    )

    testset_object.eval_all()
    return testset_object.result


testsets = ['strain_hf_con_hi', 'strain_hf_inc_hi', 'strain_hf_con_li', 'strain_hf_inc_li', 'strain_lf_con_hi', 'strain_lf_inc_hi', 'strain_lf_con_li', 'strain_lf_inc_li']

df = pd.concat(
    [run_test(x) for x in testsets], ignore_index=True
)

df.to_csv(os.path.join(cfg.path["eval_folder"], "ort_pho.csv"))

In [None]:
# df = pd.read_csv("ort_sem_uniform_sampling.csv")
df = pd.read_csv(os.path.join(cfg.path["eval_folder"], "ort_pho.csv"))
df = df.groupby(["epoch", "timetick", "y", "testset", "task"]).mean().reset_index()


def my_plot(y="acc"):
    return (
        alt.Chart(df)
        .mark_line(point=True)
        .encode(x="timetick", y=y, color="testset", column="epoch:O")
    )

In [None]:
my_plot("acc").save(os.path.join(cfg.path["plot_folder"], "op_acc.html"))
my_plot("sse").save(os.path.join(cfg.path["plot_folder"], "op_sse.html"))
my_plot("act0").save(os.path.join(cfg.path["plot_folder"], "op_act0.html"))
my_plot("act1").save(os.path.join(cfg.path["plot_folder"], "op_act1.html"))

In [None]:
all_plots = my_plot("acc") & my_plot("sse") & my_plot("act0") & my_plot("act1")
all_plots.save(os.path.join(cfg.path["plot_folder"], "op_all.html"))

In [None]:
# data = data_wrangling.MyData()
# model = modeling.HS04Model(cfg)
# model.build()
# model.set_active_task("triangle")

In [None]:
# test = evaluate.EvalReading(cfg, model, data)
# test.eval('train')
# test.eval("cortese")
# test.eval("strain")
# test.eval("grain")
# test.eval("taraban")

In [None]:
# Temp fix for pd float64 new data type error, read from disk as a work around
# test.eval('train')
# test.eval("cortese")
# test.eval("strain")
# test.eval("grain")
# test.eval("taraban")

## Basic accuracy over epoch

In [None]:
# Train ACC by OUTPUT
# test.plot_reading_acc(test.train_mean_df).encode(y="mean(acc):Q").save(
#     os.path.join(cfg.path["plot_folder"], "train_acc.html")
# )

In [None]:
# # Strain ACC by OUTPUT
# test.plot_reading_acc(test.strain_mean_df).encode(y="mean(acc)").save(
#     os.path.join(cfg.path["plot_folder"], "strain_acc.html")
# )

In [None]:
# # Grain PHO ACC by COND
# df = test.grain_mean_df.loc[test.grain_mean_df.y_test.isin(["pho"])]
# test.plot_reading_acc(df).encode(color="testset").save(
#     os.path.join(cfg.path["plot_folder"], "grain_acc.html")
# )

In [None]:
# # Grain ACC by RESP x COND
# df = test.grain_mean_df.loc[
#     test.grain_mean_df.y_test.isin(["pho_large_grain", "pho_small_grain"])
# ]
# test.plot_reading_acc(df).encode(color="testset", strokeDash="y_test").save(
#     os.path.join(cfg.path["plot_folder"], "grain_acc_by_resp.html")
# )

## Freq x Consistency

In [None]:
# epoch_selection = alt.selection_single(
#     bind=alt.binding_range(min=10, max=150, step=10),
#     fields=["epoch"],
#     init={"epoch": 150},
#     name="epoch",
# )

# timetick_selection = alt.selection_single(
#     bind=alt.binding_range(min=0, max=cfg.n_timesteps, step=1),
#     fields=["timetick"],
#     init={"timetick": cfg.n_timesteps},
#     name="timetick",
# )

In [None]:
# # Taraban
# taraban_selected_conditions = [
#     "taraban_hf-exc",
#     "taraban_hf-reg-inc",
#     "taraban_lf-exc",
#     "taraban_lf-reg-inc",
# ]

# df = test.taraban_mean_df.copy()
# df = df.loc[
#     (df.testset.isin(taraban_selected_conditions))
#     & (df.timetick >= 4)
#     & (df.y == "pho")
# ]

# df["frequency"] = df.testset.str.slice(8, 10)
# df["regularity"] = df.testset.str.slice(11, 14)


# (
#     alt.Chart(df)
#     .mark_line()
#     .encode(
#         x=alt.X("frequency:N", sort="descending"),
#         y="mean(conditional_sse):Q",
#         color="regularity:N",
#     )
#     .add_selection(epoch_selection)
#     .add_selection(timetick_selection)
#     .transform_filter(epoch_selection)
#     .transform_filter(timetick_selection)
#     .properties(width=180, height=180)
# )

# # .save(os.path.join(cfg.path["plot_folder"], "replication_hs04_fig10_taraban.html"))

In [None]:
# # Strain
# df = test.strain_mean_df.loc[
#     (test.strain_mean_df.timetick >= 4) & (test.strain_mean_df.y == "pho")
# ]

# alt.Chart(df).mark_line().encode(
#     x=alt.X("frequency:N", sort="descending"),
#     y="sum(sse):Q",
#     color="pho_consistency:N",
# ).add_selection(epoch_selection).transform_filter(epoch_selection).properties(
#     width=180, height=180
# ).save(
#     os.path.join(cfg.path["plot_folder"], "replication_hs04_fig10_strain.html")
# )

## Nonword

In [None]:
# import evaluate_old

# glushko = evaluate_old.glushko_eval(cfg, data, model)
# glushko.start_evaluate()

# mdf = glushko.i_hist.groupby(["epoch", "timestep", "cond"]).mean().reset_index()



# # ACC
# alt.Chart(mdf).mark_line().encode(x="epoch", y="acc", color="cond").add_selection(
#     timetick_selection
# ).transform_filter(timetick_selection).save(
#     os.path.join(cfg.path["plot_folder"], "glushko_acc.html")
# )

# # SSE
# alt.Chart(mdf).mark_line().encode(x="epoch", y="sse", color="cond").add_selection(
#     timetick_selection
# ).transform_filter(timetick_selection).save(
#     os.path.join(cfg.path["plot_folder"], "glushko_sse.html")
# )

## Imageability

In [None]:
# # Strain imageability
# df = test.strain_mean_df.copy()
# df["fc"] = df.frequency + "-" + df.pho_consistency
# df = df.loc[
#     df.timetick >= 4,
# ]

# y_selection = alt.selection_single(
#     bind=alt.binding_radio(options=["pho", "sem"]), fields=["y"], init={"y": "pho"}
# )

# epoch_selection = alt.selection_single(
#     bind=alt.binding_range(min=10, max=150, step=10),
#     fields=["epoch"],
#     init={"epoch": 150},
#     name="epoch",
# )

# # timetick_selection = alt.selection_single(
# #     bind=alt.binding_range(min=2, max=12, step=1),
# #     fields=["timetick"],
# #     init={"timetick": 12},
# #     name="timetick",
# # )

# fig11 = (
#     alt.Chart(df)
#     .mark_bar()
#     .encode(
#         column=alt.X("fc:N", sort=["HF-CON", "LF-CON", "HF-INC", "LF-INC"]),
#         y="mean(conditional_sse):Q",
#         x="imageability:N",
#         color="imageability:N",
#     )
#     .add_selection(epoch_selection)
#     .add_selection(y_selection)
#     .transform_filter(epoch_selection)
#     .transform_filter(y_selection)
# )

# fig11.save(os.path.join(cfg.path["plot_folder"], "replication_hs04_fig11_csse.html"))

# fig11.encode(y="mean(sse):Q").save(
#     os.path.join(cfg.path["plot_folder"], "replication_hs04_fig11_sse.html")
# )

# fig11.encode(y="mean(acc):Q").save(
#     os.path.join(cfg.path["plot_folder"], "replication_hs04_fig11_acc.html")
# )

In [None]:
# # Imageability only within Strain
# timetick_selection = alt.selection_single(
#     bind=alt.binding_range(min=0, max=cfg.n_timesteps, step=1),
#     fields=["timetick"],
#     init={"timetick": cfg.n_timesteps},
#     name="timetick",
# )

# alt.Chart(test.strain_mean_df).mark_line().encode(
#     x="epoch", y="mean(sse)", color="imageability", column="y"
# ).add_selection(timetick_selection).transform_filter(timetick_selection).save(
#     os.path.join(cfg.path["plot_folder"], "Strain_sse_img_by_output.html")
# )

In [None]:
# test.plot_reading_acc(test.cortese_mean_df).encode(
#     y="mean(conditional_sse)", color="testset", column="y"
# ).save(os.path.join(cfg.path["plot_folder"], "cortese_csse.html"))

# test.plot_reading_acc(test.cortese_mean_df).encode(
#     y="mean(sse)", color="testset", column="y"
# ).save(os.path.join(cfg.path["plot_folder"], "cortese_sse.html"))

# test.plot_reading_acc(test.cortese_mean_df).encode(
#     y="mean(acc)", color="testset", column="y"
# ).save(os.path.join(cfg.path["plot_folder"], "cortese_acc.html"))

In [None]:
# gcloud compute ssh tensorflow-2-4-20210120-000018 --zone us-east4-b -- -L 6006:localhost:6006
# !tensorboard --logdir tensorboard_log

# !tensorboard dev upload --logdir tensorboard_log