# Surgery

Move all the weight and biases from MikeNet to TensorFlow

In [None]:
from mnutil import MikeNetWeight
mn_weight = MikeNetWeight("Reading_Weight_v1")

print('\nShapes of weights:')
mn_weight.shape_map

# Create config file

In [None]:
code_name = "surgery"
batch_name = None
tf_root = "/home/jupyter/triangle_model"

# Model configs
ort_units = 364
pho_units = 200
sem_units = 2446
hidden_os_units = 500
hidden_op_units = 500
hidden_ps_units = 300
hidden_sp_units = 300
pho_cleanup_units = 50
sem_cleanup_units = 50
pho_noise_level = 0.
sem_noise_level = 0.
activation = "sigmoid"

tau = 1 / 3
max_unit_time = 4.0
output_ticks = 13
inject_error_ticks = 11

### Parameter below doesn't matter, because we don't train the model

# Training configs
learning_rate = 0.005
zero_error_radius = 0.1
save_freq = 20

# Environment configs
task_names = ["pho_sem", "sem_pho", "pho_pho", "sem_sem", "ort_pho", "ort_sem", "triangle"]
wf_compression = "log"
wf_clip_low = 0
wf_clip_high = 999_999_999
oral_start_pct = 1.0
oral_end_pct = 1.0

total_sample = 1_800_000
tasks_ps = (0.2, 0.2, 0.05, 0.05, .1, .1, .3)

batch_size = 100
rng_seed = 2021

## Build TF model

In [None]:
import meta, modeling
cfg = meta.Config.from_dict(**globals())
model = modeling.MyModel(cfg)
model.build()

# Grafting

In [None]:
import tensorflow as tf

for weight in model.weights:
    try:
        name = weight.name[:-2]
        weight.assign(mn_weight.weights_tf[name])
        print(f"Grafted {name} from mikenet to TensorFlow")

        # Post-load weight sanity check
        tf.debugging.assert_equal(mn_weight.weights_tf[name], weight)

    except KeyError:
        print(f"Missing weight {name} in mikenet")
        pass

### Check mean

In [None]:
[print(f'{w.name} mean: {w.numpy().mean()}') for w in model.weights]

### Save grafted weights model into TF checkpoint format

In [None]:
ckpt = tf.train.Checkpoint(model=model)

ckpt_manager = tf.train.CheckpointManager(
    ckpt,
    cfg.checkpoint_folder,
    max_to_keep=None,  # Keep all checkpoints
    checkpoint_name="epoch",
)

ckpt_manager.save(1)