# Individual differences (part 3)

In [None]:
%load_ext lab_black
import pandas as pd
import numpy as np
import altair as alt
import helper
from altair.expr import datum

alt.data_transformers.disable_max_rows()

In [None]:
class Select_Model:
    """ Helper class for defining TD
    I: Selection:
    1. Control space filter
    2. Rank filter
    3. Accuracy filter (developmental)
    
    II: Plotting:
    1. Where are the selected model in the control space
    2. How's their average performance (in each cond / mean of all conds)
    3. Some basic descriptives in title
    """

    def __init__(self, df):
        self.df = df

        self.df["cell_code"] = (
            "h"
            + self.df.hidden_units.astype(str)
            + "_p"
            + self.df.p_noise.astype(str)
            + "_l"
            + self.df.learning_rate.astype(str)
        )

    def count_model(self):
        return len(self.df.code_name.unique())

    # Selection related functions

    def select_by_performance(self, threshold_low, threshold_hi, t_low, t_hi):

        n_pre = self.count_model()
        tmp = self.pivot_to_wide(self.df, t_low, t_hi)
        # Selected models
        tmp = tmp.loc[(tmp.t_low < threshold_low) & (tmp.t_hi > threshold_hi)]

        # Create full dataframe of selected models
        self.df = (
            self.df.loc[self.df.code_name.isin(tmp.code_name)]
            .sort_values(by=["code_name", "cond", "epoch"])
            .reset_index()
        )

        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_control(self, hidden_units=None, p_noise=None, learning_rate=None):

        n_pre = self.count_model()
        if hidden_units is not None:
            self.df = self.df.loc[self.df.hidden_units.isin(hidden_units)]
        if p_noise is not None:
            self.df = self.df.loc[self.df.p_noise.isin(p_noise)]
        if learning_rate is not None:
            self.df = self.df.loc[self.df.learning_rate.isin(learning_rate)]

        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_rankpc(self, minpc, maxpc):
        n_pre = self.count_model()
        self.df = self.df.loc[(self.df.rank_pc >= minpc) & (self.df.rank_pc <= maxpc)]
        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_cond(self, conds):
        n_pre = self.count_model()
        self.df = self.df.loc[self.df.cond.isin(conds)]
        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    # Descriptives related functions

    def get_rankpc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().rank_pc.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    def get_acc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().score.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    # Plotting related functions

    def pivot_to_wide(self, df, t_low, t_hi):
        """ Create a pivot table of model's t_low and t_hi as column
        df: input datafile
        t_low: epoch used in applying threshold_low
        t_hi : epoch used in applying threshold_hi
        """
        tmp = df.loc[(df.epoch.isin([t_low, t_hi]))]

        index_names = [
            "code_name",
            "hidden_units",
            "p_noise",
            "learning_rate",
        ]

        pvt = tmp.pivot_table(
            index=index_names, columns="epoch", values="score",
        ).reset_index()

        # Rename new columns
        pvt.columns = index_names + ["t_low", "t_hi"]
        return pvt

    def plot_control_space(self):
        """Plot selected models at control space"""
        pdf = self.df.groupby("code_name").mean().round(3).reset_index()

        pdf["cell_code"] = (
            "h"
            + pdf.hidden_units.astype(str)
            + "_p"
            + pdf.p_noise.astype(str)
            + "_l"
            + pdf.learning_rate.astype(str)
        )

        self.select_control_space = alt.selection(
            type="multi", on="click", empty="none", fields=["cell_code"],
        )

        control_space = (
            alt.Chart(pdf)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                column=alt.Column("learning_rate:O", sort="descending"),
                color="count(code_name)",
                detail="cell_code",
                opacity=alt.condition(
                    self.select_control_space, alt.value(1), alt.value(0.2)
                ),
            )
            .add_selection(self.select_control_space)
        )
        return control_space

    def plot_mean_development(self, show_sd):
        """Plot the mean development of all selected models"""

        development_space_sd = (
            alt.Chart(self.df)
            .mark_errorband(extent="stdev")
            .encode(
                y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))),
                x="epoch:Q",
                color=alt.Color("cond:N", legend=alt.Legend(orient="top")),
            )
            .properties(
                title="Developmental space: Accuracy in each condition over epoch"
            )
            .transform_filter(self.select_control_space)
        )

        development_space_mean = development_space_sd.mark_line().encode(
            y="mean(score):Q"
        )

        this_plot = (
            (development_space_mean + development_space_sd)
            if show_sd
            else development_space_mean
        )
        return this_plot

    def plot_all_cond_mean(self, show_sd):
        """Plot the average accuracy in all conditions over epoch of all selected models"""
        group_var = ["code_name", "hidden_units", "p_noise", "learning_rate", "epoch"]
        pdf = self.df.groupby(group_var).mean().reset_index()

        dev_all_sd = (
            alt.Chart(pdf)
            .mark_errorband(extent="stdev")
            .encode(y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))), x="epoch:Q",)
            .properties(
                title="Developmental space: Mean Accuracy in all conditions over epoch"
            )
        )

        dev_all_m = dev_all_sd.mark_line().encode(y="mean(score):Q")

        this_plot = (dev_all_m + dev_all_sd) if show_sd else dev_all_m
        return this_plot

    def make_wnw(self):

        variates = ["hidden_units", "p_noise", "learning_rate"]

        df_wnw = self.df.loc[
            (self.df.cond.isin(["HF_INC", "NW_UN"])),
            variates + ["code_name", "cell_code", "epoch", "cond", "score"],
        ]

        df_wnw = df_wnw.pivot_table(
            index=variates + ["epoch", "code_name", "cell_code"], columns="cond"
        ).reset_index()

        df_wnw.columns = df_wnw.columns = [
            "".join(c).strip() for c in df_wnw.columns.values
        ]
        df_wnw.rename(
            columns={"scoreHF_INC": "word_acc", "scoreNW_UN": "nonword_acc",},
            inplace=True,
        )

        df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
        return df_wnw

    def plot_wnw(self):
        """ Performance space plot """
        df = self.make_wnw()

        wnw_line = (
            alt.Chart(df)
            .mark_line(color="black")
            .encode(
                y=alt.Y("mean_nw:Q", scale=alt.Scale(domain=(0, 1))),
                x=alt.X("mean_w:Q", scale=alt.Scale(domain=(0, 1))),
                tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
                opacity=alt.value(0.5),
            )
            .transform_filter(self.select_control_space)
            .transform_aggregate(
                mean_nw="mean(nonword_acc)", mean_w="mean(word_acc)", groupby=["epoch"]
            )
        )

        diagonal = (
            alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
            .mark_line(color="#D3D3D3")
            .encode(
                x=alt.X("x", axis=alt.Axis(title="word")),
                y=alt.X("y", axis=alt.Axis(title="nonword")),
            )
        )

        return (diagonal + wnw_line).properties(
            title="Performance space: Nonword accuracy vs. Word accuracy"
        )

    def stat_header(self):

        n = len(self.df.code_name.unique())

        t = [
            "Grand mean rank: " + self.get_rankpc_desc(),
            "Grand mean acc  : " + self.get_acc_desc(),
        ]

        return [f" (n={n})"] + t


    def plot(self, title=None, show_sd=True):
        """Plot all relevant stuffs"""

        if title is not None:
            t = [title] + self.stat_header()

        all_plot = (
            self.plot_control_space()
            & (self.plot_mean_development(show_sd=show_sd) | self.plot_wnw())
        ).properties(title=t)

        return all_plot

### Load baseline data

In [None]:
df_all = helper.parse_from_file("../sims/1520_sims.csv")
hpar = ["hidden_units", "cleanup_units", "p_noise", "learning_rate"]

In [None]:
baseline.select_by_cond(["HF_INC", "NW_UN"])

In [None]:
baseline.plot2("Baseline: Middle control space 3x3x3")

In [None]:
baseline = Select_Model(df_all)
baseline.select_by_control(
    hidden_units=[100, 150, 200], p_noise=[1, 2, 3], learning_rate=[0.004, 0.006, 0.008]
)
baseline.select_by_cond(["HF_INC", "NW_UN"])
baseline.plot2("Baseline: Middle control space 3x3x3")

In [None]:
baseline = Select_Model(df_all)
baseline.select_by_control(hidden_units=[200], p_noise=[2, 3], learning_rate=[0.004])
# baseline.plot1("Baseline: Middle control space 3x3x3")

baseline.plot2("tmp")

In [None]:
x = baseline.select_by_control(hidden_units=[200], p_noise=[3], learning_rate=[0.004])

In [None]:
x.plot1("a")

In [None]:
baseline.plot2("tmp")

### Load Part3 Datafile (n=1225)

In [None]:
df = helper.parse_from_file("../sims/part3_1750.csv")

# Mask at-risk: to show models that has at least one < baseline 
df["risk_count"] = (
    (df.hidden_units < 100) * 1 + (df.p_noise > 3) * 1 + (df.learning_rate < 0.004) * 1
)

df = df.loc[df.risk_count >= 1]
helper.count_grid(df, hpar).save("count_models_masked.html")

### RD object

### Instantiate RD analysis class

In [None]:
rd_word = Select_RD(
    df, baseline.df, include_conds=["HF_INC", "LF_INC", "HF_CON", "LF_CON"]
)

In [None]:
rd_word.plot_bundle(baseline).save("bundle.html")

In [None]:
compensator_df = df.loc[
    (
        df.hidden_units.isin([200, 250])
        & (df.p_noise == 4)
        & (df.learning_rate == 0.004)
    ),
]

In [None]:
compensator = helper.Select_Model(compensator_df)

In [None]:
compensator.plot_mean_development(show_sd=True) | baseline.plot_mean_development(
    show_sd=True
)

In [None]:
compensator.plot_wnw(mean=True) + baseline.plot_wnw(mean=True)

### Static heatmaps

In [None]:
rd_word.plot_heatmap("score").save("score.html")
rd_word.plot_heatmap("pc").save("pc.html")
rd_word.plot_heatmap("z_deviance").save("z.html")

In [None]:
rd = helper.Select_Model(df)

In [None]:
rd_word.df.columns

In [None]:
def reduce_epoch_resolution(df):
    sel_epoch = [0.01, 0.02, 0.03, 0.05, 0.07, 0.09, 0.2, 0.4, 0.6, 0.8, 1.0]
    return df.loc[
        df.epoch.isin(sel_epoch),
    ]


rd_mean_df = rd_word.df.copy()

rd_mean_df = (
    rd_mean_df.groupby(["epoch", "hidden_units", "p_noise", "learning_rate", "cond"])
    .mean()
    .reset_index()
)

rd_mean_df = reduce_epoch_resolution(df=rd_mean_df)

### Plot all cond raw score

In [None]:
def save_raw_heat(var):

    p = (
        alt.Chart(rd_mean_df)
        .mark_rect()
        .encode(
            x="p_noise:O",
            y=alt.Y("hidden_units:O", sort="descending"),
            row=alt.Column("learning_rate:O", sort="descending"),
            column="epoch:O",
            color=alt.Color(
                "score", scale=alt.Scale(domain=(0, 1), scheme="redyellowgreen"),
            ),
            tooltip=["score", "score"],
        )
        .transform_filter(datum.cond == var)
    ).title()

    p.save(f"raw_score_{var}.html")
    
for v in df.cond.unique():
    save_raw_heat(v)    


### Interactive grouping plots

In [None]:
rd_word.plot_interactive_group_heatmap(version="z").save("grouping_z.html")
rd_word.plot_interactive_group_heatmap(version="pc").save("grouping_pc.html")

### Old interactive plot

In [None]:
variates = ["hidden_units", "p_noise", "learning_rate"]

df_wnw = df.loc[
    (df.cond.isin(["HF_INC", "NW_UN"])),
    variates + ["code_name", "epoch", "cond", "score"],
]

df_wnw = df_wnw.pivot_table(
    index=variates + ["epoch", "code_name"], columns="cond"
).reset_index()

df_wnw.columns = df_wnw.columns = ["".join(c).strip() for c in df_wnw.columns.values]
df_wnw.rename(
    columns={"scoreHF_INC": "word_acc", "scoreNW_UN": "nonword_acc",}, inplace=True,
)

df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
df_wnw

In [None]:
select_control_space = alt.selection(
    type="multi",
    on="click",
    empty="none",
    fields=["code_name"],
    init=[{"code_name": "n0_h100_l0.01"}],
)

# Control space
df_overview = df_wnw.loc[df_wnw.epoch == df_wnw.epoch.max()]

control_space = (
    alt.Chart(df_overview)
    .mark_rect()
    .encode(
        x="p_noise:O",
        y=alt.Y("hidden_units:O", sort="descending"),
        column=alt.Column("learning_rate:O", sort="descending"),
        color=alt.Color(
            "word_acc", scale=alt.Scale(scheme="redyellowgreen", domain=(0, 1))
        ),
        opacity=alt.condition(select_control_space, alt.value(1), alt.value(0.3)),
        tooltip=["code_name", "word_acc", "nonword_acc", "word_advantage"],
    )
    .add_selection(select_control_space)
    .properties(title="Select a control parameter setting:")
)

# Development space
df.sort_values(by=["code_name", "cond"], inplace=True)

development_space = (
    alt.Chart(df)
    .mark_line()
    .encode(
        y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))),
        x="epoch:Q",
        color="cond:N",
        tooltip=["code_name", "epoch", "score"],
    )
    .transform_filter(select_control_space)
    .properties(title="Developmental space: Accuracy in each condition over epoch")
)

# Performance space
wnw_line = (
    alt.Chart(df_wnw)
    .mark_line(color="black")
    .encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
    )
    .transform_filter(select_control_space)
)

diagonal = (
    alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
    .mark_line(color="#D3D3D3")
    .encode(
        x=alt.X("x", axis=alt.Axis(title="word")),
        y=alt.X("y", axis=alt.Axis(title="nonword")),
    )
)

performance_space = (diagonal + wnw_line).properties(
    title="Performance space: Nonword accuracy vs. Word accuracy"
)

dev_heat = alt.Chart()


# Merge dashboard
dashboard = control_space & (development_space | performance_space)
dashboard.save("dashboard_with_baseline.html")

### What are the underlying causes that give rise to different subtypes

In [None]:
df.columns

In [None]:
class Select_Subtype(helper.Select_Model):
    def __init__(
        self, df, td_df, include_conds=["HF_CON", "HF_INC", "LF_CON", "LF_INC"]
    ):
        self.include_conds = include_conds
        self.df = df
        self.td_df = td_df
        self.td_stat = self.get_stat()
        self.cadf = self.make_cadf()
        self.zdf = self.make_zdf(self.cadf)
        self.pcdf = self.make_pcdf(self.cadf)
        self.mzdf = self.melt_zdf(self.zdf)
        self.mpcdf = self.melt_pcdf(self.pcdf)

    def get_stat(self):
        """Baseline statistics
        Return mean and sd by epoch in word 
        """
        return (
            self.td_df
            .groupby(["code_name", "epoch", "cond"])
            .mean()
            .reset_index()
            .groupby(["epoch"])
            .agg(["mean", "std"])
            .score.reset_index()
        ).to_dict()

    def make_cadf(self):
        """Make condition avergage df (aggregate cond) with reduced epoch resolution"""
        cadf = (
            self.df.loc[self.df.cond.isin(self.include_conds),]
            .groupby(["code_name", "epoch"])
            .mean()
            .reset_index()
        )
        cadf["learning_rate"] = round(cadf.learning_rate, 4)
        cadf["epoch_idx"] = cadf.apply(
            lambda x: x.epoch * 100 if x.epoch <= 0.1 else x.epoch * 10 + 9, axis=1
        )
        cadf["epoch_idx"] = cadf.epoch_idx.astype(int)
        return self.reduce_epoch_resolution(cadf)

    def make_pcdf(self, cadf):
        """ Make percetange based df"""
        cadf["pc"] = cadf.apply(self.calcuate_percetage_of_baseline, axis=1)

        # Different cutoff of RDs by percentage (0 = RD, 1 = TD)
        cadf["pc_group_50"] = 1 * (cadf.pc > 0.50)
        cadf["pc_group_55"] = 1 * (cadf.pc > 0.55)
        cadf["pc_group_60"] = 1 * (cadf.pc > 0.60)
        cadf["pc_group_65"] = 1 * (cadf.pc > 0.65)
        cadf["pc_group_70"] = 1 * (cadf.pc > 0.70)
        cadf["pc_group_75"] = 1 * (cadf.pc > 0.75)
        cadf["pc_group_80"] = 1 * (cadf.pc > 0.80)
        cadf["pc_group_85"] = 1 * (cadf.pc > 0.85)
        cadf["pc_group_90"] = 1 * (cadf.pc > 0.90)

        return cadf

    def melt_pcdf(self, df):

        mdf = df.melt(
            id_vars=[
                "code_name",
                "epoch",
                "hidden_units",
                "cleanup_units",
                "p_noise",
                "learning_rate",
            ],
            value_vars=[
                "pc_group_50",
                "pc_group_55",
                "pc_group_60",
                "pc_group_65",
                "pc_group_70",
                "pc_group_75",
                "pc_group_80",
                "pc_group_85",
                "pc_group_90",
            ],
        )

        mdf["cutoff"] = mdf.variable.str[-2:].astype(float)

        return mdf

    def make_zdf(self, cadf):
        """Make z-score based df"""

        cadf["z_deviance"] = cadf.apply(self.calcuate_z_deviance, axis=1)

        # Different cutoff of RDs (0 = RD, 1 = TD)
        cadf["group_10"] = 1 * (cadf.z_deviance > -1.0)
        cadf["group_11"] = 1 * (cadf.z_deviance > -1.1)
        cadf["group_12"] = 1 * (cadf.z_deviance > -1.2)
        cadf["group_13"] = 1 * (cadf.z_deviance > -1.3)
        cadf["group_14"] = 1 * (cadf.z_deviance > -1.4)
        cadf["group_15"] = 1 * (cadf.z_deviance > -1.5)
        cadf["group_16"] = 1 * (cadf.z_deviance > -1.6)
        cadf["group_17"] = 1 * (cadf.z_deviance > -1.7)
        cadf["group_18"] = 1 * (cadf.z_deviance > -1.8)
        cadf["group_19"] = 1 * (cadf.z_deviance > -1.9)
        cadf["group_20"] = 1 * (cadf.z_deviance > -2.0)
        cadf["group_21"] = 1 * (cadf.z_deviance > -2.1)
        cadf["group_22"] = 1 * (cadf.z_deviance > -2.2)
        cadf["group_23"] = 1 * (cadf.z_deviance > -2.3)
        cadf["group_24"] = 1 * (cadf.z_deviance > -2.4)
        cadf["group_25"] = 1 * (cadf.z_deviance > -2.5)
        cadf["group_26"] = 1 * (cadf.z_deviance > -2.6)
        cadf["group_27"] = 1 * (cadf.z_deviance > -2.7)
        cadf["group_28"] = 1 * (cadf.z_deviance > -2.8)
        cadf["group_29"] = 1 * (cadf.z_deviance > -2.9)
        cadf["group_30"] = 1 * (cadf.z_deviance > -3.0)

        return cadf

    def melt_zdf(self, df):

        mdf = df.melt(
            id_vars=[
                "code_name",
                "epoch",
                "hidden_units",
                "cleanup_units",
                "p_noise",
                "learning_rate",
            ],
            value_vars=[
                "group_10",
                "group_11",
                "group_12",
                "group_13",
                "group_14",
                "group_15",
                "group_16",
                "group_17",
                "group_18",
                "group_19",
                "group_20",
                "group_21",
                "group_22",
                "group_23",
                "group_24",
                "group_25",
                "group_26",
                "group_27",
                "group_28",
                "group_29",
                "group_30",
            ],
        )

        mdf["cutoff"] = mdf.variable.str[-2:].astype(float) / 10

        return mdf

    def plot_heatmap(self, var):
        """Z-score deviance over epoch"""
        if var == "z_deviance":
            domain = (-5, 5)
            df = self.zdf
        elif var == "score":
            domain = (0, 1)
            df = self.zdf
        elif var == "pc":
            domain = (0, 1)
            df = self.pcdf

        mean_var = f"mean({var})"

        hm = (
            alt.Chart(df)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                row=alt.Column("learning_rate:O", sort="descending"),
                column="epoch:O",
                color=alt.Color(
                    mean_var, scale=alt.Scale(domain=domain, scheme="redyellowgreen"),
                ),
                tooltip=["mean(score)", mean_var],
            )
        )

        return hm

    def reduce_epoch_resolution(self, df):
        sel_epoch = [0.01, 0.02, 0.03, 0.05, 0.07, 0.09, 0.2, 0.4, 0.6, 0.8, 1.0]
        return df.loc[
            df.epoch.isin(sel_epoch),
        ]

    def calcuate_z_deviance(self, row):
        """Calcuate z score relative to TD at each epoch
        """
        m = self.td_stat["mean"][row["epoch_idx"]]
        sd = self.td_stat["std"][row["epoch_idx"]]

        # Avoid zero division
        if sd == 0:
            sd = 1e-6

        return (row["score"] - m) / sd

    def calcuate_percetage_of_baseline(self, row):
        """Calcuate % relative to TD at each epoch
        """
        m = self.td_stat["mean"][row["epoch_idx"]]

        return row["score"] / m

    def get_acc_cut(self, epoch, xsd, cond=None):
        """Get accuracy cut off value with reference to xsd below mean of TD
        td_df: data file of typically developing readers (created by Select_Model().df)
        epoch: at what epoch to classify RD [list]
        xsd: how many sd below mean of TD
        cond: include what condition, default = all conditions (no filtering)
        """

        sel = (
            self.td_df.loc[self.td_df.epoch.isin(epoch)]
            if (cond is None)
            else self.td_df.loc[
                self.td_df.epoch.isin(epoch) & self.td_df.cond.isin(cond)
            ]
        )

        stat = sel.groupby("code_name").mean().score.agg(["mean", "std"])
        return stat["mean"] - xsd * stat["std"]

    def select_by_relative_sd(self, epoch, xsd, cond=None):
        """Select the models that has at least
        X SD <xsd> below mean of TD at <epoch>"""

        tmp = (
            self.df.loc[self.df.epoch.isin(epoch)]
            if (cond is None)
            else self.df.loc[self.df.epoch.isin(epoch) & self.df.cond.isin(cond)]
        )

        mean_tmp = tmp.groupby("code_name").mean().reset_index()
        sel = mean_tmp.loc[mean_tmp.score < self.get_acc_cut(epoch, xsd, cond)]
        self.df = self.df.loc[self.df.code_name.isin(sel["code_name"])]

        # Make deviance
        self.cadf = self.make_condition_averaged_df()

    def plot_bundle(self, baseline_model):
        """ Plot all with baseline model as reference group
        """

        dev_cond = (
            self.plot_mean_development(show_sd=False)
            + baseline_model.plot_mean_development(show_sd=True)
        ).properties(title="Each condition")

        dev_mean = (
            self.plot_all_cond_mean(show_sd=False)
            + baseline_model.plot_all_cond_mean(show_sd=True)
        ).properties(title="Mean in all conditions")

        wnw_mean = (
            self.plot_wnw(mean=True) + baseline_model.plot_wnw(mean=True)
        ).properties(title="Word vs. NW")

        return (
            self.plot_control_space() & (dev_cond | dev_mean | wnw_mean)
        ).properties(title=self.stat_header())

    def plot_interactive_group_heatmap(self, version="pc"):
        """ Plot interactive grouping heatmap
        version: "pc" percentage definition
                 "z" z-score definition
        """
        assert (version == "pc") or (version == "z")
        if version == "pc":
            use_df = self.mpcdf
            slider = alt.binding_range(
                min=50.0, max=90.0, step=5.0, name="percentage cutoff:"
            )
            selector = alt.selection_single(
                name="SelectorName",
                fields=["cutoff"],
                bind=slider,
                init={"cutoff": 80.0},
            )
        else:
            use_df = self.mzdf
            slider = alt.binding_range(min=1.0, max=3.0, step=0.1, name="z cutoff:")
            selector = alt.selection_single(
                name="SelectorName",
                fields=["cutoff"],
                bind=slider,
                init={"cutoff": 2.0},
            )

        df = (
            use_df.groupby(
                ["hidden_units", "p_noise", "learning_rate", "epoch", "cutoff"]
            )
            .mean()
            .reset_index()
        )

        interactive_group_heatmap = (
            alt.Chart(df)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                row=alt.Column("learning_rate:O", sort="descending"),
                column="epoch:O",
                color=alt.Color(
                    "value", scale=alt.Scale(domain=(0, 1), scheme="redyellowgreen"),
                ),
            )
            .add_selection(selector)
            .transform_filter(selector)
        )

        return interactive_group_heatmap

In [None]:
baseline.df.groupby(["epoch", "cond"]).mean().reset_index().score.reset_index()

In [None]:
baseline.df

.groupby(["epoch"])
            
            .score.reset_index()
        ).to_dict()