# This notebook uses Papermill to batch run models

In [None]:
%load_ext lab_black
import papermill as pm
import multiprocessing
import pandas as pd
import altair as alt
import os, itertools, json
from meta import model_cfg, connect_gbq
from tqdm import tqdm
from evaluate import vis, make_df_wnw

## Batch configurations

In [None]:
# import random
# seeds = [int(random.random() * 1e5) for x in range(10)]

batch_name = "O2P_noreg_pch"

batch_cfgs = []
i = 0

param_grid = {
    'p_noise': [0., 1., 2., 4.],
    'hidden_units': [25, 100],
    'cleanup_units': [10, 20, 50, 100]
}

varying_hpar_names, varying_hpar_values = zip(*param_grid.items())

for v in itertools.product(*varying_hpar_values):
    i += 1
    code_name = batch_name + "_model_{:04d}".format(i)

    # create this param dict (From varying hparameters)
    this_hpar = dict(zip(varying_hpar_names, v))

    # Add static params into param dict
    this_hpar['code_name'] = code_name
    this_hpar['bq_dataset'] = batch_name

    this_hpar['sample_name'] = 'hs04'
    this_hpar['sample_rng_seed'] = 4321
    this_hpar['tf_rng_seed'] = 4444
    this_hpar['use_semantic'] = False
    this_hpar['input_dim'] = 119
    #     this_hpar['hidden_units'] = 100
    this_hpar['output_dim'] = 250
    #     this_hpar['cleanup_units'] = 50
    this_hpar['use_attractor'] = False
    this_hpar['rnn_activation'] = 'sigmoid'
    this_hpar['regularizer_const'] = None
    #     this_hpar['p_noise'] = 0.
    this_hpar['tau'] = 0.2
    this_hpar['max_unit_time'] = 4.
    this_hpar['n_mil_sample'] = 1.
    this_hpar['batch_size'] = 128
    this_hpar['learning_rate'] = 0.005
    this_hpar['save_freq'] = 10

    batch_cfg = dict(
        sn=i,
        in_notebook="OSP_master.ipynb",
        code_name=code_name,
        model_folder="models/" + code_name + "/",
        out_notebook="models/" + code_name + "/output.ipynb",
        params=this_hpar
    )

    batch_cfgs.append(batch_cfg)

## Run batch

In [None]:
# Run
def run_batch(cfg):
    try:
        print("Running model {}".format(cfg['sn']))

        if not os.path.exists(cfg['model_folder']):
            os.mkdir(cfg['model_folder'])

        pm.execute_notebook(
            cfg['in_notebook'],
            cfg['out_notebook'],
            parameters=cfg['params'],
        )

    except:
        print("Error occur in {}".format(cfg['code_name']))


# Run in parallel pool
with multiprocessing.Pool(4) as pool:
    pool.map(run_batch, batch_cfgs)

In [None]:
# Push cfgs to BQ
bq = connect_gbq()
bq.push_cfgs(batch_name, batch_cfgs)

#### Shutdown compute engine

In [None]:
from time import sleep
sleep(30)
!sudo poweroff  

## Compile results

In [None]:
import os, json
import pandas as pd
from meta import connect_gbq
from evaluate import vis
from tqdm import tqdm

conn = connect_gbq()
cfgs = conn.read_bq_cfg(batch_name)

# Read cfg files from BQ
print('===== Batch level hyperparams (columns that have >1 unique value) =====')
for i, x in enumerate(cfgs.columns):
    if not x in ['code_name', 'uuid']:
        if len(cfgs[x].unique()) > 1:
            print(
                'Column <{}> has these unique values: {}'.format(
                    x, cfgs[x].unique()
                )
            )

# Parse each run by batch_eval, which aggregate item level data to condition level
# and merge Grain and Strain into one single file (Using local files instead of BQ,
# may use BQ for way way more data... >5Gbs I guess)


def parse_batch_results(cfgs, batch_name):
    """
    Parse and Concat all condition level results
    And merge with cfg data (run level meta data) from cfgs
    """
    batch_cdf = pd.DataFrame()

    for i in tqdm(range(len(cfgs))):

        model_path = 'models/' + batch_name + '_model_{0:04d}'.format(i + 1)

        this_eval = vis(
            model_path, 'result_strain_item.csv', 'result_grain_item.csv'
        )  # Eval lesion and grain
        this_eval.parse_cond_df()
        batch_cdf = pd.concat([batch_cdf, this_eval.cdf], ignore_index=True)

    return pd.merge(batch_cdf, cfgs, 'left', 'code_name')


df = parse_batch_results(cfgs, batch_name)
os.makedirs('batch_eval/{}'.format(batch_name))
df.to_hdf('batch_eval/{}/bcdf.h5'.format(batch_name), key='df', mode='w')

# Plotting

Create re-useable overview heatmap and word vs. nonword df

In [None]:
alt.data_transformers.enable("default")
alt.data_transformers.disable_max_rows()

df = pd.read_hdf('batch_eval/{}/bcdf.h5'.format(batch_name), 'df')

# Selectors for interactions
sel_run = alt.selection(type="multi", on="click", fields=["code_name"])
sel_cond = alt.selection(
    type="multi", on="click", fields=["cond"], bind="legend"
)

# df for overview
df_ov = df[(df.epoch == df.epoch.max()) & (df.timestep == df.timestep.max())]

# Shared master over-view
overview = (
    alt.Chart(df_ov).mark_rect().encode(
        x="p_noise:O",
        y="cleanup_units:O",
        row="hidden_units:O",
        color=alt.Color("acc", scale=alt.Scale(scheme="redyellowgreen")),
        opacity=alt.condition(sel_run, alt.value(1), alt.value(0.1)),
        tooltip=["code_name", "acc"],
    ).add_selection(sel_run).properties(title="Overall accuracy")
)

# Accuracy Word (HF-INC) vs. Nonwords
df_wnw = make_df_wnw(df, selected_cond=['INC_HF', 'ambiguous', 'unambiguous'])

### Single run plots

In [None]:
# Accuracy over epoch at last time step for selected model
df_laststep = df[df.timestep == df.timestep.max()]

acc_plot = (
    alt.Chart(df_laststep).mark_line(point=True).encode(
        y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
        x="epoch",
        color="cond",
        opacity=alt.condition(sel_cond, alt.value(1), alt.value(0.1)),
        tooltip=["code_name", "acc"],
    ).add_selection(sel_cond).transform_filter(sel_run).properties(
        title="Full model at final time step"
    )
)

wnw_plot = (
    alt.Chart(df_wnw).mark_point().encode(
        y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
        x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
        color=alt.Color("epoch", scale=alt.Scale(scheme="redyellowgreen")),
        tooltip=["epoch", "word_acc", "nonword_acc"],
    ).transform_filter(sel_run).properties(
        title="Word vs. Nonword accuracy at final time step"
    )
)

# Plot diagonal
diagline = alt.Chart(pd.DataFrame({
    'x': [0, 1],
    'y': [0, 1]
})).mark_line().encode(x='x', y='y')

wnw_with_diag = wnw_plot + diagline

# overview = overview_strain & overview_grain
mainplots = acc_plot & wnw_with_diag
splot = overview | mainplots

splot.save('batch_eval/{}/single_run.html'.format(batch_name))
splot

### Multi runs plots

In [None]:
wnw_mdf = df_wnw.melt(
    id_vars=['code_name', 'epoch'],
    value_vars=['word_acc', 'nonword_acc'],
    var_name='wnw',
    value_name='acc'
)

wnw_mdf['group'] = wnw_mdf.code_name.str.slice(-4
                                              ) + '_' + wnw_mdf.wnw.str.slice(
                                                  stop=-4
                                              )

plot_epoch = alt.Chart(wnw_mdf).mark_line(point=True).encode(
    y=alt.Y("acc:Q", scale=alt.Scale(domain=(0, 1))),
    x="epoch:Q",
    color="group:N",
    tooltip=["code_name", "epoch", "acc"],
).transform_filter(sel_run).properties(
    title="Plot word and nonword accuracy by epoch"
)

plot_wnw = alt.Chart(df_wnw).mark_line(point=True).encode(
    y=alt.Y("nonword_acc:Q", scale=alt.Scale(domain=(0, 1))),
    x=alt.X("word_acc:Q", scale=alt.Scale(domain=(0, 1))),
    color="code_name:N",
    opacity=alt.condition(sel_run, alt.value(1), alt.value(0.1)),
    tooltip=["code_name", "epoch", "word_acc", "nonword_acc"],
).add_selection(sel_run).transform_filter(sel_run).properties(
    title="Word vs. Nonword accuracy at final time step"
)

multi_plot = overview | (plot_epoch & plot_wnw)
multi_plot.save('batch_eval/{}/multi_runs.html'.format(batch_name))
multi_plot

In [None]:
batch_dir = 'batch_eval/{}'.format(batch_name)
!jupyter nbconvert --output-dir=$batch_dir --to html batch.ipynb