In [None]:
%load_ext blackcellmagic

# Declarator for counting words

In [None]:
def word_count(func, *args, **kwargs):
    def wrapper():
        counter = {}
        words = func(*args, **kwargs)
        for word in words:
            if word in counter:
                counter[word] += 1
            else:
                counter[word] = 1
        return func(*args, **kwargs), counter

    return wrapper

# Stitching altair plots

In [None]:
import benchmark_hs04, meta
import os
import altair as alt

In [None]:
def plot1(test_obj, null, metric, timetick=12):
    """Plot metric over training epoch"""
    code_name = test_obj.cfg.code_name
    df = test_obj.eval("train_r100", "triangle")
    mean_df = benchmark_hs04.make_mean_df(df)

    metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()
    plot_title = f"{code_name}: (batch size: {test_obj.cfg.batch_size}; learning rate: {test_obj.cfg.learning_rate})"
    return (
        alt.Chart(mean_df.loc[mean_df.timetick==timetick])
        .mark_line(point=True)
        .encode(
            x="epoch:Q",
            y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
            color="output_name:N",
        )
    ).properties(title=plot_title)

def plot2(test_obj, epoch_selector, metric, timetick=range(4, 13)):
    """Plot metric over training epoch"""
    code_name = test_obj.cfg.code_name
    df = test_obj.eval("taraban", "triangle")
    mean_df = benchmark_hs04.make_cond_mean_df(df)
    mean_df = mean_df.loc[mean_df.timetick.isin(timetick)]
    mean_df = mean_df.loc[mean_df.output_name == "pho"]
    mean_df = mean_df.loc[
        mean_df.cond.isin(
            [
                "High-frequency exception",
                "Regular control for High-frequency exception",
                "Low-frequency exception",
                "Regular control for Low-frequency exception",
            ]
        )
    ]
    
    mean_df["freq"] = mean_df.cond.apply(
        lambda x: "High"
        if x
        in ("High-frequency exception", "Regular control for High-frequency exception")
        else "Low"
    )
    mean_df["reg"] = mean_df.cond.apply(
        lambda x: "Regular" if x.startswith("Regular") else "Exception"
    )


    metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()
    plot_title = f"{code_name}: (batch size: {test_obj.cfg.batch_size}; learning rate: {test_obj.cfg.learning_rate})"

    epoch_selection = (
        alt.Chart(mean_df).mark_rect().encode(x="epoch:Q").add_selection(epoch_selector)
    ).properties(width=400)

    plot_fxc_interact = (
        alt.Chart(mean_df)
        .mark_line()
        .encode(
            x=alt.X("freq:N", scale=alt.Scale(reverse=True),axis=alt.Axis(labels=False)),
            y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
            color="reg:N",
        )
        .transform_filter(epoch_selector)
        .properties(width=400)
    )

    return (epoch_selection & plot_fxc_interact).properties(title=plot_title)


def plot4(test_obj, epoch_selector, metric, timetick=range(4, 13)):
    code_name = test_obj.cfg.code_name

    df = test_obj.eval("hs04_img_240", "triangle")
    mean_df = benchmark_hs04.make_cond_mean_df(df)

    mean_df = mean_df.loc[mean_df.timetick.isin(timetick)]
    mean_df = mean_df.loc[mean_df.output_name == "pho"]
    
    mean_df["fc"] = mean_df.cond.apply(lambda x: x[:5])
    mean_df["img"] = mean_df.cond.apply(lambda x: x[-2:])

    metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()

    epoch_selection = (
        alt.Chart(mean_df).mark_rect().encode(x="epoch:Q").add_selection(epoch_selector)
    ).properties(width=300)

    bar = (
        alt.Chart(mean_df)
        .mark_bar()
        .encode(
            x="img:N",
            y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
            column=alt.Column("fc:N", sort=['hf_ls', 'lf_ls', 'hf_hs', 'lf_hs']),
            color="img:N",
        )
        .transform_filter(epoch_selector)
    ).properties(height=200)

    plot_title = f"{code_name}: (batch size: {test_obj.cfg.batch_size}; learning rate: {test_obj.cfg.learning_rate})"
    return (epoch_selection & bar).properties(title=plot_title)

def plot6(test_obj, null, metric, cond='hf', timetick=12, testset='train_r100'):
    code_name = test_obj.cfg.code_name

    df_intact = test_obj.eval(testset, "triangle", save_file_prefix="cos")
    df_os_lesion = test_obj.eval(testset, "exp_ops", save_file_prefix="cos")
    df_ops_lesion = test_obj.eval(testset, "ort_sem", save_file_prefix="cos")
    df_sem = pd.concat([df_intact, df_os_lesion, df_ops_lesion], ignore_index=True)
    mean_df = benchmark_hs04.make_cond_mean_df(df_sem)

    mean_df = mean_df.loc[mean_df.timetick.isin([timetick])]
    mean_df = mean_df.loc[mean_df.output_name == "sem"]
    mean_df = mean_df.loc[mean_df.cond == cond]
    
    metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()

    plot_title = f"{code_name}: (batch size: {test_obj.cfg.batch_size}; learning rate: {test_obj.cfg.learning_rate})"

    return alt.Chart(mean_df).mark_line(point=True).encode(
            x="epoch:Q",
            y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
            color="task:N",
        ).properties(title=plot_title)



In [None]:
class BatchPlot:

    BENCHMARKS = {1: "Accuracy", 2: "Taraban", 4: "HS04-IMG", 6: "Cosine lesion in SEM"}
    PLOTTER = {1: plot1, 2: plot2, 4: plot4, 6: plot6}

    def __init__(self, batch_name: str, tf_root: str):
        self.batch_name = batch_name
        self.json = os.path.join("models", batch_name, "batch_config.json")
        self.tf_root = tf_root
        self.cfg_df = self.parse_batch_config()
        self.interval_epoch = alt.selection_interval(init={"epoch": (100, 200)})

    def parse_batch_config(self):
        df = meta.batch_json_to_df(self.json, tf_root=self.tf_root)
        assert (
            self.batch_name == "task_effect"
        )  # Just in case I forgot to change below line in other batches
        df["task"] = [
            "OP",
            "OS",
            "Triangle",
        ] * 12  # Caution: this is a hack to get around list type config, only works for this batch
        return df[["code_name", "batch_size", "learning_rate", "task"]]

    def plot(self, code_name, plot_fn, **kwargs):
        """Plot an acc in a run"""

        test = benchmark_hs04.init(code_name, batch_name=self.batch_name)
        test.cfg.tf_root = self.tf_root
        return plot_fn(test, self.interval_epoch, **kwargs)

    def find_code_name(self, criteria: dict) -> str:
        """Return a code_name from a dictionary of criteria."""
        mask = []
        for k, v in criteria.items():
            hit = (self.cfg_df[k] == v).to_list()
            if len(mask) > 0:
                mask = [a and b for a, b in zip(mask, hit)]
            else:
                mask = hit

        return self.cfg_df.code_name.loc[mask].to_string(index=False)

    def plot_grid(
        self,
        benchmark_id: int,
        metric: str,
        col: str = "learning_rate",
        row: str = "batch_size",
        task: str = "Triangle",
    ) -> alt.Chart:
        """Plot grid of plots"""
        plotter = self.PLOTTER[benchmark_id]

        cols_val = sorted(self.cfg_df[col].unique())
        rows_val = sorted(self.cfg_df[row].unique())

        fig = alt.hconcat()
        for vr in rows_val:
            this_row = alt.vconcat()
            for vc in cols_val:
                criterion = {col: vc, row: vr, "task": task}
                code_name = self.find_code_name(criterion)
                this_row |= self.plot(code_name, plot_fn=plotter, metric=metric)
            fig &= this_row

        return fig


bp = BatchPlot(batch_name="task_effect", tf_root="/home/jupyter/triangle_model")

In [None]:
bp.plot_grid(1, metric="acc").save('interactive_1_acc.html')
bp.plot_grid(1, metric="sse").save('interactive_1_sse.html')

In [None]:
bp.plot_grid(2, metric="acc").save('interactive_2_acc.html')
bp.plot_grid(2, metric="csse").save('interactive_2_csse.html')

In [None]:
bp.plot_grid(4, metric="acc").save('interactive_4_acc.html')
bp.plot_grid(4, metric="csse").save('interactive_4_csse.html')