# Examine stimulus properties effects
On PHO:
- F + OP + IMG + OP x F + OP x IMG + F x IMG 

On SEM:
- IMG x F

Steps:
1. Get output at tick 12
2. Run lm on each rng_seed
    - Logistic regression for accuracy
    - Linear regression for SSE
3. Extract all betas
4. Average the betas over rng_seed
5. Plot developmental and performance space
    - add zero horizontal line
    - add epoch info
    - add sem in pho output plot, vice versa...
6. Make interactive heat if I have enough time

In [None]:
# Utilities
%load_ext google.cloud.bigquery
import json
import meta
from tqdm import tqdm

# Tidy and visualize
import pandas as pd
import numpy as np
import altair as alt

# Statistics
from scipy.stats.mstats import zscore
import statsmodels.formula.api as smf
import statsmodels.api as sm

# Variance

In [None]:
%%bigquery df_pho
SELECT
  code_name, epoch, AVG(acc) as mean_acc
FROM
  `majestic-camp-303620.station_3.train`
WHERE
  timetick = 12
  AND output_name = 'pho'
GROUP BY
  code_name, epoch

In [None]:
%%bigquery df_sem
SELECT
  code_name, epoch, AVG(acc) as mean_acc
FROM
  `majestic-camp-303620.station_3.train`
WHERE
  timetick = 12
  AND output_name = 'sem'
GROUP BY
  code_name, epoch

In [None]:
df_pho['y'] = 'pho'
df_sem['y'] = 'sem'
df = pd.concat([df_pho, df_sem])

In [None]:
json_file = "models/station_3/batch_config.json"

with open(json_file) as f:
    batch_cfgs = json.load(f)

all_params = [pd.DataFrame(cfg["params"]) for cfg in batch_cfgs if type(cfg["params"].values()) is not list]
cfgs = pd.concat(all_params, ignore_index=True)
cfgs = cfgs.groupby(['code_name', 'batch_size', 'learning_rate']).mean().reset_index()
cfgs = cfgs[['code_name', 'batch_size', 'learning_rate']]

In [None]:
df = df.merge(cfgs, on=['code_name'], how='left')

In [None]:
df

In [None]:
alt.data_transformers.disable_max_rows()

acc = alt.Chart(df).mark_line().encode(
    x='epoch:Q',
    y='mean(mean_acc):Q',
    color="code_name:N",
    strokeDash='y:N',
    row='batch_size:O',
    column='learning_rate:O',
)
acc.save('acc_stroke.html')


In [None]:
stdev = alt.Chart(df).mark_line().encode(
    x='epoch:Q',
    y='stdev(mean_acc):Q',
    color='y:N',
    row='batch_size:O',
    column='learning_rate:O',
)
stdev.save('stdev.html')

# H-param of interest

In [None]:
poi = 'hidden_os'
poi_cfg_name = f"{poi}_units"

# Get PHO beta

In [None]:
%%bigquery df
SELECT
  code_name, epoch, word, acc, sse, output_name
FROM
  `majestic-camp-303620.hidden_os.train`
WHERE
  timetick = 12

In [None]:
# df.to_csv("models/hidden_op/pho_lasttick.csv")

In [None]:
json_file = f"models/{poi}/batch_config.json"

with open(json_file) as f:
    batch_cfgs = json.load(f)

all_params = [pd.DataFrame(cfg["params"]) for cfg in batch_cfgs if type(cfg["params"].values()) is not list]
cfgs = pd.concat(all_params, ignore_index=True)
cfgs = cfgs.groupby(['code_name', poi_cfg_name]).mean().reset_index()
cfgs = cfgs[['code_name', poi_cfg_name]]

In [None]:
tf_root = os.environ.get("TF_ROOT")

surprisal = pd.read_csv(os.path.join(tf_root, "corpus/noam_surprisal.csv"))
word2op_dict = {word: op for word, op in zip(surprisal.word, surprisal["uncond.surprisal"])}

df_train = pd.read_csv(os.path.join(tf_root, "dataset/df_train.csv"))
word2wf_dict = {word: wf for word, wf in zip(df_train.word, df_train.wf)}

img_replacement_value = df_train.img[0] # The first element is the mean replacement value in the dataset, get rid of it. 
word2img_dict = {word: img for word, img in zip(df_train.word, df_train.img) if not img == img_replacement_value}


def word2op(word):
    try:
        return word2op_dict[word]
    except:
        return None

def word2wf(word):
    try:
        return np.log10(word2wf_dict[word] + 1)
    except:
        return None

def word2img(word):
    try:
        return word2img_dict[word]
    except:
        return None

selected_words = set.intersection(set(word2op_dict.keys()), set(word2wf_dict.keys()), set(word2img_dict.keys()))

In [None]:
# calculate csse
df = df[df.word.isin(selected_words)]

df['csse'] = df.sse.loc[df.acc == 1]

# Get wf and op for each word
df['wf'] = df.word.apply(word2wf)
df['op'] = df.word.apply(word2op)
df['img'] = df.word.apply(word2img)

# Get batch size and learning rate
df = df.merge(cfgs, on=['code_name'], how='left')

# checkpoint
df.to_csv(f"models/{poi}/parsed_df.csv")

In [None]:
mdf = df.groupby([poi_cfg_name, 'epoch', 'output_name']).mean().reset_index()

In [None]:
alt.Chart(mdf).mark_line().encode(
    x='epoch:Q',
    y='mean(acc):Q',
    color='output_name:N',
    column=poi_cfg_name
)

In [None]:
def get_beta(df: pd.DataFrame, output_name: str, code_name:str, epoch:int, metric:str) -> pd.DataFrame:
    """Run one GLM and get one row of beta"""

    try:
        sdf = df.loc[(df.epoch == epoch) & (df.code_name == code_name) & (df.output_name == output_name)]
        poi_value = sdf[poi_cfg_name].unique()[0]

        assert metric in ('acc', 'sse', 'csse')
        sdf = sdf[['word', metric, 'op', 'wf', 'img']].dropna()

        # Determine RHS by output_name
        if output_name == 'pho':
            rhs = "zscore(op) * zscore(wf) + zscore(op) * zscore(img) + zscore(wf) * zscore(img) + 0"
        elif output_name == 'sem':
            rhs = "zscore(wf) * zscore(img) + 0"
        
        # Determine LHS and link function by DV
        if metric == 'acc':
            m = smf.glm(formula=f"acc ~ {rhs}", family=sm.families.Binomial(), data=sdf).fit()
        else:
            m = smf.glm(formula=f"zscore(csse) ~ {rhs}", data=sdf).fit()

        p = m.params
        p['epoch'] = epoch
        p['code_name'] = code_name
        p[poi] = poi_value
        p['metric'] = metric

        return pd.DataFrame(p).T
    except Exception:
        return None

In [None]:

def make_beta_df(df, func, output_name:str, acc_label:str):
    """Make a dataframe of all the betas in each code_name, epoch, and metric (acc, csse)
    df: item level raw data dataframe
    func: function to get the beta for each row (e.g., get_pho_beta, get_sem_beta)
    acc_label: label for the acc column (mean accuracy at a given epoch)
    """
    
    epoch_acc_map = df.loc[df.output_name == output_name].groupby(['code_name', 'epoch']).mean().reset_index()[['code_name', 'epoch', 'acc']]
    epoch_acc_map.columns = ['code_name', 'epoch', acc_label]

    code_names = sorted(df.code_name.unique())
    epochs = sorted(df.epoch.unique())
    metrics = ['acc', 'csse']

    # Do the job
    beta_df = pd.concat([func(df, output_name, code_name, epoch, metric) for code_name in tqdm(code_names) for epoch in epochs for metric in metrics], ignore_index=True)

    beta_df = beta_df.melt(id_vars=['code_name', 'epoch', poi, 'metric'], var_name='param', value_name='beta')
    beta_df = pd.merge(beta_df, epoch_acc_map, on=['code_name', 'epoch'], how='left').dropna()

    return beta_df


In [None]:
pho_beta = make_beta_df(df, get_beta, 'pho', acc_label='pho_acc')

In [None]:
pho_beta.to_csv(f"models/{poi}/pho_beta.csv")

# Get SEM betas

In [None]:
sem_beta = make_beta_df(df, get_beta, 'sem', acc_label='sem_acc')
sem_beta.to_csv(f"models/{poi}/sem_beta.csv")

# Exchange mean accuracy between PHO and SEM

In [None]:
sem_beta = pd.read_csv(f"models/{poi}/sem_beta.csv", index_col=0)
pho_beta = pd.read_csv(f"models/{poi}/pho_beta.csv", index_col=0)

In [None]:
sem_acc_map = sem_beta.groupby(['code_name', 'epoch']).mean().reset_index()[['code_name', 'epoch', 'sem_acc']]
pho_acc_map = pho_beta.groupby(['code_name', 'epoch']).mean().reset_index()[['code_name', 'epoch', 'pho_acc']]

In [None]:
pho_beta = pho_beta.merge(sem_acc_map, on=['code_name', 'epoch'], how='left')
sem_beta = sem_beta.merge(pho_acc_map, on=['code_name', 'epoch'], how='left')

In [None]:
pho_beta.to_csv(f"models/{poi}/pho_beta.csv")
sem_beta.to_csv(f"models/{poi}/sem_beta.csv")

# Plotting

In [None]:
def plot_beta(df, x:str, metric:str, additional_acc: str):
    """Plot beta and save developmental and performance space."""
    df = df.loc[(df.metric == metric) & (df.epoch >= 300)].copy()
    df['epoch'] = df.epoch - 300
 
    selection = alt.selection_multi(fields=['param'], bind='legend')

    # Line of betas
    b = alt.Chart().mark_line(point=True).encode(
        x=f"{x}:Q",
        y="beta:Q",
        color="param:N",
        opacity=alt.condition(selection, alt.value(1), alt.value(0.))
    ).add_selection(selection)

    # Line of additional accuracy
    a = alt.Chart().mark_line(color='black').encode(
        x=f"{x}:Q",
        y=f"mean({additional_acc}):Q",
    )
    
    # Color point to indicate 50 epoch 
    p = (
        alt.Chart()
        .transform_filter(alt.datum.epoch == 100)
        .mark_rule(color='red')
        .encode(x=f"{x}:Q")
    )

    # h-line for easier reference
    l = alt.Chart().mark_rule().encode(y='zero:Q').transform_calculate(zero="0")

    return (
        alt.layer(l, b, p, a, data=df)
        .facet(column=f"{poi}:O")
        .interactive()
    ).properties(title=f"{metric}_by_{x}. Red vertical line indicate epoch == 100")


## Plot PHO

In [None]:
pho_beta = pd.read_csv(f"models/{poi}/pho_beta.csv", index_col=0)
pho_beta = pho_beta.groupby(['epoch', poi, 'metric', 'param']).mean().reset_index()

In [None]:
os.makedirs(f"models/{poi}/plots", exist_ok=True)

In [None]:
plot_beta(pho_beta, x='epoch', metric='acc', additional_acc='sem_acc').save(f"models/{poi}/plots/pho_beta_dev_acc.html")
plot_beta(pho_beta, x='epoch', metric='csse', additional_acc='sem_acc').save(f"models/{poi}/plots/pho_beta_dev_csse.html")
plot_beta(pho_beta, x='pho_acc', metric='acc', additional_acc='sem_acc').save(f"models/{poi}/plots/pho_beta_per_acc.html")
plot_beta(pho_beta, x='pho_acc', metric='csse', additional_acc='sem_acc').save(f"models/{poi}/plots/pho_beta_per_csse.html")

## Plot SEM

In [None]:
sem_beta = pd.read_csv(f"models/{poi}/sem_beta.csv", index_col=0)
sem_beta = sem_beta.groupby(['epoch', poi, 'metric', 'param']).mean().reset_index()

In [None]:
plot_beta(sem_beta, x='epoch', metric='acc', additional_acc='pho_acc').save(f"models/{poi}/plots/sem_beta_dev_acc.html")
plot_beta(sem_beta, x='epoch', metric='csse', additional_acc='pho_acc').save(f"models/{poi}/plots/sem_beta_dev_csse.html")
plot_beta(sem_beta, x='sem_acc', metric='acc', additional_acc='pho_acc').save(f"models/{poi}/plots/sem_beta_per_acc.html")
plot_beta(sem_beta, x='sem_acc', metric='csse', additional_acc='pho_acc').save(f"models/{poi}/plots/sem_beta_per_csse.html")