# Examine effect of error injection timing and rng seed variance
I have ran 10 models per (error_injection_timing)
- fewer error ticks : 2 (11 - 12 tick)
- more error ticks : 11 (2 - 12 tick)
- Phase 1 training is done with the same error tick setting


static_hpar = {
    "tf_root": "/home/jupyter/tf",
    "ort_units": 119,
    "pho_units": 250,
    "sem_units": 2446,
    "hidden_os_units": 500,
    "hidden_op_units": 100,
    "hidden_ps_units": 500,
    "hidden_sp_units": 500,
    "pho_cleanup_units": 20,
    "sem_cleanup_units": 50,
    "pho_noise_level": 0.0,
    "sem_noise_level": 0.0,
    "activation": "sigmoid",
    "tau": 1 / 3,
    "max_unit_time": 4.0,
    "output_ticks": 11,
    "learning_rate": 0.01,
    "zero_error_radius": 0.1,
    "n_mil_sample": 2.0,
    "batch_size": 100,
    "save_freq": 10,
    "batch_name": batch_name,
}

In [None]:
%reload_ext lab_black
import sqlite3
import pandas as pd

In [None]:
batch_name = "error_injection_timing_test"
con = sqlite3.connect(
    f"/home/jupyter/tf/models/batch_run/{batch_name}/batch_results.sqlite"
)
cur = con.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
print(f'Result tables: {cur.fetchall()}')

In [None]:
batch_config = pd.read_sql_query("SELECT * FROM batch_config", con)

In [None]:
query = """
    SELECT 
        code_name, 
        epoch, 
        timetick, 
        y, 
        testset, 
        AVG(acc) as acc, 
        AVG(sse) as sse, 
        AVG(conditional_sse) as conditional_sse
    FROM strain
    GROUP BY
        code_name, 
        epoch,
        timetick,
        y,
        testset
    """

strain_df = pd.read_sql_query(query, con)


In [None]:
# Caution, there are 4 y_test in grain, ['pho_small_grain', 'pho_large_grain', 'pho', 'sem']
# pho is pho_small_grain or pho_large_grain

query = """
    SELECT 
        code_name, 
        epoch, 
        timetick, 
        y, 
        testset, 
        AVG(acc) as acc, 
        AVG(sse) as sse, 
        AVG(conditional_sse) as conditional_sse
    FROM grain
    WHERE y_test IN ('pho', 'sem')
    GROUP BY
        code_name, 
        epoch,
        timetick,
        y,
        testset
    """

grain_df = pd.read_sql_query(query, con)


In [None]:
strain_df['wordness'] = 'word'
grain_df['wordness'] = 'nonword'

df = pd.concat([strain_df, grain_df])
df = df.merge(batch_config[['code_name', 'inject_error_ticks']], "left", on="code_name")

# Average by wordness
mean_df = df.groupby(['code_name', 'epoch', 'timetick', 'y', 'wordness']).mean().reset_index()


## Looking at Phonology output
- Similar to what we found in O2P model
    - Late timetick --> Similar between inject_error_ticks
    - At earlier time tick, inject_error_ticks has more influence on word than nonword
- Somewhat differnt from O2P model
    - More inject_error_ticks is relatively more stable over time ticks (vs. less inject_error, and vs. O2P model)


In [None]:
import altair as alt

stress_contrast_df = mean_df.groupby(['inject_error_ticks', 'epoch', 'timetick', 'wordness', 'y']).mean().reset_index()

def stress_plot(df, variable_of_interest):
    """Plot Hi time stress vs. Low time stress in variable of interest"""

    timetick_selection = alt.selection_single(
        bind=alt.binding_range(min=0, max=12, step=1),
        fields=["timetick"],
        init={"timetick": 12},
        name="timetick",
    )

    return alt.Chart(df).mark_line().encode(
        x='epoch:Q',
        y=f'{variable_of_interest}:Q',
        color='wordness:N',
        column='inject_error_ticks:N',
    ).add_selection(timetick_selection).transform_filter(timetick_selection)

stress_plot(stress_contrast_df.loc[stress_contrast_df.y=='pho'], 'acc')

## Looking at Semantic output
- Less inject_error_tick just kill the accuracy in the timeticks that without error injection (e.g., tick 9: 0%, tick 10: 70%)

In [None]:
stress_plot(stress_contrast_df.loc[(stress_contrast_df.y=='sem') & (stress_contrast_df.wordness=='word')], 'acc')

# Variance

In [None]:
hs_df = df.loc[df.inject_error_ticks==11]

mean_df = hs_df.groupby(['code_name', 'epoch', 'timetick', 'wordness', 'y']).mean().reset_index()
variance_df = mean_df.groupby(['epoch', 'timetick', 'wordness', 'y']).var().reset_index()


alt.Chart(variance_df).mark_rect().encode(
        x='epoch:O',
        y='timetick:O',
        color=alt.Color('acc', scale=alt.Scale(domain=(0, 0.008))),
        column='wordness',
        row='y',
        tooltip=['acc']
    ).properties(title='RNG Seed variance: Variance in the mean of wordness over Epoch and Timetick')
    




- ignore bottom-left panel
- similar pattern as O2P, but slightly higher variance?

## Zoom-in: last timetick

In [None]:
sel_df = mean_df.loc[mean_df.timetick==12,]

alt.Chart(sel_df).mark_line().encode(
    x='epoch:Q',
    y='acc:Q',
    column='wordness',
    row='y',
    color='code_name'
)

- Variance related to accuracy, it is highest when acc near 0.7-0.8


### Phonology output variance by condition

In [None]:
mean_df = hs_df.groupby(['code_name', 'epoch', 'timetick', 'testset', 'y']).mean().reset_index()
variance_df = mean_df.groupby(['epoch', 'timetick', 'testset', 'y']).var().reset_index()

alt.Chart(variance_df.loc[variance_df.y=='pho']).mark_rect().encode(
    x='epoch:O',
    y='timetick:O',
    color=alt.Color('acc', scale=alt.Scale(domain=(0, 0.02))),
    row='testset',
    tooltip=['acc']
).properties(title='RNG Seed variance: Variance by condition over Epoch and Timetick')

### Semantic output variance by condition

In [None]:
strain_subtests = ('strain_hf_con_hi', 'strain_hf_con_li', 'strain_hf_inc_hi', 'strain_hf_inc_li',
       'strain_lf_con_hi', 'strain_lf_con_li', 'strain_lf_inc_hi', 'strain_lf_inc_li')

sem_variance_plot_df = variance_df.loc[(variance_df.y=='sem') & variance_df.testset.isin(strain_subtests)]


alt.Chart(sem_variance_plot_df).mark_rect().encode(
    x='epoch:O',
    y='timetick:O',
    color=alt.Color('acc', scale=alt.Scale(domain=(0, 0.02))),
    row='testset',
    tooltip=['acc']
).properties(title='RNG Seed variance: Variance by condition over Epoch and Timetick')

### Last timetick phonology by condition

In [None]:
sel_df = mean_df.loc[(mean_df.timetick==12) & (mean_df.y=="pho"),]

alt.Chart(sel_df).mark_line().encode(
    x='epoch:Q',
    y='acc:Q',
    row='testset:N',
    color='code_name'
)

### Last time tick semantic by condition

In [None]:
sel_df = mean_df.loc[(mean_df.timetick==12) & (mean_df.y=="sem") & (mean_df.testset.isin(strain_subtests)),]

alt.Chart(sel_df).mark_line().encode(
    x='epoch:Q',
    y='acc:Q',
    row='testset:N',
    color='code_name'
)

- Semantics has a stonger variance
- Maybe need more run per cell than O2P