# TF model v4.0

HS04 model incorporating non-stationary environment

- *50 sem_cleanup*
- *50 pho_cleanup*
- *500 sem_pho_hidden_units*
- *500 pho_sem_hidden_units*
- *4 output_ticks* 
- *No auto-connection lock*
- *Attractor clamped for 8 steps, free for last 4 steps*
- implemented only in modeling.py but not the generator, it will just drop the extra generated time ticks
- can do both phase 1 (oral) and 2 (reading) in this notebook

In [None]:
%load_ext lab_black
import pickle, os, time
import tensorflow as tf
import numpy as np
import pandas as pd
from IPython.display import clear_output
import meta, data_wrangling, modeling, metrics, evaluate

# meta.limit_gpu_memory_use(7000)

# Parameters block (for papermill)

In [None]:
code_name = "interleave_003lr_test"
tf_root = "/home/jupyter/tf"

# Model configs
ort_units = 119
pho_units = 250
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 100
hidden_ps_units = 500
hidden_sp_units = 500
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.0
sem_noise_level = 0.0
activation = "sigmoid"
tau = 1 / 3
max_unit_time = 4.0
output_ticks = 12
inject_error_ticks = 11

# Training configs
learning_rate = 0.003
zero_error_radius = 0.1
save_freq = 10
batch_name = None

# Environment configs
tasks = ("pho_sem", "sem_pho", "pho_pho", "sem_sem", "triangle")
wf_clipping_edges = None
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 0.02
oral_end_pct = 0.5
oral_sample = 900_000
oral_tasks_ps = (0.4, 0.4, 0.1, 0.1, 0.0)
transition_sample = 400_000
reading_sample = 2_000_000
reading_tasks_ps = (0.2, 0.2, 0.05, 0.05, 0.5)
batch_size = 100
rng_seed = 2021

In [None]:
# cfg = meta.ModelConfig.from_json(os.path.join(tf_root, 'models', code_name, 'model_config.json'))

In [None]:
# Load global cfg variables into a dictionary for feeding into ModelConfig()

config_dict = {}
for v in meta.CORE_CONFIGS + meta.ENV_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        raise

for v in meta.OPTIONAL_CONFIGS:
    try:
        config_dict[v] = globals()[v]
    except:
        pass

# Construct ModelConfig object
cfg = meta.ModelConfig(**config_dict)
cfg.save()
del config_dict

# Build model and all supporting components

In [None]:
tf.random.set_seed(cfg.rng_seed)
data = data_wrangling.MyData()
model = modeling.HS04Model(cfg)
model.build()
sampler = data_wrangling.Sampler(cfg, data)

In [None]:
generator = sampler.generator()

# Full set of task specific components

optimizers = {}
loss_fns = {}
train_losses = {}  # Mean loss (for TensorBoard)


for task in cfg.tasks:
    optimizers[task] = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
    loss_fns[task] = metrics.CustomBCE(radius=cfg.zero_error_radius)
    train_losses[task] = tf.keras.metrics.Mean(f"train_loss_{task}", dtype=tf.float32)


# Training acc is output specific
train_acc = {
    "pho_pho": metrics.PhoAccuracy("acc_pho_pho"),
    "sem_sem": metrics.RightSideAccuracy("acc_sem_sem"),
    "pho_sem": metrics.RightSideAccuracy("acc_pho_sem"),
    "sem_pho": metrics.PhoAccuracy("acc_sem_pho"),
    "triangle": [
        metrics.PhoAccuracy("acc_triangle_pho"),
        metrics.RightSideAccuracy("acc_triangle_sem"),
    ],
}

## Train step for each task

In [None]:
# Since each sub-task has its own states, it must be trained with separate optimizer,
# instead of sharing the same optimizer instance (https://github.com/tensorflow/tensorflow/issues/27120)


def get_train_step():
    """Special train step for triangle phase with 2 outputs"""

    @tf.function
    def train_step(
        x,
        y,
        model,
        task,
        loss_fn,
        optimizer,
        train_metrics,
        train_losses,
    ):

        train_weights_name = [x + ":0" for x in modeling.WEIGHTS_AND_BIASES[task]]
        train_weights = [x for x in model.weights if x.name in train_weights_name]

        if task == "triangle":
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value_pho = loss_fn(y[0], y_pred[0])  # Caution order matter
                loss_value_sem = loss_fn(y[1], y_pred[1])
                loss_value = loss_value_pho + loss_value_sem
        else:
            with tf.GradientTape() as tape:
                y_pred = model(x, training=True)
                loss_value = loss_fn(y, y_pred)

        grads = tape.gradient(loss_value, train_weights)
        optimizer.apply_gradients(zip(grads, train_weights))

        # Mean loss for Tensorboard
        train_losses.update_state(loss_value)

        # Metric for last time step (output first dimension is time ticks, from -cfg.output_ticks to end) for live results
        if type(train_metrics) is list:
            for i, x in enumerate(train_metrics):
                x.update_state(tf.cast(y[i][-1], tf.float32), y_pred[i][-1])
        else:
            train_metrics.update_state(tf.cast(y[-1], tf.float32), y_pred[-1])

    return train_step


train_steps = {task: get_train_step() for task in cfg.tasks}

# Train model

In [None]:
# TensorBoard writer
train_summary_writer = tf.summary.create_file_writer(cfg.path["tensorboard_folder"])

for epoch in range(int(sampler.total_batches / 100)):
    start_time = time.time()

    for step in range(100):
        # Run 100 batches before every logging
        # Draw task, create batch
        task, exposed_words_idx, x_batch_train, y_batch_train = next(generator)
        model.set_active_task(task)  # task switching must be done outside trainstep...
        train_steps[task](
            x_batch_train,
            y_batch_train,
            model,
            task,
            loss_fns[task],
            optimizers[task],
            train_acc[task],
            train_losses[task],
        )

    # End of epoch operations

    ## Write log to tensorboard
    with train_summary_writer.as_default():
        ### Losses
        [
            tf.summary.scalar(f"loss_{x}", train_losses[x].result(), step=epoch)
            for x in train_losses.keys()
        ]

        ### Metrics

        for task in train_acc.keys():
            if task == "triangle":
                [
                    tf.summary.scalar(acc.name, acc.result(), step=epoch)
                    for acc in train_acc[task]
                ]
            else:
                tf.summary.scalar(
                    train_acc[task].name, train_acc[task].result(), step=epoch
                )

        ### Weight histogram
        [tf.summary.histogram(f"{x.name}", x, step=epoch) for x in model.weights]

    ## Print status
    print(f"Epoch {epoch + 1} trained for {time.time() - start_time:.0f}s")
    print(
        "Losses:",
        [f"{x}: {train_losses[x].result().numpy()}" for x in cfg.tasks],
    )
    clear_output(wait=True)

    ## Save weights
    if (epoch < 10) or ((epoch + 1) % cfg.save_freq == 0):
        weight_path = cfg.path["weights_checkpoint_fstring"].format(epoch=epoch + 1)
        model.save_weights(weight_path, overwrite=True, save_format="tf")

    ## Reset metric and loss
    [train_losses[x].reset_states() for x in train_losses.keys()]

    for task in train_acc.keys():
        if task == "triangle":
            [x.reset_states() for x in train_acc[task]]
        else:
            train_acc[task].reset_states()


# End of training ops
print("Done")

# Evaluate model

In [None]:
class EvalReading:
    """Bundle of testsets"""
    TESTSETS_NAME = ("strain", "grain")

    def __init__(self, cfg, model, data):
        self.cfg = cfg
        self.model = model
        self.data = data

        self.strain_mean_df = None
        self.grain_mean_df = None

        
        # Load eval results from file
        for _testset_name in self.TESTSETS_NAME:
            try:
                _file = os.path.join(
                    self.cfg.path["model_folder"],
                    "eval",
                    f"{_testset_name}_mean_df.csv",
                )
                setattr(self, f"{_testset_name}_mean_df", pd.read_csv(_file))
            except (FileNotFoundError, IOError):
                pass

        # Bundle testsets into dictionary
        self.run_eval = {
            "train": self._eval_train,
            "strain": self._eval_strain,
            "grain": self._eval_grain,
            "taraban": self._eval_taraban,
            "cortese": self._eval_cortese,
        }
        
    def eval(self, testset_name):
        """Run eval and push to dat"""
        if getattr(self, f"{testset_name}_mean_df") is None:
            results = self.run_eval[testset_name]()
        else:
            print("Evaluation results found, loaded from file.")
                  
        

    def _eval_train(self):
        testset_name = "train"
        t = TestSet(
            name=testset_name,
            cfg=self.cfg,
            model=self.model,
            task="triangle",
            testitems=self.data.testsets[testset_name]["item"],
            x_test=self.data.testsets[testset_name]["ort"],
            y_test=[
                self.data.testsets[testset_name]["pho"],
                self.data.testsets[testset_name]["sem"],
            ],
        )
        t.eval_all()
        df = t.result
        df.to_csv(
            os.path.join(
                self.cfg.path["model_folder"], "eval", f"{testset_name}_item_df.csv"
            )
        )

        # Aggregate
        mean_df = (
            df.groupby(["code_name", "task", "testset", "epoch", "timetick", "y"])
            .mean()
            .reset_index()
        )
        mean_df.to_csv(
            os.path.join(
                self.cfg.path["model_folder"], "eval", f"{testset_name}_mean_df.csv"
            )
        )

        self.train_mean_df = mean_df
        
        return df

    def _eval_strain(self):

        df = pd.DataFrame()
        testsets = (
            "strain_hf_con_hi",
            "strain_hf_inc_hi",
            "strain_hf_con_li",
            "strain_hf_inc_li",
            "strain_lf_con_hi",
            "strain_lf_inc_hi",
            "strain_lf_con_li",
            "strain_lf_inc_li"
        )

        for testset_name in testsets:
            t = TestSet(
                name=testset_name,
                cfg=self.cfg,
                model=self.model,
                task="triangle",
                testitems=self.data.testsets[testset_name]["item"],
                x_test=self.data.testsets[testset_name]["ort"],
                y_test=[
                    self.data.testsets[testset_name]["pho"],
                    self.data.testsets[testset_name]["sem"],
                ],
            )

            t.eval_all()
            df = pd.concat([df, t.result])


        df.to_csv(
            os.path.join(
                self.cfg.path["model_folder"], "eval", "strain_item_df.csv"
            )
        )

        # Condition level aggregate
        mean_df = (
            df.groupby(
                [
                    "code_name",
                    "task",
                    "testset",
                    "epoch",
                    "timetick",
                    "y",
                ]
            )
            .mean()
            .reset_index()
        )
        mean_df.to_csv(
            os.path.join(
                self.cfg.path["model_folder"], "eval", "strain_mean_df.csv"
            )
        )
        self.strain_mean_df = mean_df
        
        return df

    def _eval_grain(self):
        df = pd.DataFrame()
        for testset_name in ("grain_unambiguous", "grain_ambiguous"):
            for grain_size in ("pho_small_grain", "pho_large_grain"):
                t = TestSet(
                    name=testset_name,
                    cfg=self.cfg,
                    model=self.model,
                    task="triangle",
                    testitems=self.data.testsets[testset_name]["item"],
                    x_test=self.data.testsets[testset_name]["ort"],
                    y_test=[
                        self.data.testsets[testset_name][grain_size],
                        self.data.testsets[testset_name]["sem"],
                    ],
                )

                t.eval_all()
                t.result["y_test"] = grain_size
                df = pd.concat([df, t.result])

        # Pho only
        pho_df = df.loc[df.y == "pho"]

        # Calculate pho acc by summing large and small grain response
        pho_acc_df = (
            pho_df.groupby(
                ["code_name", "task", "y", "testset", "epoch", "timetick", "item"]
            )
            .sum()
            .reset_index()
        )

        pho_acc_df["y_test"] = "pho"

        # Sem only (Because we have evaluated semantic twice, we need to remove the duplicates)
        sem_df = df.loc[(df.y == "sem") & (df.y_test == "pho_small_grain")]
        sem_df = sem_df.drop(columns="y_test")
        sem_df["y_test"] = "sem"

        df = pd.concat([pho_df, pho_acc_df, sem_df])
        df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "grain_item_df.csv")
        )

        mean_df = (
            df.groupby(
                ["code_name", "task", "testset", "epoch", "timetick", "y", "y_test"]
            )
            .mean()
            .reset_index()
        )
        mean_df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "grain_mean_df.csv")
        )

        self.grain_mean_df = mean_df
        
        return df

    def _eval_taraban(self):

        testsets = (
            "taraban_hf-exc",
            "taraban_hf-reg-inc",
            "taraban_lf-exc",
            "taraban_lf-reg-inc",
            "taraban_ctrl-hf-exc",
            "taraban_ctrl-hf-reg-inc",
            "taraban_ctrl-lf-exc",
            "taraban_ctrl-lf-reg-inc",
        )

        df = pd.DataFrame()

        for testset_name in testsets:

            t = TestSet(
                name=testset_name,
                cfg=self.cfg,
                model=self.model,
                task="triangle",
                testitems=self.data.testsets[testset_name]["item"],
                x_test=self.data.testsets[testset_name]["ort"],
                y_test=[
                    self.data.testsets[testset_name]["pho"],
                    self.data.testsets[testset_name]["sem"],
                ],
            )

            t.eval_all()
            df = pd.concat([df, t.result])

        df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "taraban_item_df.csv")
        )

        mean_df = (
            df.groupby(["code_name", "task", "testset", "epoch", "timetick", "y"])
            .mean()
            .reset_index()
        )

        mean_df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "taraban_mean_df.csv")
        )

        self.taraban_mean_df = mean_df
        
        return df

    def _eval_cortese(self):

        df = pd.DataFrame()
        for testset_name in ("cortese_hi_img", "cortese_low_img"):
            t = TestSet(
                name=testset_name,
                cfg=self.cfg,
                model=self.model,
                task="triangle",
                testitems=self.data.testsets[testset_name]["item"],
                x_test=self.data.testsets[testset_name]["ort"],
                y_test=[
                    self.data.testsets[testset_name]["pho"],
                    self.data.testsets[testset_name]["sem"],
                ],
            )

            t.eval_all()
            df = pd.concat([df, t.result])

        df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "cortese_item_df.csv")
        )

        mean_df = (
            df.groupby(["code_name", "task", "testset", "epoch", "timetick", "y"])
            .mean()
            .reset_index()
        )

        mean_df.to_csv(
            os.path.join(self.cfg.path["model_folder"], "eval", "cortese_mean_df.csv")
        )

        self.cortese_mean_df = mean_df
        
        return df

    def plot_reading_acc(self, df):
        timetick_selection = alt.selection_single(
            bind=alt.binding_range(min=0, max=self.cfg.n_timesteps, step=1),
            fields=["timetick"],
            init={"timetick": self.cfg.n_timesteps},
            name="timetick",
        )

        p = (
            alt.Chart(df)
            .mark_line()
            .encode(
                x="epoch:Q",
                y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
                color="y",
            )
            .add_selection(timetick_selection)
            .transform_filter(timetick_selection)
        )

        return p

    def plot_grain_by_resp(self):
        df = self.grain_mean_df.loc[
            self.grain_mean_df.y_test.isin(["pho_large_grain", "pho_small_grain"])
        ]
        p = self.plot_reading_acc(df).encode(color="testset", strokeDash="y_test")
        return p

In [None]:
test = evaluate.EvalReading(cfg, model, data)
test.eval("strain")
test.eval("grain")

In [None]:
test.grain_mean_df

In [None]:
cfg.path.keys()

In [None]:
import altair as alt


def plot_reading_acc(df, cfg):
    timetick_selection = alt.selection_single(
        bind=alt.binding_range(min=0, max=cfg.n_timesteps, step=1),
        fields=["timetick"],
        init={"timetick": cfg.n_timesteps},
        name="timetick",
    )

    y_selection = alt.selection_single(
        bind=alt.binding_select(options=["pho", "sem"]),
        fields=["y"],
    )

    return (
        alt.Chart(df)
        .mark_line()
        .encode(
            x="epoch:Q",
            y=alt.Y("mean(acc):Q", scale=alt.Scale(domain=(0, 1))),
            color="testset",
        )
        .add_selection(timetick_selection)
        .add_selection(y_selection)
        .transform_filter(timetick_selection)
        .transform_filter(y_selection)
    )


strain_plot = plot_reading_acc(test.strain_mean_df, cfg)
strain_plot.save(os.path.join(cfg.path["plot_folder"], "strain.html"))

In [None]:
def plot_grain_acceptable(df):

    df = df.loc[df.y_test.isin(["pho", "sem"]) & (df.y == "pho")]

    timetick_selection = alt.selection_single(
        bind=alt.binding_range(min=0, max=cfg.n_timesteps, step=1),
        fields=["timetick"],
        init={"timetick": cfg.n_timesteps},
        name="timetick",
    )

    return (
        alt.Chart(df)
        .mark_line()
        .encode(
            x="epoch:Q",
            y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
            color="testset",
        )
        .add_selection(timetick_selection)
        .transform_filter(timetick_selection)
    )


grain_plot_acc = plot_grain_acceptable(test.grain_mean_df)
grain_plot_acc.save(os.path.join(cfg.path["plot_folder"], "grain1.html"))


def plot_grain_by_resp(df, cfg):
    df = df.loc[df.y_test.isin(["pho_large_grain", "pho_small_grain"])]
    p = plot_reading_acc(df, cfg).encode(color="testset", strokeDash="y_test")
    return p


grain_plot_resp = plot_grain_by_resp(test.grain_mean_df, cfg)
grain_plot_resp.save(os.path.join(cfg.path["plot_folder"], "grain2.html"))

In [None]:
test.strain_mean_df

In [None]:
def plot(df, use_y="acc", y_max=1, task="triangle"):
    df = df.loc[(df.task == task)]

    timetick_selection = alt.selection_single(
        bind=alt.binding_range(min=0, max=12, step=1),
        fields=["timetick"],
        init={"timetick": cfg.n_timesteps},
        name="timetick",
    )

    y_selection = alt.selection_single(
        bind=alt.binding_select(options=["pho", "sem"]),
        init={"y": "pho"},
        fields=["y"],
    )

    cond_selection = alt.selection_multi(bind="legend", fields=["testset"])

    # Plot by condition
    plot_by_cond = (
        alt.Chart(df)
        .mark_line()
        .encode(
            x="epoch:Q",
            y=alt.Y(f"{use_y}:Q", scale=alt.Scale(domain=(0, y_max))),
            color="testset:N",
            opacity=alt.condition(cond_selection, alt.value(1), alt.value(0.1)),
        )
        .add_selection(timetick_selection)
        .add_selection(y_selection)
        .add_selection(cond_selection)
        .transform_filter(timetick_selection)
        .transform_filter(y_selection)
    )

    # Plot contrasts
    contrasts = {}
    contrasts[
        "F contrast"
    ] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_hf_inc_hi + datum.strain_hf_inc_li - 
        (datum.strain_lf_con_hi + datum.strain_lf_con_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
    contrasts[
        "CON contrast"
    ] = """(datum.strain_hf_con_hi + datum.strain_hf_con_li + datum.strain_lf_con_hi + datum.strain_lf_con_li - 
        (datum.strain_hf_inc_hi + datum.strain_hf_inc_li + datum.strain_lf_inc_hi + datum.strain_lf_inc_li))/4"""
    contrasts[
        "IMG contrast"
    ] = """(datum.strain_hf_con_hi + datum.strain_lf_con_hi + datum.strain_hf_inc_hi + datum.strain_lf_inc_hi - 
        (datum.strain_hf_con_li + datum.strain_lf_con_li + datum.strain_hf_inc_li + datum.strain_lf_inc_li))/4"""

    def create_contrast_plot(name):
        return (
            plot_by_cond.encode(
                y=alt.Y("difference:Q", scale=alt.Scale(domain=(-y_max, y_max)))
            )
            .transform_pivot("testset", value=use_y, groupby=["epoch"])
            .transform_calculate(difference=contrasts[name])
            .properties(title=name, width=100, height=100)
        )

    contrast_plots = alt.hconcat()
    for c in contrasts.keys():
        contrast_plots |= create_contrast_plot(c)

    return plot_by_cond | contrast_plots

In [None]:
plot(test.strain_mean_df).save(
    os.path.join(cfg.path["plot_folder"], "strain_contrast.html")
)