1. Align with accuracy instead of epoch
- One epoch that closest to 80% accuracy on PHO
2. Plot individual “network” difference beta over grid
- Taraban : y~lm(freq x cons)
- IMG-HS04 : y~lm(fxcximg)
- Nonword Glushko overall: just acc
3. Big stat model on the entire grid
- Y ~ batch_size  or epsilon check same dimensions or not… 
- y ~ lm/lmer(batch_size  or epsilon * stimprop)  | testset x
4. Also summarize DoL within the same grid [raw, same epoch at 1]
- P: intact, OP, OSP
- S: intact, OS, OPS


# Get merged data

In [None]:
import meta
import os
import pandas as pd
import numpy as np
import altair as alt
from itertools import chain
from tqdm import tqdm
import statsmodels.formula.api as smf
import statsmodels.api as sm
from scipy.stats.mstats import zscore

In [None]:
class Batch:
    """Batch object that take cares of the data manipulation in the results of a batch."""

    def __init__(self, batch_name: str, tf_root: str = None):
        self.batch_name = batch_name
        self.json = os.path.join("models", batch_name, "batch_config.json")
        self.tf_root = tf_root if tf_root else "./"
        self.cfg_df = self.parse_batch_config()
        self.code_names = self.cfg_df.code_name.unique().tolist()

        # Dataframe to be loaded
        self.df = None
        self.backup_df = None

    def mount_testset(self, csv: list):
        self.df = self.parse_df(csv)
        self.checkpoint_df()

    def checkpoint_df(self):
        """Make a df checkpoint copy"""
        self.backup_df = self.df.copy()

    def restore_df(self):
        """Restore self.df to the original dataframe."""
        self.df = self.backup_df

    def subset_df(
        self,
        code_name: str = None,
        epoch: int = None,
        output_name: str = None,
        timetick: list = None,
        cond: list = None,
        train_task: str = None,
    ):
        """Subset self.df to spec."""
        df = self.df
        df = df.loc[df.code_name == code_name] if code_name is not None else df
        df = df.loc[df.epoch == epoch] if epoch is not None else df
        df = df.loc[df.output_name == output_name] if output_name is not None else df
        df = df.loc[df.timetick.isin(timetick)] if timetick is not None else df
        df = df.loc[df.cond.isin(cond)] if cond is not None else df
        df = df.loc[df.train_task == train_task] if train_task is not None else df
        return df

    def subset_by_epoch_dict(self, sel_epoch: dict):
        """Return a subset of the dataframe using a epoch dictionary.
        args:
            sel_epoch: dictionary of epochs to select with k=code_name, v=epoch
        """
        dfs = [self.subset_df(code_name=k, epoch=v) for k, v in sel_epoch.items()]
        return self.concat_dfs(dfs)

    def parse_batch_config(self):
        df = meta.batch_json_to_df(self.json, tf_root=self.tf_root)
        assert (
            self.batch_name == "task_effect"
        )  # Just in case I forgot to change below line in other batches
        df["train_task"] = [
            "OP",
            "OS",
            "Triangle",
        ] * 12  # Caution: this is a hack to get around list type config, only works for this batch
        return df[["code_name", "batch_size", "learning_rate", "train_task"]]

    def parse_df(self, csv: list) -> pd.DataFrame:
        files = chain.from_iterable([self.get_eval_file_names(x) for x in csv])
        df = self.merge_from_file_names(files)
        return df.merge(self.cfg_df, on="code_name", how="left")

    def get_eval_file_names(self, csv_name: str) -> list:
        """Return a list of dataframes from a list of csvs."""
        return [
            os.path.join(
                self.tf_root, "models", self.batch_name, code_name, "eval", csv_name
            )
            for code_name in self.code_names
        ]

    def find_code_name(self, criteria: dict) -> str:
        """Return a code_name from a dictionary of criteria."""
        mask = None
        for k, v in criteria.items():
            hit = (self.cfg_df[k].isin(v)).to_list()
            mask = hit if mask is None else (a & b for a, b in zip(mask, hit))

        return self.cfg_df.code_name.loc[mask].tolist()

    def find_epoch(self, code_name: str, outputs: list, fn: callable, sse: float = None, acc: float = None) -> int:
        """Return an epoch number from an accuracy."""
        assert (sse is None) ^ (acc is None) # Exclusive or (to make sure only sse or acc is set)
        df = self.df.loc[self.df.code_name == code_name]
        df = df.loc[df.output_name.isin(outputs)]
        df = df.groupby("epoch").mean().reset_index()  # Group by epoch

        if acc is not None:
            idx = fn(df.acc, acc)  # Find nearest accuracy
        if sse is not None:
            idx = fn(df.sse, sse)  # Find nearest sse

        if idx is None:
            return None
        else:
            return df.iloc[idx,].epoch  # Return epoch


    @staticmethod
    def merge_from_file_names(filenames: list) -> list:
        """Merge a list of dataframes into one."""
        dfs = [pd.read_csv(f) for f in filenames]
        return Batch.concat_dfs(dfs)

    @staticmethod
    def concat_dfs(dfs: list) -> pd.DataFrame:
        """Return a dataframe from a list of dataframes."""
        return pd.concat(dfs, ignore_index=True).reset_index(drop=True)

    @staticmethod
    def find_nearest(array, value):
        """Returning the index of an array that is closest to a given value."""
        array = np.array(array)
        return (np.abs(array - value)).argmin()

    @staticmethod
    def find_first_less_than(array, value):
        """Returning the first index of an array that has the value lower than a given value
        Return None if no value is found.
        """
        array = np.array(array)
        test = array < value
        return test.argmax() if sum(test) > 0 else None

    @staticmethod
    def find_first_more_than(array, value):
        """Returning the first index of an array that has the value higher than a given value
        Return None if no value is found.
        """
        array = np.array(array)
        test = array > value
        return test.argmax() if sum(test) > 0 else None

    @staticmethod
    def get_acc_based_df(self, acc: float) -> pd.DataFrame:
        """Return a dataframe of accuracy for a code_name."""

        df = self.df.loc[self.df.code_name == code_name].copy()
        # Subset to nearest accuracy epoch
        sel_epoch = self.find_epoch_by_acc(code_name, 0.8)
        df = df.loc[df.epoch == sel_epoch]
        return df


b = Batch("task_effect")


# Examine the correlation between PHO and SEM

In [None]:
b.mount_testset(['train_r100_triangle.csv'])
b.df = b.subset_df(timetick=range(8, 13), train_task="Triangle")

In [None]:
df = b.df.groupby(["batch_size", "learning_rate", "word", "epoch", "output_name"]).mean().reset_index()
df = df[["batch_size", "learning_rate", "epoch", "word", "output_name", "acc", "sse", "act1"]]
df = df.pivot_table(index=["batch_size", "learning_rate", "epoch"], columns="output_name", values="acc").reset_index()

In [None]:
from scipy.stats import spearmanr
spearmanr(df['pho'], df['sem'])

# Finding the earliest epoch that reach below 0.4 SEM SSE

In [None]:
sel_epoch_sem = {x:b.find_epoch(x, outputs=["sem"], fn=b.find_first_less_than, sse=0.4) for x in tqdm(b.df.code_name.unique())}
print(sel_epoch_sem)

# Finding the earliest epoch that reach below 0.05 PHO SSE

In [None]:
sel_epoch_pho = {x:b.find_epoch(x, outputs=["pho"], fn=b.find_first_less_than, sse=0.05) for x in tqdm(b.df.code_name.unique())}
print(sel_epoch_pho)

# Find the epoch that is nearest to 0.9 SEM ACC

In [None]:
sel_epoch_sem09 = {x:b.find_epoch(x, outputs=["sem"], fn=b.find_nearest, acc=0.90) for x in tqdm(b.df.code_name.unique())}
print(sel_epoch_sem09)

In [None]:
class TarabanTest:
    
    def __init__(self, batch: Batch, sel_epoch: dict):
        self.batch = batch
        self.sel_epoch = sel_epoch
        self.df = None # Cleaned dataframe selected to sel_epoch from tidy()
        self.mdf = None # mean within cell of self.df from tidy()
        self.taraban_beta = {} # beta values from run_glm()
        self.tidy()

    def tidy(self):
        # Tidy up Taraban testset
        self.batch.mount_testset(['taraban_triangle.csv'])
        sel_conds = [
        "High-frequency exception",
        "Regular control for High-frequency exception",
        "Low-frequency exception",
        "Regular control for Low-frequency exception",
        ]

        self.batch.df = self.batch.subset_df(output_name="pho", timetick=range(8, 13), cond=sel_conds, train_task="Triangle")

        self.batch.df["freq"] = self.batch.df.cond.apply(
            lambda x: "High"
            if x
            in ("High-frequency exception", "Regular control for High-frequency exception")
            else "Low"
        )

        self.batch.df["reg"] = self.batch.df.cond.apply(
            lambda x: "Regular" if x.startswith("Regular") else "Exception"
        )

        self.batch.checkpoint_df()

        # Create different dfs
        self.df = self.batch.subset_by_epoch_dict(self.sel_epoch)
        self.df = self.df[['batch_size', 'learning_rate', 'code_name', 'epoch', 'timetick', 'freq', 'reg', 'word', 'acc', 'sse']]
        self.mdf = self.df.groupby(['batch_size', 'learning_rate', 'code_name', 'freq', 'reg']).mean().reset_index()

    def plot_selection_acc(self):
    
        acc_txt = alt.Chart(self.mdf).mark_text(dy=6).encode(
            x='learning_rate:O',
            y=alt.Y('batch_size:O'),
            text=alt.Text('mean(acc):Q', format='.2f'),
        ).properties(title = f"Selected epoch and mean accuracy in Taraban testset", width=200, height=200)

        epoch_txt = acc_txt.mark_text(dy=-6).encode(
            text=alt.Text('mean(epoch):Q', format='.0f'),
        )

        heatmap = acc_txt.mark_rect().encode(
            color="mean(acc):Q"
        )

        return heatmap + acc_txt + epoch_txt

    def plot_interaction(self, metric: str = 'acc'):
        metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()
        return alt.Chart(self.mdf).mark_line().encode(
                x=alt.X("freq:N", scale=alt.Scale(reverse=True)),
                y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
                row="batch_size:O",
                column="learning_rate:O",
                color="reg:N",
            ).properties(width=150, height=150)

    def run_grid_glm(self, metric: str = 'acc'):

        # Numeric condition for lm

        self.mdf['reg_num'] = self.mdf.reg.apply(lambda x: 0.5 if x == 'Regular' else -0.5)
        self.mdf['freq_num'] = self.mdf.freq.apply(lambda x: 0.5 if x == 'High' else -0.5)

        if metric == 'acc':
        m = smf.glm(formula=f'zscore({metric}) ~ zscore(learning_rate) * zscore(batch_size) * reg_num * freq_num', data=self.mdf).fit()
        elif metric == 'sse':
            m = smf.glm(formula='zscore(sse) ~ zscore(learning_rate) * zscore(batch_size) * reg_num * freq_num', data=self.mdf).fit()
        
        print(f"===== Grid level GLM on average {metric} =====")
        print(m.summary())


    def run_cell_glm(self, metric: str = 'acc'):

        assert metric in ['acc', 'sse']
        get_beta = self.get_taraban_params_acc if metric == 'acc' else self.get_taraban_params_sse
        
        self.df['reg_num'] = self.df.reg.apply(lambda x: 0.5 if x == 'Regular' else -0.5)
        self.df['freq_num'] = self.df.freq.apply(lambda x: 0.5 if x == 'High' else -0.5)
        # Run cell level GLMs

        params = [get_beta(self.df, code_name=x) for x in tqdm(self.df.code_name.unique())]
        self.taraban_beta[metric] = pd.concat(params, ignore_index=True)
        setting_map = self.mdf[['code_name', 'batch_size', 'learning_rate']].groupby(['code_name']).mean().reset_index()
        self.taraban_beta[metric] = self.taraban_beta[metric].merge(setting_map, on='code_name')

        # Restructure
        self.taraban_beta[metric].columns = ['intercept', 'freq_effect', 'reg_effect', 'interactions', 'code_name', 'batch_size', 'epsilon']
        self.taraban_beta[metric] = self.taraban_beta[metric].melt(id_vars=['code_name', 'batch_size', 'epsilon'], value_vars=['intercept', 'freq_effect', 'reg_effect', 'interactions'])

    def plot_glm_betas(self, metric: str = 'acc', color_range: float = 25):
        """Plot the betas on grid."""

        if self.taraban_beta is {}:
            raise Exception("Run run_glm() first.")

        # Plot betas
        return alt.Chart(self.taraban_beta[metric]).mark_rect().encode(
            x='epsilon:O',
            y='batch_size:O',
            color=alt.Color('value:Q', scale=alt.Scale(domain=(-color_range, color_range), scheme='redblue')),
            column='variable:N',
        ).properties(width=200, height=200)

    @staticmethod
    def get_taraban_params_acc(df, code_name):
        try:
            m = smf.glm(formula="acc ~ freq_num * reg_num", data=df.loc[df.code_name == code_name], family=sm.families.Binomial()).fit()
            p = m.params
            p['code_name'] = code_name
            return pd.DataFrame(p).T
        except Exception as e:
            print(f"Error in {code_name}")
            pass

    @staticmethod
    def get_taraban_params_sse(df, code_name):
        try:
            m = smf.glm(formula="sse ~ freq_num * reg_num", data=df.loc[df.code_name == code_name]).fit()
            p = m.params
            p['code_name'] = code_name
            return pd.DataFrame(p).T
        except Exception as e:
            print(f"Error in {code_name}")
            pass

In [None]:
t = TarabanTest(b, sel_epoch_sem09)

In [None]:
t.plot_selection_acc().save('sel_sem09_taraban_acc.html')
t.plot_interaction(metric='acc').save('sel_sem09_taraban_acc_interaction.html')
t.plot_interaction(metric='sse').save('sel_sem09_taraban_sse_interaction.html')

In [None]:
t.run_grid_glm('acc')

In [None]:
t.run_grid_glm('sse')

In [None]:
t.run_cell_glm('acc')
t.run_cell_glm('sse')

In [None]:
t.plot_glm_betas('acc',color_range=30)

In [None]:
t.plot_glm_betas('sse', color_range=1)

# Find epoch that are closest to 80% accuracy in each network

- Define by Taraban
- at 8-12 ticks
- Train task: Triangle
- Output at PHO

# Nonword

In [None]:
class GlushkoTest:
    
    def __init__(self, batch: Batch, sel_epoch: dict):
        self.batch = batch
        self.sel_epoch = sel_epoch
        self.df = None # Cleaned dataframe selected to sel_epoch from tidy()
        self.mdf = None # mean within cell of self.df from tidy()
        self.tidy()

    def tidy(self):
        # Tidy up 
        self.batch.mount_testset(['glushko_triangle.csv'])
        self.batch.df = self.batch.subset_df(output_name="pho", timetick=range(8, 13), train_task="Triangle")
        self.batch.checkpoint_df()

        # Create different dfs
        self.df = self.batch.subset_by_epoch_dict(self.sel_epoch)
        self.df['cond_num'] = self.df.cond.apply(lambda x: 0.5 if x == 'Regular' else -0.5)
        self.mdf = self.df.groupby(['batch_size', 'learning_rate', 'code_name', 'cond']).mean().reset_index()

    def plot_selection_acc(self):
    
        acc_txt = alt.Chart(self.mdf).mark_text(dy=6).encode(
            x='learning_rate:O',
            y='batch_size:O',
            text=alt.Text('mean(acc):Q', format='.2f'),
        ).properties(title = f"Selected epoch and mean accuracy in Taraban testset", width=200, height=200)

        epoch_txt = acc_txt.mark_text(dy=-6).encode(
            text=alt.Text('mean(epoch):Q', format='.0f'),
        )

        heatmap = acc_txt.mark_rect().encode(
            color="mean(acc):Q"
        )

        return heatmap + acc_txt + epoch_txt

    def plot_cond_heatmap(self, metric: str = 'acc'):
        metric_specific_scale = alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()
        return alt.Chart(self.mdf).mark_rect().encode(
                x='learning_rate:O',
                y='batch_size:O',
                color=alt.Color(f'{metric}:Q', scale=metric_specific_scale),
                column='cond:N',
            ).properties(title=metric.upper(), width=200, height=200)

    def run_grid_glm(self, metric: str = 'acc'):

        m = smf.glm(formula=f'zscore({metric}) ~ zscore(learning_rate) * zscore(batch_size) * cond_num', data=self.mdf).fit()
        print(f"===== Grid level GLM on average {metric.upper()} =====")
        print(m.summary())


In [None]:
g = GlushkoTest(b, sel_epoch_sem09)

In [None]:
g.plot_selection_acc()

In [None]:
g.plot_cond_heatmap(metric='acc')

In [None]:
g.run_grid_glm(metric='acc')

In [None]:
g.plot_cond_heatmap(metric='sse')

In [None]:
g.run_grid_glm('sse')

# Img-HS04

In [None]:
class ImgTest:
    def __init__(self, batch: Batch, sel_epoch: dict):
        self.batch = batch
        self.sel_epoch = sel_epoch
        self.df = None  # Cleaned dataframe selected to sel_epoch from tidy()
        self.mdf = None  # mean within cell of self.df from tidy()
        self.beta = {}  # beta values from run_glm()
        self.tidy()

    def tidy(self):
        # Tidy up
        self.batch.mount_testset(["hs04_img_240_triangle.csv"])
        self.batch.df = self.batch.subset_df(
            output_name="pho", timetick=range(8, 13), train_task="Triangle"
        )
        self.batch.checkpoint_df()

        # Create different dfs
        self.df = self.batch.subset_by_epoch_dict(self.sel_epoch)
        self.df[["freq", "op", "img"]] = self.df.cond.str.split("_", expand=True)
        self.df["fc"] = self.df.cond.apply(lambda x: x[:5])
        self.df["freq_num"] = self.df.freq.apply(lambda x: 0.5 if x == "hf" else -0.5)
        self.df["op_num"] = self.df.op.apply(lambda x: 0.5 if x == "ls" else -0.5)
        self.df["img_num"] = self.df.img.apply(lambda x: 0.5 if x == "hi" else -0.5)

        self.mdf = (
            self.df.groupby(
                [
                    "batch_size",
                    "learning_rate",
                    "code_name",
                    "cond",
                    "fc",
                    "freq",
                    "op",
                    "img",
                ]
            )
            .mean()
            .reset_index()
        )

    def plot_selection_acc(self):

        acc_txt = (
            alt.Chart(self.mdf)
            .mark_text(dy=6)
            .encode(
                x="learning_rate:O",
                y="batch_size:O",
                text=alt.Text("mean(acc):Q", format=".2f"),
            )
            .properties(
                title=f"Selected epoch and mean accuracy", width=200, height=200
            )
        )

        epoch_txt = acc_txt.mark_text(dy=-6).encode(
            text=alt.Text("mean(epoch):Q", format=".0f"),
        )

        heatmap = acc_txt.mark_rect().encode(color="mean(acc):Q")

        return heatmap + acc_txt + epoch_txt

    def plot_bar_img(self, metric: str = "acc"):
        metric_specific_scale = (
            alt.Scale(domain=(0, 1)) if metric == "acc" else alt.Scale()
        )
        return (
            alt.Chart(self.mdf)
            .mark_bar()
            .encode(
                x="cond:N",
                y=alt.Y(f"mean({metric}):Q", scale=metric_specific_scale),
                color="img:N",
                row="batch_size:O",
                column="learning_rate:O",
            )
            .properties(title=metric.upper(), width=200, height=200)
        )

    def run_grid_glm(self, metric: str = "acc"):
        m = smf.glm(
            formula=f"zscore({metric}) ~ zscore(learning_rate) * zscore(batch_size) * freq_num * op_num * img_num",
            data=self.mdf,
        ).fit()
        print(f"===== Grid level GLM on average {metric.upper()} =====")
        print(m.summary())

    def run_cell_glm(self, metric: str = "acc"):

        assert metric in ["acc", "sse"]
        get_beta = (
            self.get_img_params_acc if metric == "acc" else self.get_img_params_sse
        )

        params = [
            get_beta(self.df, code_name=x) for x in tqdm(self.df.code_name.unique())
        ]
        self.beta[metric] = pd.concat(params, ignore_index=True)
        setting_map = (
            self.mdf[["code_name", "batch_size", "learning_rate"]]
            .groupby(["code_name"])
            .mean()
            .reset_index()
        )
        self.beta[metric] = self.beta[metric].merge(setting_map, on="code_name")

        # Restructure
        self.beta[metric].columns = [
            "intercept",
            "freq_effect",
            "reg_effect",
            "fxr",
            "img_effect",
            "fxi",
            "rxi",
            "fxrxi",
            "code_name",
            "batch_size",
            "epsilon",
        ]

        self.beta[metric] = self.beta[metric].melt(
            id_vars=["code_name", "batch_size", "epsilon"],
            value_vars=[
                "intercept",
                "freq_effect",
                "reg_effect",
                "fxr",
                "img_effect",
                "fxi",
                "rxi",
                "fxrxi",
            ],
        )

    def plot_glm_betas(self, metric: str = "acc", color_range: float = 25):
        """Plot the betas on grid."""

        if self.beta is {}:
            raise Exception("Run run_glm() first.")

        # Plot betas
        return (
            alt.Chart(self.beta[metric])
            .mark_rect()
            .encode(
                x="epsilon:O",
                y="batch_size:O",
                color=alt.Color(
                    "value:Q",
                    scale=alt.Scale(
                        domain=(-color_range, color_range), scheme="redblue"
                    ),
                ),
                column="variable:N",
            )
            .properties(width=200, height=200)
        )

    @staticmethod
    def get_img_params_acc(df, code_name):
        try:
            m = smf.glm(
                formula="acc ~ freq_num * op_num * img_num",
                data=df.loc[df.code_name == code_name],
                family=sm.families.Binomial(),
            ).fit()
            p = m.params
            p["code_name"] = code_name
            return pd.DataFrame(p).T
        except Exception as e:
            print(f"Error in {code_name}")
            pass

    @staticmethod
    def get_img_params_sse(df, code_name):
        try:
            m = smf.glm(
                formula="sse ~ freq_num * op_num * img_num",
                data=df.loc[df.code_name == code_name],
            ).fit()
            p = m.params
            p["code_name"] = code_name
            return pd.DataFrame(p).T
        except Exception as e:
            print(f"Error in {code_name}")
            pass


In [None]:
i = ImgTest(b, sel_epoch_sem09)

In [None]:
i.plot_bar_img(metric='acc')

In [None]:
i.run_grid_glm(metric='acc')

In [None]:
i.plot_bar_img(metric='sse')

In [None]:
i.run_grid_glm(metric='sse')

In [None]:
i.run_cell_glm(metric='acc')
i.run_cell_glm(metric='sse')

In [None]:
i.plot_glm_betas(metric='acc')

In [None]:
i.plot_glm_betas(metric='sse',color_range=0.5)

In [None]:
i.plot_selection_acc()

# DoL

### PHO output

In [None]:
b.mount_testset(['train_r100_ort_pho.csv', 'train_r100_exp_osp.csv', 'train_r100_triangle.csv'])
b.df = b.subset_df(timetick=[12], output_name='pho', train_task="Triangle")
df = b.subset_by_epoch_dict(sel_epoch_sem09)
dol_pho_mdf = df.groupby(['batch_size', 'learning_rate', 'code_name', 'task']).mean().reset_index()

alt.Chart(dol_pho_mdf).mark_rect().encode(
    x='learning_rate:O',
    y='batch_size:O',
    color=alt.Color('acc:Q', scale=alt.Scale(domain=(0, 1))),
    column='task:N',
).properties(width=200, height=200)

In [None]:
b.mount_testset(['cos_train_r100_ort_sem.csv', 'cos_train_r100_exp_ops.csv', 'cos_train_r100_triangle.csv'])
b.df = b.subset_df(timetick=[12], output_name='sem', train_task="Triangle")
df = b.subset_by_epoch_dict(sel_epoch_sem09)
dol_sem_mdf = df.groupby(['batch_size', 'learning_rate', 'code_name', 'task']).mean().reset_index()

alt.Chart(dol_sem_mdf).mark_rect().encode(
    x='learning_rate:O',
    y='batch_size:O',
    color=alt.Color('acc:Q', scale=alt.Scale(domain=(0, 1))),
    column='task:N',
).properties(width=200, height=200)