# Run training of CLRS algorithm

Copied main function from clrs_train.py. Helper functions are located in clrs_train_funcs.py

In [1]:
import functools
import os
import shutil
from typing import Any, Dict, List

import logging
import clrs
import jax
import numpy as np
import requests
import tensorflow as tf
import haiku as hk

import model
import flags
from clrs_train_funcs import *

In [2]:
FLAGS = flags.FLAGS

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

In [4]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError(
        'Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

2023-03-08 02:13:06,764 - jax._src.lib.xla_bridge - INFO - Remote TPU is not linked into jax; skipping remote TPU.
2023-03-08 02:13:06,765 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
2023-03-08 02:13:06,765 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-08 02:13:06,765 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-08 02:13:06,766 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-03-08 02:13:06,766 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (d

In [5]:
# Create samplers
(train_samplers,
 val_samplers, val_sample_counts,
 test_samplers, test_sample_counts,
 spec_list) = create_samplers(rng, train_lengths)

2023-03-08 02:13:07,306 - root - INFO - Creating samplers for algo binary_search


Metal device set to: Apple M1 Pro


2023-03-08 02:13:07,957 - absl - INFO - Creating a dataset with 4096 samples.
2023-03-08 02:13:08,175 - absl - INFO - 1000 samples created
2023-03-08 02:13:08,302 - absl - INFO - 2000 samples created
2023-03-08 02:13:08,434 - absl - INFO - 3000 samples created
2023-03-08 02:13:08,563 - absl - INFO - 4000 samples created
2023-03-08 02:13:08,642 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skipping download.
2023-03-08 02:13:08,643 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-08 02:13:08,645 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-08 02:13:08,646 - absl - INFO - Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0)
2023-03-08 02:13:08,646 - absl - INFO - Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
2023-03-08 02:13:08.842100: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
processor_factory = model.get_processor_factory(
    FLAGS.processor_type,
    use_ln=FLAGS.use_ln,
    nb_triplet_fts=FLAGS.nb_triplet_fts,
    nb_heads=FLAGS.nb_heads
)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=FLAGS.hidden_size,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init=FLAGS.encoder_init,
    use_lstm=FLAGS.use_lstm,
    learning_rate=FLAGS.learning_rate,
    grad_clip_max_norm=FLAGS.grad_clip_max_norm,
    checkpoint_path=FLAGS.checkpoint_path,
    freeze_processor=FLAGS.freeze_processor,
    dropout_prob=FLAGS.dropout_prob,
    hint_teacher_forcing=FLAGS.hint_teacher_forcing,
    hint_repred_mode=FLAGS.hint_repred_mode,
    nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
    l1_weight=FLAGS.l1_weight
)

eval_model = model.BaselineMsgModel(
    spec=spec_list,
    dummy_trajectory=[next(t) for t in val_samplers],
    **model_params
)
# # we will never used chunked training
# if FLAGS.chunked_training:
#     train_model = clrs.models.BaselineModelChunked(
#         spec=spec_list,
#         dummy_trajectory=[next(t) for t in train_samplers],
#         **model_params
#     )
# else:
#     train_model = eval_model
train_model = eval_model

In [11]:
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0

while step < FLAGS.train_steps:
    feedback_list = [next(t) for t in train_samplers]

    # Initialize model.
    if step == 0:
        all_features = [f.features for f in feedback_list]
        if FLAGS.chunked_training:
            # We need to initialize the model with samples of all lengths for
            # all algorithms. Also, we need to make sure that the order of these
            # sample sizes is the same as the order of the actual training sizes.
            all_length_features = [all_features] + [
                [next(t).features for t in train_samplers]
                for _ in range(len(train_lengths))]
            train_model.init(all_length_features[:-1], FLAGS.seed + 1)
        else:
            train_model.init(all_features, FLAGS.seed + 1)

    # Training step.
    for algo_idx in range(len(train_samplers)):
        feedback = feedback_list[algo_idx]
        rng_key, new_rng_key = jax.random.split(rng_key)
        if FLAGS.chunked_training:
            # In chunked training, we must indicate which training length we are
            # using, so the model uses the correct state.
            length_and_algo_idx = (length_idx, algo_idx)
        else:
            # In non-chunked training, all training lengths can be treated equally,
            # since there is no state to maintain between batches.
            length_and_algo_idx = algo_idx
        cur_loss = train_model.feedback(
            rng_key, feedback, length_and_algo_idx)
        rng_key = new_rng_key

        if FLAGS.chunked_training:
            examples_in_chunk = np.sum(feedback.features.is_last).item()
        else:
            examples_in_chunk = len(feedback.features.lengths)
        current_train_items[algo_idx] += examples_in_chunk
        logging.info('Algo %s step %i current loss %f, current_train_items %i.',
                     FLAGS.algorithms[algo_idx], step,
                     cur_loss, current_train_items[algo_idx])

    # Periodically evaluate model
    if step >= next_eval:
        eval_model.params = train_model.params
        for algo_idx in range(len(train_samplers)):
            common_extras = {'examples_seen': current_train_items[algo_idx],
                             'step': step,
                             'algorithm': FLAGS.algorithms[algo_idx]}

            # Validation info.
            new_rng_key, rng_key = jax.random.split(rng_key)
            val_stats = collect_and_eval(
                val_samplers[algo_idx],
                functools.partial(eval_model.predict,
                                  algorithm_index=algo_idx),
                val_sample_counts[algo_idx],
                new_rng_key,
                extras=common_extras)
            logging.info('(val) algo %s step %d: %s',
                         FLAGS.algorithms[algo_idx], step, val_stats)
            val_scores[algo_idx] = val_stats['score']

        next_eval += FLAGS.eval_every

        # If best total score, update best checkpoint.
        # Also save a best checkpoint on the first step.
        msg = (f'best avg val score was '
               f'{best_score/len(FLAGS.algorithms):.3f}, '
               f'current avg val score is {np.mean(val_scores):.3f}, '
               f'val scores are: ')
        msg += ', '.join(
            ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
        if (sum(val_scores) > best_score) or step == 0:
            best_score = sum(val_scores)
            logging.info('Checkpointing best model, %s', msg)
            train_model.save_model('best.pkl')
        else:
            logging.info('Not saving new best model, %s', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)

2023-03-08 00:34:14,765 - root - INFO - Algo binary_search step 0 current loss 26.699396, current_train_items 32.
2023-03-08 00:34:17,156 - root - INFO - (val) algo binary_search step 0: {'return': 0.065673828125, 'score': 0.065673828125, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-08 00:34:17,157 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.066, val scores are: binary_search: 0.066
2023-03-08 00:34:20,726 - root - INFO - Algo binary_search step 1 current loss 29.408669, current_train_items 64.
2023-03-08 00:34:24,353 - root - INFO - Algo binary_search step 2 current loss 33.881187, current_train_items 96.
2023-03-08 00:34:27,933 - root - INFO - Algo binary_search step 3 current loss 30.668726, current_train_items 128.
2023-03-08 00:34:31,602 - root - INFO - Algo binary_search step 4 current loss 28.633110, current_train_items 160.
2023-03-08 00:34:31,619 - root - INFO - Algo binary_search step 5 current l

2023-03-08 00:34:37,505 - root - INFO - Algo binary_search step 65 current loss 3.746989, current_train_items 2112.
2023-03-08 00:34:37,538 - root - INFO - Algo binary_search step 66 current loss 4.623101, current_train_items 2144.
2023-03-08 00:34:37,609 - root - INFO - Algo binary_search step 67 current loss 5.608706, current_train_items 2176.
2023-03-08 00:34:37,693 - root - INFO - Algo binary_search step 68 current loss 6.138997, current_train_items 2208.
2023-03-08 00:34:37,801 - root - INFO - Algo binary_search step 69 current loss 7.137094, current_train_items 2240.
2023-03-08 00:34:37,814 - root - INFO - Algo binary_search step 70 current loss 2.691839, current_train_items 2272.
2023-03-08 00:34:37,848 - root - INFO - Algo binary_search step 71 current loss 4.469884, current_train_items 2304.
2023-03-08 00:34:37,917 - root - INFO - Algo binary_search step 72 current loss 4.819474, current_train_items 2336.
2023-03-08 00:34:37,999 - root - INFO - Algo binary_search step 73 curre

2023-03-08 00:34:43,892 - root - INFO - Algo binary_search step 133 current loss 3.362348, current_train_items 4288.
2023-03-08 00:34:44,000 - root - INFO - Algo binary_search step 134 current loss 3.238188, current_train_items 4320.
2023-03-08 00:34:44,013 - root - INFO - Algo binary_search step 135 current loss 1.427442, current_train_items 4352.
2023-03-08 00:34:44,046 - root - INFO - Algo binary_search step 136 current loss 1.516623, current_train_items 4384.
2023-03-08 00:34:44,113 - root - INFO - Algo binary_search step 137 current loss 3.138206, current_train_items 4416.
2023-03-08 00:34:44,197 - root - INFO - Algo binary_search step 138 current loss 3.622734, current_train_items 4448.
2023-03-08 00:34:44,303 - root - INFO - Algo binary_search step 139 current loss 4.602278, current_train_items 4480.
2023-03-08 00:34:44,315 - root - INFO - Algo binary_search step 140 current loss 1.105246, current_train_items 4512.
2023-03-08 00:34:44,347 - root - INFO - Algo binary_search step 

2023-03-08 00:34:50,113 - root - INFO - Algo binary_search step 200 current loss 1.320550, current_train_items 6432.
2023-03-08 00:34:52,173 - root - INFO - (val) algo binary_search step 200: {'return': 0.543701171875, 'score': 0.543701171875, 'examples_seen': 6432, 'step': 200, 'algorithm': 'binary_search'}
2023-03-08 00:34:52,174 - root - INFO - Not saving new best model, best avg val score was 0.691, current avg val score is 0.544, val scores are: binary_search: 0.544
2023-03-08 00:34:52,207 - root - INFO - Algo binary_search step 201 current loss 2.350540, current_train_items 6464.
2023-03-08 00:34:52,276 - root - INFO - Algo binary_search step 202 current loss 3.514669, current_train_items 6496.
2023-03-08 00:34:52,358 - root - INFO - Algo binary_search step 203 current loss 2.819115, current_train_items 6528.
2023-03-08 00:34:52,489 - root - INFO - Algo binary_search step 204 current loss 3.543144, current_train_items 6560.
2023-03-08 00:34:52,503 - root - INFO - Algo binary_sear

2023-03-08 00:34:58,284 - root - INFO - Algo binary_search step 264 current loss 3.085880, current_train_items 8480.
2023-03-08 00:34:58,297 - root - INFO - Algo binary_search step 265 current loss 0.818194, current_train_items 8512.
2023-03-08 00:34:58,329 - root - INFO - Algo binary_search step 266 current loss 1.106221, current_train_items 8544.
2023-03-08 00:34:58,422 - root - INFO - Algo binary_search step 267 current loss 1.694093, current_train_items 8576.
2023-03-08 00:34:58,504 - root - INFO - Algo binary_search step 268 current loss 2.146522, current_train_items 8608.
2023-03-08 00:34:58,611 - root - INFO - Algo binary_search step 269 current loss 2.601979, current_train_items 8640.
2023-03-08 00:34:58,625 - root - INFO - Algo binary_search step 270 current loss 0.520964, current_train_items 8672.
2023-03-08 00:34:58,657 - root - INFO - Algo binary_search step 271 current loss 1.433023, current_train_items 8704.
2023-03-08 00:34:58,726 - root - INFO - Algo binary_search step 

2023-03-08 00:35:04,439 - root - INFO - Algo binary_search step 331 current loss 1.295012, current_train_items 10624.
2023-03-08 00:35:04,507 - root - INFO - Algo binary_search step 332 current loss 2.418581, current_train_items 10656.
2023-03-08 00:35:04,590 - root - INFO - Algo binary_search step 333 current loss 3.245784, current_train_items 10688.
2023-03-08 00:35:04,698 - root - INFO - Algo binary_search step 334 current loss 3.150851, current_train_items 10720.
2023-03-08 00:35:04,710 - root - INFO - Algo binary_search step 335 current loss 0.589139, current_train_items 10752.
2023-03-08 00:35:04,744 - root - INFO - Algo binary_search step 336 current loss 1.220576, current_train_items 10784.
2023-03-08 00:35:04,813 - root - INFO - Algo binary_search step 337 current loss 2.042824, current_train_items 10816.
2023-03-08 00:35:04,894 - root - INFO - Algo binary_search step 338 current loss 2.309875, current_train_items 10848.
2023-03-08 00:35:05,001 - root - INFO - Algo binary_sear

2023-03-08 00:35:12,004 - root - INFO - Algo binary_search step 398 current loss 2.076700, current_train_items 12768.
2023-03-08 00:35:12,118 - root - INFO - Algo binary_search step 399 current loss 3.815567, current_train_items 12800.
2023-03-08 00:35:12,119 - root - INFO - Restoring best model from checkpoint...


In [7]:
import pickle

def restore_model(model, file_name):
    """Restore model from `file_name`."""
    with open(file_name, 'rb') as f:
        restored_state = pickle.load(f)
        restored_params = restored_state['params']
        model.params = hk.data_structures.merge(restored_params)
        model.opt_state = restored_state['opt_state']

def save_model(model, file_name):
    """Save model (processor weights only) to `file_name`."""
    to_save = {'params': model.params, 'opt_state': model.opt_state}
    with open(file_name, 'wb') as f:
        pickle.dump(to_save, f)

In [13]:
# save_model(eval_model, 'eval_model_asdf.pkl')

In [8]:
restore_model(eval_model, 'eval_model_asdf.pkl')

In [9]:
algo_idx = 0
common_extras = {}

new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
    val_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    val_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

val_stats

4


{'return': 0.766357421875, 'score': 0.766357421875}

In [10]:
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
    test_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    test_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

test_stats

(1, 32, 64, 64, 128)
6
(5, 1, 32, 64, 64, 128)
(7, 1, 32, 64, 64, 128)


{'return': 0.33544921875, 'score': 0.33544921875}

In [17]:
feedback = next(val_samplers[0])

In [18]:
list(filter(lambda x: x[:2] != '__', dir(feedback)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_make',
 '_replace',
 'count',
 'features',
 'index',
 'outputs']

In [28]:
list(filter(lambda x: x[:2] != '__', dir(feedback.features)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_make',
 '_replace',
 'count',
 'hints',
 'index',
 'inputs',
 'lengths']

In [42]:
feedback.features

Features(inputs=(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 16)), DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 16)), DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)), DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 16))), hints=(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(5, 32, 16)), DataPoint(name="high",	location=node,	type=mask_one,	data=Array(5, 32, 16)), DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(5, 32, 16))), lengths=array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]))

In [37]:
feedback.features.inputs

(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)),
 DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 16)))

In [43]:
feedback.features.inputs[2].data

array([0.14317903, 0.07972956, 0.67909305, 0.98567252, 0.93329147,
       0.28345592, 0.43966783, 0.07153851, 0.12399225, 0.76594674,
       0.36091068, 0.97857869, 0.28576997, 0.40300221, 0.26544856,
       0.52667835, 0.12452213, 0.07972956, 0.93578709, 0.51574709,
       0.3870473 , 0.58158791, 0.68471948, 0.92190682, 0.18154095,
       0.38265265, 0.65144341, 0.42893379, 0.05218584, 0.86252671,
       0.059794  , 0.44887807])

In [48]:
feedback.features.hints

(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="high",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(5, 32, 16)))

In [20]:
np.argmin(feedback.features.inputs[0].data, axis=1)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [21]:
np.where(feedback.outputs[0].data == 1)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 array([ 2,  1, 11, 15, 14,  3,  9,  0,  3, 11,  5, 15,  6,  5,  6,  6,  3,
         1, 14,  9,  7, 12, 10, 13,  2,  5,  9, 10,  1, 14,  1,  8]))

In [10]:
new_rng_key, rng_key = jax.random.split(rng_key)
batched_msgs = get_msgs(
    test_samplers[0],
    functools.partial(eval_model.predict, algorithm_index=0),
#     test_sample_counts[0],  # EXPLODING MEMORY LOL
    32*8,
    new_rng_key)

batched_msgs.shape

0
6
32
64
96
128
160
192
224


(256, 6, 64, 64, 128)

In [57]:
file_name = 'binary_search_val_msgs.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(msgs, f)