# Run training of CLRS algorithm

Copied main function from clrs_train.py. Helper functions are located in clrs_train_funcs.py

In [1]:
import functools
import os
import shutil
from typing import Any, Dict, List

import logging
import clrs
import jax
import jax.numpy as jnp
import numpy as np
import requests
import tensorflow as tf
import haiku as hk
import pickle
import sys

import model

import flags
from clrs_train_funcs import *

In [2]:
!nvidia-smi

zsh:1: command not found: nvidia-smi


In [3]:
FLAGS = flags.FLAGS

In [4]:
FLAGS.random_pos = True
# FLAGS.algorithms = ["insertion_sort"]

FLAGS.hidden_size = 128
FLAGS.msg_size = 32

FLAGS.train_steps = 500 # 1000
FLAGS.eval_every = 50
FLAGS.test_every = 500

FLAGS.l1_weight = 0.01

# l1_weight_fn_mult_start = 1
# l1_weight_fn_mult_end = 10
# l1_weight_fn = lambda step: FLAGS.l1_weight * (l1_weight_fn_mult_start + step / FLAGS.train_steps * (l1_weight_fn_mult_end - l1_weight_fn_mult_start))

l1_weight_fn = lambda step: FLAGS.l1_weight

In [5]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError(
        'Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

In [6]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

In [7]:
# Create samplers
(train_samplers,
 val_samplers, val_sample_counts,
 test_samplers, test_sample_counts,
 spec_list) = create_samplers(rng, train_lengths)

2023-03-27 15:48:18,731 - root - INFO - Creating samplers for algo binary_search


Metal device set to: Apple M1 Pro
Using randomized pos for split train




Using randomized pos for split train
Using randomized pos for split train


2023-03-27 15:48:19,385 - absl - INFO - Creating a dataset with 4096 samples.


Using randomized pos for split train
Using randomized pos for split train


2023-03-27 15:48:19,624 - absl - INFO - 1000 samples created
2023-03-27 15:48:19,754 - absl - INFO - 2000 samples created
2023-03-27 15:48:19,901 - absl - INFO - 3000 samples created
2023-03-27 15:48:20,031 - absl - INFO - 4000 samples created
2023-03-27 15:48:20,113 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skipping download.
2023-03-27 15:48:20,114 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-27 15:48:20,116 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-27 15:48:20,117 - absl - INFO - Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0)
2023-03-27 15:48:20,117 - absl - INFO - Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0


Using randomized pos for split val
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Not! using randomized pos for split test


2023-03-27 15:48:20.320261: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [8]:
processor_factory = model.get_processor_factory(
    FLAGS.processor_type,
    use_ln=FLAGS.use_ln,
    nb_triplet_fts=FLAGS.nb_triplet_fts,
    nb_heads=FLAGS.nb_heads
)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=FLAGS.hidden_size,
    msg_dim=FLAGS.msg_size,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init=FLAGS.encoder_init,
    use_lstm=FLAGS.use_lstm,
    learning_rate=FLAGS.learning_rate,
    grad_clip_max_norm=FLAGS.grad_clip_max_norm,
    checkpoint_path=FLAGS.checkpoint_path,
    freeze_processor=FLAGS.freeze_processor,
    dropout_prob=FLAGS.dropout_prob,
    hint_teacher_forcing=FLAGS.hint_teacher_forcing,
    hint_repred_mode=FLAGS.hint_repred_mode,
    nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
    l1_weight=FLAGS.l1_weight
)

eval_model = model.BaselineMsgModel(
    spec=spec_list,
    dummy_trajectory=[next(t) for t in val_samplers],
    **model_params
)

train_model = eval_model

In [9]:
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0

while step < FLAGS.train_steps:
    feedback_list = [next(t) for t in train_samplers]

    # Initialize model.
    if step == 0:
        all_features = [f.features for f in feedback_list]
        if FLAGS.chunked_training:
            # We need to initialize the model with samples of all lengths for
            # all algorithms. Also, we need to make sure that the order of these
            # sample sizes is the same as the order of the actual training sizes.
            all_length_features = [all_features] + [
                [next(t).features for t in train_samplers]
                for _ in range(len(train_lengths))]
            train_model.init(all_length_features[:-1], FLAGS.seed + 1)
        else:
            train_model.init(all_features, FLAGS.seed + 1)

    train_model.l1_weight = l1_weight_fn(step)
    # Training step.
    for algo_idx in range(len(train_samplers)):
        feedback = feedback_list[algo_idx]
        rng_key, new_rng_key = jax.random.split(rng_key)
        if FLAGS.chunked_training:
            # In chunked training, we must indicate which training length we are
            # using, so the model uses the correct state.
            length_and_algo_idx = (length_idx, algo_idx)
        else:
            # In non-chunked training, all training lengths can be treated equally,
            # since there is no state to maintain between batches.
            length_and_algo_idx = algo_idx
        cur_loss = train_model.feedback(
            rng_key, feedback, length_and_algo_idx)
        rng_key = new_rng_key

        if FLAGS.chunked_training:
            examples_in_chunk = np.sum(feedback.features.is_last).item()
        else:
            examples_in_chunk = len(feedback.features.lengths)
        current_train_items[algo_idx] += examples_in_chunk
        if step % 50 == 0:
            logging.info('Algo %s step %i current loss %f, current_train_items %i, l1_weight %f.',
                        FLAGS.algorithms[algo_idx], step,
                        cur_loss, current_train_items[algo_idx], train_model.l1_weight)

    # Periodically evaluate model
    if step >= next_eval:
        eval_model.params = train_model.params
        for algo_idx in range(len(train_samplers)):
            common_extras = {'examples_seen': current_train_items[algo_idx],
                             'step': step,
                             'algorithm': FLAGS.algorithms[algo_idx]}

            # Validation info.
            new_rng_key, rng_key = jax.random.split(rng_key)
            val_stats = collect_and_eval(
                val_samplers[algo_idx],
                functools.partial(eval_model.predict,
                                  algorithm_index=algo_idx),
                val_sample_counts[algo_idx],
                new_rng_key,
                extras=common_extras)
            logging.info('(val) algo %s step %d: %s',
                         FLAGS.algorithms[algo_idx], step, val_stats)
            val_scores[algo_idx] = val_stats['score']

        next_eval += FLAGS.eval_every

        # If best total score, update best checkpoint.
        # Also save a best checkpoint on the first step.
        msg = (f'best avg val score was '
               f'{best_score/len(FLAGS.algorithms):.3f}, '
               f'current avg val score is {np.mean(val_scores):.3f}, '
               f'val scores are: ')
        msg += ', '.join(
            ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
        if (sum(val_scores) > best_score) or step == 0:
            best_score = sum(val_scores)
            logging.info('Checkpointing best model, %s', msg)
            train_model.save_model('best.pkl')
        else:
            logging.info('Not saving new best model, %s', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

logging.info('Checkpointing final model, %s', msg)
train_model.save_model('final.pkl')

logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)
# logging.info('Restoring final model from checkpoint...')
# eval_model.restore_model('final.pkl', only_load_processor=False)

2023-03-27 15:50:44,629 - root - INFO - Algo binary_search step 0 current loss 7.124762, current_train_items 32, l1_weight 0.010000.
2023-03-27 15:50:50,685 - root - INFO - (val) algo binary_search step 0: {'return': 0.127197265625, 'score': 0.127197265625, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-27 15:50:50,686 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.127, val scores are: binary_search: 0.127
2023-03-27 15:51:07,568 - root - INFO - Algo binary_search step 50 current loss 2.920558, current_train_items 1632, l1_weight 0.010000.
2023-03-27 15:51:13,170 - root - INFO - (val) algo binary_search step 50: {'return': 0.533203125, 'score': 0.533203125, 'examples_seen': 1632, 'step': 50, 'algorithm': 'binary_search'}
2023-03-27 15:51:13,171 - root - INFO - Checkpointing best model, best avg val score was 0.127, current avg val score is 0.533, val scores are: binary_search: 0.533
2023-03-27 15:51:14,657 - ro

In [13]:
def restore_model(model, file_name):
    """Restore model from `file_name`."""
    with open(file_name, 'rb') as f:
        restored_state = pickle.load(f)
        restored_params = restored_state['params']
        model.params = hk.data_structures.merge(restored_params)
        model.opt_state = restored_state['opt_state']

def save_model(model, file_name):
    """Save model (processor weights only) to `file_name`."""
    to_save = {'params': model.params, 'opt_state': model.opt_state}
    with open(file_name, 'wb') as f:
        pickle.dump(to_save, f)

In [None]:
# save_model(eval_model, 'eval_model_1e-3.pkl')

In [None]:
# restore_model(eval_model, 'eval_model_asdf.pkl')

In [14]:
algo_idx = 0
common_extras = {}

new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
    val_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    val_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(val_stats)

{'return': 0.8916015625, 'score': 0.8916015625}


In [15]:
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
    test_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    test_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(test_stats)

{'return': 0.419921875, 'score': 0.419921875}


# inspect inputs

In [16]:
feedback = next(val_samplers[0])

In [17]:
feedback.features.inputs

(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)),
 DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 16)))

In [18]:
feedback.features.inputs[0].data

array([[0.03216641, 0.0887491 , 0.09141461, 0.14151338, 0.23066313,
        0.41496996, 0.44112883, 0.49271468, 0.6238474 , 0.62963874,
        0.63477921, 0.69024592, 0.77975909, 0.80418882, 0.94906204,
        0.9518903 ],
       [0.04055783, 0.09127235, 0.09635523, 0.13762894, 0.16389643,
        0.17309541, 0.20917447, 0.26774559, 0.38516773, 0.56366763,
        0.57524694, 0.69341529, 0.7526364 , 0.80916532, 0.82993744,
        0.90335583],
       [0.07683136, 0.15108231, 0.17222616, 0.18134717, 0.20762323,
        0.30275832, 0.53356587, 0.57324225, 0.58404184, 0.60452041,
        0.70197108, 0.75442818, 0.84153007, 0.92058679, 0.92677261,
        0.99021764],
       [0.19301804, 0.25978994, 0.27642331, 0.28177089, 0.30299001,
        0.30395177, 0.32041608, 0.35924949, 0.43099866, 0.61091876,
        0.63926805, 0.68065744, 0.78670552, 0.87959477, 0.90751811,
        0.98618097],
       [0.01880814, 0.2642722 , 0.26856692, 0.33684816, 0.45190865,
        0.50456194, 0.52807881, 

In [19]:
feedback.features.hints

(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="high",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(5, 32, 16)))

# get messages

In [20]:
new_rng_key, rng_key = jax.random.split(rng_key)
batched_msgs = get_msgs(
    sampler=val_samplers[0],
    predict_fn=functools.partial(eval_model.predict, algorithm_index=0),
    sample_count=val_sample_counts[0],
    rng_key=new_rng_key,
    sample_prob=None
)

batched_msgs.shape

(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random samp

Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Thresholding: 3840 of 32768
Random sampling: 3840 of 32768
(32, 4, 16, 16, 32) 32 12
Threshol

(491520, 44)

In [21]:
print(sys.getsizeof(batched_msgs) / 1000000)

173.015168


In [25]:
np.set_printoptions(suppress=True)
batched_msgs[0], batched_msgs[1], batched_msgs[2], batched_msgs[3]

(array([ 0.10069926,  0.02468263, -0.07856453,  0.03372188, -0.14106266,
        -0.1149468 , -0.02804468, -0.32331645, -0.07613028,  0.30354437,
         0.01924583, -0.41064122, -0.07536419, -0.44078687, -0.00260308,
         0.2742092 ,  0.22355607, -0.00355166, -0.08645447,  0.02887585,
        -0.99488211, -0.03649637,  0.16388178, -0.48418835, -0.00866786,
        -0.32873982, -0.17630227,  0.42477337,  0.33064786,  0.60467374,
         0.06773275, -0.01688233,  0.05084593,  0.00986762,  0.32295883,
         0.31027162,  0.75508583,  0.        ,  0.        ,  1.        ,
         0.        ,  0.        ,  0.        ,  0.        ]),
 array([ 0.19582251,  0.07174149, -0.20112674,  0.17868218,  0.08243616,
        -0.10620342, -0.0191057 , -0.25672385, -0.06998946,  0.09868141,
         0.05623136, -0.19035122,  0.07618684, -0.1113157 ,  0.02884557,
        -0.04922082, -0.05250286, -0.05462624, -0.21402006, -0.10652531,
        -0.95267791, -0.1359971 ,  0.11252074, -0.52002478, -0

In [None]:
file_name = 'binary_search_val_msgs.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(batched_msgs, f)