# Examine a single word  


In [None]:
import os
import altair as alt
from tqdm import tqdm
import troubleshooting 
import data_wrangling
from IPython.display import clear_output
from ipywidgets import interact

# Intactive input/act

# Examine one

In [None]:
code_name = 'triangle_hope1'
testset_name = 'train_r100'
d = troubleshooting.Diagnosis(code_name)
d.eval(testset_name, task='triangle', epoch=200)

@interact(
    sel_word=d.testset_package['item'], 
    layer=['pho', 'sem'], 
    task=['triangle', 'ort_pho', 'exp_osp', 'ort_sem', 'exp_ops'], 
    epoch=(10, d.cfg.total_number_of_epoch + 1, d.cfg.save_freq)
    )
def interactive_plot(sel_word, layer, task, epoch):
    d = troubleshooting.Diagnosis(code_name)
    d.eval(testset_name, task=task, epoch=epoch)
    d.set_target_word(sel_word)
    print(f"Output phoneme over timeticks: {d.list_output_phoneme}")
    return d.plot_one_layer(layer)

# Compare bias init -5 vs. 0

In [None]:
a = "triangle_hope_fix"
b = "triangle_batch1"

def compare_plot(sel_word:str, layer:str, target_nodes:int, task:str, epoch:int) -> alt.Chart:
    d1 = troubleshooting.Diagnosis(a)
    d1.eval(testset_name, task=task, epoch=epoch)
    d1.set_target_word(sel_word)

    d2 = troubleshooting.Diagnosis(b)
    d2.eval(testset_name, task=task, epoch=epoch)
    d2.set_target_word(sel_word)
    # print(f"Output phoneme over timeticks: {d.list_output_phoneme}")
    plot1 = d1.plot_one_layer_by_target(layer, target_act=target_nodes).properties(title=d1.code_name)
    plot2 = d2.plot_one_layer_by_target(layer, target_act=target_nodes).properties(title=d2.code_name)

    return plot1 & plot2

testset_name = 'train_r100'
ts = data_wrangling.load_testset(
            os.path.join("dataset", "testsets", f"{testset_name}.pkl.gz")
        )

# Generate all plots to files
for word in tqdm(ts['item']):
    for layer in ['pho', 'sem']:
        for target_node in [0, 1]:
            clear_output(wait=True)
            output_path = os.path.join('issues', 'compare_batchsize_1', f"{layer}_{word}_node{target_node}_compare.html")
            compare_plot(word, layer, target_node, task='triangle', epoch=780).save(output_path)

# Tick 1 checking 

## Cleanup checking

In [None]:
sel_word = 'wasps'
d.set_target_word(sel_word)
print(f"All outputs: {d.list_outputs}\n")
print(f"All weights: {d.list_weights}")

# Get weights and biases
sc = d.get_weight("w_sc")
bias = d.get_weight("bias_css")

# Manually caluclate tick 1 TAI input 
sem_init = 0.5 * np.ones(sc.shape[0])
manual_css_1 = (np.matmul(sem_init, sc) + bias)/3

# Compare against the output of the model at tick 1
input_css = d.get_output('input_css')
print(f"From model: \n {input_css[1]} \n \n Checking: \n {manual_css_1}")  
np.allclose(input_css[1], manual_css_1, atol=1e-3)


In [None]:
import numpy as np
import pandas as pd
import altair as alt

def act_to_df(activation: np.ndarray) -> pd.DataFrame:
    """Convert an activation numpy 2d array into a labeled pandas dataframe """
    df = pd.DataFrame(activation)
    df['tick'] = df.index
    return df.melt(id_vars=['tick'], var_name='node', value_name='activation')

def plot_activation(activation_df: pd.DataFrame) -> alt.Chart:
    plot = alt.Chart(activation_df).mark_line().encode(
        x='tick:O',
        y=alt.Y('activation:Q', scale=alt.Scale(domain=(0,1))),
        color='node:N',
    )
    return plot

df = act_to_df(cs)
plot_activation(df.loc[df.node==4])

## SEM

In [None]:
w_hos_hs = d.get_weight(name="w_hos_hs")
w_ss = d.get_weight(name="w_ss")
bias_s = d.get_weight(name="bias_s")
w_cs = d.get_weight(name="w_cs")

data = {
    'ss1': 0.5 * np.sum(w_ss, axis=0), # Lazy matmul. 
    'cs1': 0.5 * np.sum(w_cs, axis=0),
    'os1': 0.5 * np.sum(w_hos_hs, axis=0),
    'bias1': bias_s
}

df = pd.DataFrame.from_dict(data)

fig, ax = plt.subplots(1, 4, figsize=(15,6))
for i, k in enumerate(data.keys()):
    df[k].plot.density(ax=ax[i], title=k)
    


### Tick 1 input density at PHO

In [None]:
w_hop_hp = d.get_weight(name="w_hop_hp")
w_pp = d.get_weight(name="w_pp")
bias_p = d.get_weight(name="bias_p")
w_cp = d.get_weight(name="w_cp")

data = {
    'pp1': 0.5 * np.sum(w_pp, axis=0), # Lazy matmul. 
    'cp1': 0.5 * np.sum(w_cp, axis=0),
    'op1': 0.5 * np.sum(w_hop_hp, axis=0),
    'bias1': bias_p
}

df = pd.DataFrame.from_dict(data)

fig, ax = plt.subplots(1, 4, figsize=(15,6))
for i, k in enumerate(data.keys()):
    df[k].plot.density(ax=ax[i], title=k)
    
