# Individual differences (part 3)

In [None]:
%load_ext lab_black
import pandas as pd
import numpy as np
import altair as alt
import helper
from altair.expr import datum

alt.data_transformers.disable_max_rows()

In [None]:
class Select_Model:
    """ Helper class for defining TD
    I: Selection:
    1. Control space filter
    
    II: Plotting:
    1. Where are the selected model in the control space
    2. How's their average performance (in each cond / mean of all conds)
    3. Some basic descriptives in title
    """

    def __init__(self, df):
        self.df = df

        self.df["cell_code"] = (
            "h"
            + self.df.hidden_units.astype(str)
            + "_p"
            + self.df.p_noise.astype(str)
            + "_l"
            + self.df.learning_rate.astype(str)
        )

    def count_model(self):
        return len(self.df.code_name.unique())

    def select_by_control(self, hidden_units=None, p_noise=None, learning_rate=None):

        n_pre = self.count_model()
        if hidden_units is not None:
            self.df = self.df.loc[self.df.hidden_units.isin(hidden_units)]
        if p_noise is not None:
            self.df = self.df.loc[self.df.p_noise.isin(p_noise)]
        if learning_rate is not None:
            self.df = self.df.loc[self.df.learning_rate.isin(learning_rate)]

        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    def select_by_cond(self, conds):
        n_pre = self.count_model()
        self.df = self.df.loc[self.df.cond.isin(conds)]
        n_post = self.count_model()
        print(f"Selected {n_post} models from the original {n_pre} models")

    # Descriptives related functions

    def get_rankpc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().rank_pc.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    def get_acc_desc(self):
        desc = self.df.groupby("code_name").mean().reset_index().score.describe()
        return f"M:{desc['mean']:.3f} SD: {desc['std']:.3f} Min: {desc['min']:.3f} Max: {desc['max']:.3f}"

    def plot_control_space(self):
        """Plot selected models at control space"""
        pdf = self.df.groupby("code_name").mean().round(3).reset_index()

        pdf["cell_code"] = (
            "h"
            + pdf.hidden_units.astype(str)
            + "_p"
            + pdf.p_noise.astype(str)
            + "_l"
            + pdf.learning_rate.astype(str)
        )

        self.select_control_space = alt.selection(
            type="multi", on="click", empty="none", fields=["cell_code"],
        )

        control_space = (
            alt.Chart(pdf)
            .mark_rect()
            .encode(
                x="p_noise:O",
                y=alt.Y("hidden_units:O", sort="descending"),
                column=alt.Column("learning_rate:O", sort="descending"),
                color="count(code_name)",
                detail="cell_code",
                opacity=alt.condition(
                    self.select_control_space, alt.value(1), alt.value(0.2)
                ),
            )
            .add_selection(self.select_control_space)
        )
        return control_space

    def plot_mean_development(self, show_sd, baseline=False):
        """Plot the mean development of all selected models"""

        development_space_sd = (
            alt.Chart(self.df)
            .mark_errorband(extent="stdev")
            .encode(
                y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))),
                x="epoch:Q",
                color=alt.Color("cond:N", legend=alt.Legend(orient="top")),
            )
            .properties(
                title="Developmental space: Accuracy in each condition over epoch"
            )
            .transform_filter(self.select_control_space)
        )

        development_space_mean = development_space_sd.mark_line().encode(
            y="mean(score):Q"
        )

        this_plot = (
            (development_space_mean + development_space_sd)
            if show_sd
            else development_space_mean
        )

        if baseline != False:
            this_plot += baseline
        return this_plot

    def make_wnw(self):

        variates = ["hidden_units", "p_noise", "learning_rate"]

        df_wnw = self.df.loc[
            (self.df.cond.isin(["HF_INC", "NW_UN"])),
            variates + ["code_name", "cell_code", "epoch", "cond", "score"],
        ]

        df_wnw = df_wnw.pivot_table(
            index=variates + ["epoch", "code_name", "cell_code"], columns="cond"
        ).reset_index()

        df_wnw.columns = df_wnw.columns = [
            "".join(c).strip() for c in df_wnw.columns.values
        ]
        df_wnw.rename(
            columns={"scoreHF_INC": "word_acc", "scoreNW_UN": "nonword_acc",},
            inplace=True,
        )

        df_wnw["word_advantage"] = df_wnw.word_acc - df_wnw.nonword_acc
        return df_wnw

    def plot_wnw(self, baseline=False):
        """ Performance space plot """
        df = self.make_wnw()

        wnw_line = (
            alt.Chart(df)
            .mark_line(color="black")
            .encode(
                y=alt.Y("mean_nw:Q", scale=alt.Scale(domain=(0, 1))),
                x=alt.X("mean_w:Q", scale=alt.Scale(domain=(0, 1))),
                tooltip=["epoch", "mean_w:Q", "mean_nw:Q"],
            )
            .transform_filter(self.select_control_space)
            .transform_aggregate(
                mean_nw="mean(nonword_acc)", mean_w="mean(word_acc)", groupby=["epoch"]
            )
        )

        text = wnw_line.mark_text(align="left", dx=1, size=16).encode(text="epoch")
        wnw_line += text

        if baseline != False:
            wnw_line += baseline

        diagonal = (
            alt.Chart(pd.DataFrame({"x": [0, 1], "y": [0, 1]}))
            .mark_line(color="#D3D3D3")
            .encode(
                x=alt.X("x", axis=alt.Axis(title="word")),
                y=alt.X("y", axis=alt.Axis(title="nonword")),
            )
        )

        return (diagonal + wnw_line).properties(
            title="Performance space: Nonword accuracy vs. Word accuracy"
        )

    def stat_header(self):

        n = len(self.df.code_name.unique())

        t = [
            "Grand mean rank: " + self.get_rankpc_desc(),
            "Grand mean acc  : " + self.get_acc_desc(),
        ]

        return [f" (n={n})"] + t

    def plot(self, title=None, show_sd=True):
        """Plot all relevant stuffs"""

        if title is not None:
            t = [title] + self.stat_header()

        all_plot = (
            self.plot_control_space()
            & (
                self.plot_mean_development(show_sd=show_sd, baseline=base_dev)
                | self.plot_wnw(baseline=base_wnw)
            )
        ).properties(title=t)

        return all_plot

### Load baseline data

In [None]:
df_1520 = helper.parse_from_file("../sims/1520_sims.csv")
hpar = ["hidden_units", "cleanup_units", "p_noise", "learning_rate"]
baseline = Select_Model(df_1520)
baseline.select_by_control(
    hidden_units=[100, 150, 200], p_noise=[1, 2, 3], learning_rate=[0.004, 0.006, 0.008]
)
baseline.select_by_cond(["HF_INC", "NW_UN"])

### Baseline development

In [None]:
base_dev = (
    alt.Chart(baseline.df)
    .mark_errorband(extent="stdev")
    .encode(
        y=alt.Y("score:Q", scale=alt.Scale(domain=(0, 1))),
        x="epoch:Q",
        color=alt.Color("cond:N", legend=alt.Legend(orient="top")),
        opacity=alt.value(0.3),
    )
)

line = base_dev.mark_line(strokeDash=[10, 10]).encode(y="mean(score)")

base_dev = line

### Baseline performance

In [None]:
df = baseline.make_wnw()

base_wnw = (
    alt.Chart(df)
    .mark_line(color="black", strokeDash=[10, 10])
    .encode(
        y=alt.Y("mean_nw:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("mean_w:Q", scale=alt.Scale(domain=(0, 1))),
        tooltip=["epoch", "mean_nw:Q", "mean_w:Q"],
        opacity=alt.value(0.5),
    )
    .transform_aggregate(
        mean_nw="mean(nonword_acc)", mean_w="mean(word_acc)", groupby=["epoch"]
    )
)

words = base_wnw.mark_text(align="left", dx=1, size=16).encode(text="epoch")

base_wnw += words

### Load Part3 Datafile

In [None]:
df = helper.parse_from_file("../sims/part3_1750.csv")

# Mask at-risk: to show models that has at least one < baseline
df["risk_count"] = (
    (df.hidden_units < 100) * 1 + (df.p_noise > 3) * 1 + (df.learning_rate < 0.004) * 1
)

df = df.loc[df.risk_count >= 1]

atrisk = Select_Model(df)
atrisk.select_by_cond(["HF_INC", "NW_UN"])

In [None]:
atrisk.plot("at risk", show_sd=False).save("new_interactive.html")