# Surgery

Purpose: Move all the weight and biases from MikeNet to TensorFlow

In [None]:
import troubleshooting, meta, modeling
from importlib import reload
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
reload(troubleshooting)
mn_weight = troubleshooting.MikeNetWeight("mikenet/Reading_Weight_v1")  # They have v1-v10, looks similar to me

In [None]:
def plot_weight_heatmap(weight_name):
    """Plot weight in 3 orders, to help guessing MN weight order format"""

    fig, axs = plt.subplots(1, 3, figsize=(10, 6))

    for ax, order in zip(axs, ['F', 'C', 'A']):
        w = np.reshape(mn_weight.weights[weight_name], mn_weight.shape_map[weight_name], order=order)
        ax.imshow(w, cmap='hot', interpolation='none')
        ax.set_title(f"{weight_name}: {order}")

    fig.patch.set_facecolor('white')



In [None]:
plot_weight_heatmap('Ortho -> osh')

Use Fortran ordering

## Shapes of weights in MikeNet model

In [None]:
mn_weight.create_weights_shapes()

# Create config file

In [None]:
code_name = "surgery"
batch_name = None
tf_root = "/home/jupyter/triangle_model"

# Model configs
ort_units = 364
pho_units = 200
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 500
hidden_ps_units = 300
hidden_sp_units = 300
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.
sem_noise_level = 0.
activation = "sigmoid"

tau = 1 / 3
max_unit_time = 4.0
output_ticks = 13
inject_error_ticks = 11

# Training configs
learning_rate = 0.005
zero_error_radius = 0.1
save_freq = 20

# Environment configs
tasks = ("pho_sem", "sem_pho", "pho_pho", "sem_sem", "ort_pho", "ort_sem", "triangle")
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 1.0
oral_end_pct = 1.0

oral_sample = 1_800_000
# oral_tasks_ps = (0.4, 0.4, 0.1, 0.1, 0.)
oral_tasks_ps = (0.4, 0.4, 0.05, 0.15, 0., 0., 0.)
transition_sample = 800_000
reading_sample = 15_000_000
# reading_tasks_ps = (0.2, 0.2, 0.05, 0.05, 0.5)
reading_tasks_ps = (0.2, 0.2, 0.05, 0.05, .1, .1, .3)

batch_size = 100
rng_seed = 2021

## Build model

In [None]:
cfg = meta.ModelConfig.from_global(globals_dict=globals())
model = modeling.MyModel(cfg)
model.build()

# Grafting

In [None]:
for weight in model.weights:
    try:
        name = weight.name[:-2]
        weight.assign(mn_weight.weights_tf[name])
        print(f"Grafted mikenet weight {name} to tf.weights")

        # Post-load weight sanity check
        tf.debugging.assert_equal(mn_weight.weights_tf[name], weight)

    except KeyError:
        print(f"Missing weight {name} in mikenet")
        pass



In [None]:
# Check grafted weight is the correct one
def my_heatmap(x):
    plt.imshow(x, cmap='hot', interpolation='none')
    plt.colorbar()


[my_heatmap(w) for w in model.weights if w.name.startswith("w_hos_oh")]


# Parse MN pattern into TF format

# Input

In [None]:
import evaluate
from importlib import reload
reload(evaluate)
testset = mn_r100
testset_name = 'mn_r100'
task = 'triangle'

model.set_active_task('triangle')
input_name = modeling.IN_OUT[task][0]
pred = model([testset[input_name]]* cfg.n_timesteps)

testset = evaluate.TestSet(cfg)
df = testset.eval(testset_name, task)


In [None]:
df

In [None]:
import altair as alt

alt.Chart(df).mark_line().encode(
    x='timetick:O',
    y="act1:Q"
)

In [None]:
from IPython.display import clear_output
from ipywidgets import interact


In [None]:
reload(troubleshooting)
d = troubleshooting.Diagnosis(code_name)
d.eval(testset_name, task='triangle', epoch=0)
sel_word = 'wasps'
d.set_target_word(sel_word)

In [None]:
d.word_pho

In [None]:
reload(troubleshooting)
d = troubleshooting.Diagnosis(code_name)
d.eval(testset_name, task='triangle', epoch=0)

@interact(
    sel_word=d.testset_package['item'], 
    layer=['pho', 'sem'], 
    task=['triangle', 'ort_pho', 'exp_osp', 'ort_sem', 'exp_ops'], 
    )
def interactive_plot(sel_word, layer, task):
    d = troubleshooting.Diagnosis(code_name)
    d.eval(testset_name, task=task, epoch=0)
    d.set_target_word(sel_word)
    # print(f"Output phoneme over timeticks: {d.list_output_phoneme}")
    return d.plot_one_layer(layer)
