# Individual differences simulation paper part III

In [None]:
%load_ext lab_black
import helper
import pandas as pd
import numpy as np
import altair as alt
from altair.expr import datum

alt.data_transformers.disable_max_rows()
df = helper.parse_from_file("../sims/1520_sims.csv")

### Goal: Somewhat meaningful RD

### Stragegy:

1. Already have TD defined (C2 criteria 3x3x3 middle cells)
    
2. Filtering on epoch
    - Overall accuracy at epoch between (x, y) of TD


### Master Select Model Class

In [None]:
class Select_Model:
    """ Helper class for defining TD
    I: Selection:
    1. Control space filter
    2. Rank filter
    3. Accuracy filter (developmental)
    
    II: Plotting:
    1. Where are the selected model in the control space
    2. How's their average performance (in each cond / mean of all conds)
    3. Some basic descriptives in title
    """

    def __init__(self, df):
        self.df = df

    def count_model(self):
        return len(self.df.code_name.unique())

    # Selection related functions

    def select_by_performance(self, threshold_low, threshold_hi, t_low, t_hi):

        n_pre = self.count_model()
        tmp = self.pivot_to_wide(self.df, t_low, t_hi)
        # Selected models
        tmp = tmp.loc[(tmp.t_low < threshold_low) & (tmp.t_hi > threshold_hi)]

        # Create full dataframe of selected models
        self.df = (
            self.df.loc[self.df.code_name.isin(tmp.code_name)]
            .sort_values(by=["code_name", "cond", "epoch"])
            .reset_index()
        )

        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_control(self, hidden_units=None, p_noise=None, learning_rate=None):

        n_pre = self.count_model()
        if hidden_units is not None:
            self.df = self.df.loc[self.df.hidden_units.isin(hidden_units)]
        if p_noise is not None:
            self.df = self.df.loc[self.df.p_noise.isin(p_noise)]
        if learning_rate is not None:
            self.df = self.df.loc[self.df.learning_rate.isin(learning_rate)]

        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_rankpc(self, minpc, maxpc):
        n_pre = self.count_model()
        self.df = self.df.loc[(self.df.rank_pc >= minpc) & (self.df.rank_pc <= maxpc)]
        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_cond(self, conds):
        n_pre = self.count_model()
        self.df = self.df.loc[self.df.cond.isin(conds)]
        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    # Descriptives related functions

    def get_rankpc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().rank_pc.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    def get_acc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().score.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    # Plotting related functions

    def pivot_to_wide(self, df, t_low, t_hi):
        """ Create a pivot table of model's t_low and t_hi as column
        df: input datafile
        t_low: epoch used in applying threshold_low
        t_hi : epoch used in applying threshold_hi
        """
        tmp = df.loc[(df.epoch.isin([t_low, t_hi]))]

        index_names = [
            "code_name",
            "hidden_units",
            "p_noise",
            "learning_rate",
        ]

        pvt = tmp.pivot_table(
            index=index_names, columns="epoch", values="score",
        ).reset_index()

        # Rename new columns
        pvt.columns = index_names + ["t_low", "t_hi"]
        return pvt

    def plot_control_space(self):
        """Plot selected models at control space"""

        pdf = self.df.groupby("code_name").mean().round(3).reset_index()

        control_space = (
            alt.Chart(pdf)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                column=alt.Column("learning_rate:O", sort="descending"),
                color="count(code_name)",
            )
        )
        return control_space

    def plot_mean_development(self, show_sd):
        """Plot the mean development of all selected models"""

        development_space_sd = (
            alt.Chart(self.df)
            .mark_errorband(extent="stdev")
            .encode(
                y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))),
                x="epoch:Q",
                color="cond:N",
            )
            .properties(
                title="Developmental space: Accuracy in each condition over epoch"
            )
        )

        development_space_mean = development_space_sd.mark_line().encode(
            y="mean(score):Q"
        )

        this_plot = (
            (development_space_mean + development_space_sd)
            if show_sd
            else development_space_mean
        )
        return this_plot

    def plot_all_cond_mean(self, show_sd):
        """Plot the average accuracy in all conditions over epoch of all selected models"""
        group_var = ["code_name", "hidden_units", "p_noise", "learning_rate", "epoch"]
        pdf = self.df.groupby(group_var).mean().reset_index()

        dev_all_sd = (
            alt.Chart(pdf)
            .mark_errorband(extent="stdev")
            .encode(y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))), x="epoch:Q",)
            .properties(
                title="Developmental space: Mean Accuracy in all conditions over epoch"
            )
        )

        dev_all_m = dev_all_sd.mark_line().encode(y="mean(score):Q")

        this_plot = (dev_all_m + dev_all_sd) if show_sd else dev_all_m
        return this_plot

    def make_wnw(self):

        variates = ["hidden_units", "p_noise", "learning_rate"]

        df_wnw = self.df.loc[
            (self.df.cond.isin(["HF_INC", "NW_UN"])),
            variates + ["code_name", "epoch", "cond", "score"],
        ]

        df_wnw = df_wnw.pivot_table(
            index=variates + ["epoch", "code_name"], columns="cond"
        ).reset_index()

        df_wnw.columns = df_wnw.columns = [
            "".join(c).strip() for c in df_wnw.columns.values
        ]
        df_wnw.rename(
            columns={"scoreHF_INC": "word_acc", "scoreNW_UN": "nonword_acc",},
            inplace=True,
        )

        df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
        return df_wnw

    def plot_wnw(self, mean=False):
        """ Performance space plot """
        df = self.make_wnw()

        if mean:
            df = df.groupby("epoch").mean().reset_index()

        wnw_line = (
            alt.Chart(df)
            .mark_line()
            .encode(
                y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
                x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
                tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
                color="code_name:N",
            )
        )

        diagonal = (
            alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
            .mark_line(color="#D3D3D3")
            .encode(
                x=alt.X("x", axis=alt.Axis(title="word")),
                y=alt.X("y", axis=alt.Axis(title="nonword")),
            )
        )

        return (diagonal + wnw_line).properties(
            title="Performance space: Nonword accuracy vs. Word accuracy"
        )

    def plot(self, title=None, show_sd=True):
        """Plot all relevant stuffs"""

        n = len(self.df.code_name.unique())

        t = [
            "Grand mean rank: " + self.get_rankpc_desc(),
            "Grand mean acc  : " + self.get_acc_desc(),
        ]

        if title is not None:
            t = [title + f" (n={n})"] + t

        all_plot = (
            self.plot_control_space()
            & (
                self.plot_mean_development(show_sd=show_sd)
                | self.plot_all_cond_mean(show_sd=show_sd)
            )
        ).properties(title=t)

        return all_plot

In [None]:
class Select_RD(Select_Model):
    def __init__(self, df, td_df):
        self.df = df
        self.td_df = td_df
        self.td_stat = self.get_stat()

    def plot_deviance_heatmap(self):
        """Z-score deviance over epoch"""

        hm = (
            alt.Chart(self.cadf)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                row=alt.Column("learning_rate:O", sort="descending"),
                column="epoch:O",
                color=alt.Color(
                    "mean(z_deviance)",
                    scale=alt.Scale(domain=(-5, 5), scheme="redyellowgreen"),
                ),
                tooltip=["mean(z_deviance)", "mean(score)"],
            )
        )

        return hm

    def make_condition_averaged_df(self):
        cadf = self.df.groupby(["code_name", "epoch"]).mean().reset_index()
        cadf["epoch_idx"] = cadf.apply(
            lambda x: x.epoch * 100 if x.epoch <= 0.1 else x.epoch * 10 + 9, axis=1
        )
        cadf["epoch_idx"] = cadf.epoch_idx.astype(int)
        cadf["z_deviance"] = cadf.apply(self.calcuate_z_deviance, axis=1)
        return cadf

    def calcuate_z_deviance(self, row):
        """
        Calcuate z score relative to TD at each epoch
        """
        m = self.td_stat["mean"][row["epoch_idx"]]
        sd = self.td_stat["std"][row["epoch_idx"]]

        # Avoid zero division
        if sd == 0:
            sd = 1e-6

        return (row["score"] - m) / sd

    def get_stat(self):
        """TD statistics"""
        return (
            self.td_df.groupby(["code_name", "epoch"])
            .mean()
            .reset_index()
            .groupby(["epoch"])
            .agg(["mean", "std"])
            .score.reset_index()
        ).to_dict()

    def get_acc_cut(self, epoch, xsd, cond=None):
        """Get accuracy cut off value with reference to xsd below mean of TD
        td_df: data file of typically developing readers (created by Select_Model().df)
        epoch: at what epoch to classify RD [list]
        xsd: how many sd below mean of TD
        cond: include what condition, default = all conditions (no filtering)
        """

        sel = (
            self.td_df.loc[self.td_df.epoch.isin(epoch)]
            if (cond is None)
            else self.td_df.loc[
                self.td_df.epoch.isin(epoch) & self.td_df.cond.isin(cond)
            ]
        )

        stat = sel.groupby("code_name").mean().score.agg(["mean", "std"])
        return stat["mean"] - xsd * stat["std"]

    def select_by_relative_sd(self, epoch, xsd, cond=None):
        """Select the models that has at least
        X SD <xsd> below mean of TD at <epoch>"""

        tmp = (
            self.df.loc[self.df.epoch.isin(epoch)]
            if (cond is None)
            else self.df.loc[self.df.epoch.isin(epoch) & self.df.cond.isin(cond)]
        )

        mean_tmp = tmp.groupby("code_name").mean().reset_index()
        sel = mean_tmp.loc[mean_tmp.score < self.get_acc_cut(epoch, xsd, cond)]
        self.df = self.df.loc[self.df.code_name.isin(sel["code_name"])]

        # Make deviance
        self.cadf = self.make_condition_averaged_df()

    def plot_bundle(self, baseline_model):
        """ Plot all with baseline model as reference group
        """

        dev_cond = (
            self.plot_mean_development(show_sd=False)
            + baseline_model.plot_mean_development(show_sd=True)
        ).properties(title="Each condition")

        dev_mean = (
            self.plot_all_cond_mean(show_sd=False)
            + baseline_model.plot_all_cond_mean(show_sd=True)
        ).properties(title="Mean in all conditions")

        wnw_mean = (
            self.plot_wnw(mean=True) + baseline_model.plot_wnw(mean=True)
        ).properties(title="Word vs. NW")

        return (
            self.plot_control_space()
            & (dev_cond | dev_mean | wnw_mean)
            & self.plot_deviance_heatmap()
        )

### Results: Control space selection

In [None]:
c2 = Select_Model(df)
c2.select_by_control(
    hidden_units=[100, 150, 200], p_noise=[1, 2, 3], learning_rate=[0.004, 0.006, 0.008]
)
c2.plot("C2: Middle levels").save("td.html")

In [None]:
for epoch in [0.05, 0.1, 0.2, 0.3]:
    for cut in [1.0, 1.5, 2, 2.5]:
        rd = Select_RD(df, c2.df)
        rd.select_by_relative_sd([epoch], cut)
        rd.plot_bundle(baseline_model=c2).save(f"Epoch{epoch:.2f}_Cut{cut:.2f}.html")