# HS04 model

## Phase 2

> The weights that were obtained at the end of the Phase 1 model were frozen and embedded in the larger reading model. Thus, only the connections from orthography to other units were trained in Phase 2. Freezing the weights is not strictly necessary; earlier work (Harm & Seidenberg, 1997) used a process of intermixing in which comprehension trials were used along with reading trials. Weight freezing has the same effect but is simpler and less computationally burdensome to implement. Intermixing is effective and real- istic but adds substantially to network training time.

- *Pretraining is necessary, and freeze in phase 2

> One set of 500 hidden units mediated the mapping from these orthographic units to semantics...

- *500 sem_hidden_units*

> ...a second set of 100 hidden units mediated the orth-phon pathway.

- *100 pho_hidden_units*

> To computationally instantiate the principle that the reading system is under pressure to perform rapidly as well as accurately, we injected error into the semantic and phonological representa- tions early, from time samples 2 to 12. 
- *11 output_ticks*

## Phase 3

- Modeling individual differences
- Simulating ERPs
- Link to reliance of OP vs OS
- Use equation to model semantic / phonetic input to P/S

In [None]:
# %reload_ext lab_black
import pickle, os, time
import tensorflow as tf
import numpy as np
import pandas as pd
import altair as alt
from IPython.display import clear_output

import meta, data_wrangling, modeling, metrics, evaluate

# meta.set_gpu_mem_cap()

# Parameters block (for papermill)

In [None]:
code_name = "hs04_phase2_test3"
tf_root = "/home/jupyter/tf"

# Model architechture
ort_units = 119
pho_units = 250
sem_units = 2446

hidden_os_units = 500  # P2
hidden_op_units = 100  # P2
hidden_ps_units = 500
hidden_sp_units = 500

pho_cleanup_units = 50
sem_cleanup_units = 50

pho_noise_level = 0.0  # P3
sem_noise_level = 0.0  # P3

activation = "sigmoid"
tau = 1 / 3
max_unit_time = 4.0
output_ticks = 11

# Pretraining
pretrained_checkpoint = (
    "/home/jupyter/tf/models/hs04_phase1_selected_fix_attractor/weights/ep0200"
)

# Training
sample_name = "hs04"

rng_seed = 2021
learning_rate = 0.01
n_mil_sample = 1.5
batch_size = 100
save_freq = 10

In [None]:
cfg = meta.ModelConfig.from_json(os.path.join("models", code_name, "model_config.json"))

In [None]:
config_dict = {}

# Load global cfg variables into a dictionary for feeding into ModelConfig()
for v in meta.CORE_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        raise

for v in meta.OPTIONAL_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        pass

# Construct ModelConfig object
cfg = meta.ModelConfig(**config_dict)
cfg.save()
del config_dict

# Build model and all supporting components

In [None]:
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)

sampler = data_wrangling.FastSampling(cfg, data)
generators = {"triangle": sampler.sample_generator(x="ort", y=["pho", "sem"])}
optimizers = {"triangle": tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)}
loss_fns = {"triangle": tf.keras.losses.BinaryCrossentropy()}

# Mean loss (for TensorBoard)
train_losses = {
    "triangle": tf.keras.metrics.Mean("train_loss_triangle", dtype=tf.float32)
}

# Train metrics
train_acc = {
    "triangle_pho": metrics.PhoAccuracy("acc_triangle_pho"),
    "triangle_sem": metrics.RightSideAccuracy("acc_triangle_sem"),
}

# Train step for triangle model 

In [None]:
@tf.function
def train_step_triangle(
    x,
    y,
    model,
    task,
    loss_fn,
    optimizer,
    train_metric_pho,
    train_metric_sem,
    train_losses,
):

    train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
    train_weights = [x for x in model.weights if x.name in train_weights_name]

    with tf.GradientTape() as tape:
        pho_pred, sem_pred = model(x, training=True)
        loss_value_pho = loss_fn(y[0], pho_pred)
        loss_value_sem = loss_fn(y[1], sem_pred)
        loss_value = loss_value_pho + loss_value_sem

    grads = tape.gradient(loss_value, train_weights)
    optimizer.apply_gradients(zip(grads, train_weights))

    # Mean loss for Tensorboard
    train_losses.update_state(loss_value)

    # Metric for last time step (output first dimension is time ticks, from -cfg.output_ticks to end)
    train_metric_pho.update_state(tf.cast(y[0][-1], tf.float32), pho_pred[-1])
    train_metric_sem.update_state(tf.cast(y[1][-1], tf.float32), sem_pred[-1])


train_steps = {"triangle": train_step_triangle}

# Train model

In [None]:
model.build()
model.load_weights(pretrained_checkpoint)
task = "triangle"
model.set_active_task(task)


# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(cfg.path["tensorboard_folder"])

for epoch in range(cfg.total_number_of_epoch):
    start_time = time.time()

    for step in range(cfg.steps_per_epoch):

        x_batch_train, y_batch_train = next(generators[task])

        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_acc["triangle_pho"],
            train_acc["triangle_sem"],
            train_losses[task],
        )

    # End of epoch operations

    ## Log all scalar metrics (losses and metrics)and histogram (weights and biases) to tensorboard
    with train_summary_writer.as_default():

        [
            tf.summary.scalar(f"loss_{x}", train_losses[x].result(), step=epoch)
            for x in train_losses.keys()
        ]
        [
            tf.summary.scalar(f"acc_{x}", train_acc[x].result(), step=epoch)
            for x in train_acc.keys()
        ]
        [tf.summary.histogram(f"{x.name}", x, step=epoch) for x in model.weights]

    ## Print status
    compute_time = time.time() - start_time
    print(f"Epoch {epoch + 1} trained for {compute_time:.0f}s")
    print(f"Losses: {train_losses[task].result().numpy()}")
    clear_output(wait=True)

    ## Save weights
    if (epoch < 10) or ((epoch + 1) % 10 == 0):
        weight_path = cfg.path["weights_checkpoint_fstring"].format(epoch=epoch + 1)
        model.save_weights(weight_path, overwrite=True, save_format="tf")

    ## Reset metric and loss
    [train_losses[x].reset_states() for x in train_losses.keys()]
    [train_acc[x].reset_states() for x in train_acc.keys()]

# End of training ops
# model.save(cfg.path["save_model_folder"])
print("Done")

# Evaluate model

In [None]:
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)
model.build()

In [None]:
import importlib
importlib.reload(evaluate)
importlib.reload(data_wrangling)
data=data_wrangling.MyData()


In [None]:
test = evaluate.EvalReading(cfg, model, data)
# test.eval_train()
# test.eval_strain()
# test.eval_grain()
test.eval_train_cortese_img()   

In [None]:
test.plot_reading_acc(test.train_cortese_img_mean_df).encode(y="mean(sse)", color="testset", column="y").save(
    os.path.join(cfg.path["plot_folder"], "train_cortese_img_sse_output.html")
)

test.plot_reading_acc(test.train_cortese_img_mean_df).encode(y="mean(acc)", color="testset", column="y").save(
    os.path.join(cfg.path["plot_folder"], "train_cortese_img_acc_output.html")
)


In [None]:
# Train ACC by OUTPUT
test.plot_reading_acc(test.train_mean_df).encode(y="mean(acc)").save(
    os.path.join(cfg.path["plot_folder"], "train_acc_output.html")
)

In [None]:
# Strain ACC by OUTPUT
test.plot_reading_acc(test.strain_mean_df).encode(y="mean(acc)").save(
    os.path.join(cfg.path["plot_folder"], "strain_acc_output.html")
)

In [None]:
# Grain PHO ACC by COND
df = test.grain_mean_df.loc[test.grain_mean_df.y_test.isin(["pho"])]
test.plot_reading_acc(df).encode(color="testset").save(
    os.path.join(cfg.path["plot_folder"], "grain_pho_acc.html")
)

In [None]:
# Grain ACC by RESP x COND
df = test.grain_mean_df.loc[
    test.grain_mean_df.y_test.isin(["pho_large_grain", "pho_small_grain"])
]
test.plot_reading_acc(df).encode(color="testset", strokeDash="y_test").save(
    os.path.join(cfg.path["plot_folder"], "grain_pho_acc_by_resp.html")
)

### HS04 experiments (Fig 10) 

In [None]:
exp_df = test.strain_mean_df.loc[
    (test.strain_mean_df.epoch >= 50)
    & (test.strain_mean_df.timetick >= 8)
    & (test.strain_mean_df.y == "pho")
]

In [None]:
alt.Chart(exp_df).mark_line().encode(
    x=alt.X("frequency:N", sort="descending"),
    y="sum(sse):Q",
    color="pho_consistency:N",
).properties(width=180, height=180)

# .save(
#     os.path.join(cfg.path["plot_folder"], "strain_acc_output.html")
# )

In [None]:
exp_df["fc"] = exp_df.frequency + "-" + exp_df.pho_consistency

In [None]:
alt.Chart(exp_df).mark_bar().encode(
    column=alt.X("fc:N", sort=["HF-CON", "LF-CON", "HF-INC", "LF-INC"]),
    y="mean(sse):Q",
    x="imageability:N",
    color="imageability:N",
).properties(width=180, height=180)

In [None]:
# interactive on epoch and timetick
epoch_selection = alt.selection_single(
    bind=alt.binding_range(min=50, max=150, step=10),
    fields=["epoch"],
    init={"epoch": 150},
    name="epoch",
)


timetick_selection = alt.selection_single(
    bind=alt.binding_range(min=0, max=cfg.n_timesteps, step=1),
    fields=["timetick"],
    init={"timetick": cfg.n_timesteps},
    name="timetick",
)


(
    alt.Chart(exp_df)
    .mark_bar()
    .encode(
        column=alt.X("fc:N", sort=["HF-CON", "LF-CON", "HF-INC", "LF-INC"]),
        y="sum(sse):Q",
        x="imageability:N",
        color="imageability:N",
    )
    .add_selection(timetick_selection)
    .add_selection(epoch_selection)
    .transform_filter(timetick_selection)
    .transform_filter(epoch_selection)
).save(os.path.join(cfg.path["plot_folder"], "interactive_strain_sse.html"))

In [None]:
# Sum across timeticks, interactive on epoch
(
    alt.Chart(exp_df)
    .mark_bar()
    .encode(
        column=alt.X("fc:N", sort=["HF-CON", "LF-CON", "HF-INC", "LF-INC"]),
        y="sum(sse):Q",
        x="imageability:N",
        color="imageability:N",
    )
    .add_selection(epoch_selection)
    .transform_filter(epoch_selection)
).save(os.path.join(cfg.path["plot_folder"], "interactive_epoch_strain_sse.html"))

In [None]:
# local ssh to cloud tensorboard
# gcloud compute ssh tensorflow-2-4-20210120-000018 --zone us-east4-b -- -L 6006:localhost:6006
# !tensorboard dev upload --logdir tensorboard_log