# Run training of CLRS algorithm

Copied main function from clrs_train.py. Helper functions are located in clrs_train_funcs.py

In [1]:
import functools
import os
import shutil
from typing import Any, Dict, List

import logging
import clrs
import jax
import jax.numpy as jnp
import numpy as np
import requests
import tensorflow as tf
import haiku as hk

import model
import flags
from clrs_train_funcs import *

In [2]:
FLAGS = flags.FLAGS

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

In [4]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError(
        'Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

2023-03-23 10:55:41,982 - jax._src.lib.xla_bridge - INFO - Remote TPU is not linked into jax; skipping remote TPU.
2023-03-23 10:55:41,982 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
2023-03-23 10:55:41,983 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-23 10:55:41,985 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-23 10:55:41,986 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-03-23 10:55:41,990 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (d

In [5]:
# Create samplers
(train_samplers,
 val_samplers, val_sample_counts,
 test_samplers, test_sample_counts,
 spec_list) = create_samplers(rng, train_lengths)

2023-03-23 10:55:42,032 - root - INFO - Creating samplers for algo binary_search


Metal device set to: Apple M1 Pro


2023-03-23 10:55:42,655 - absl - INFO - Creating a dataset with 4096 samples.
2023-03-23 10:55:42,871 - absl - INFO - 1000 samples created
2023-03-23 10:55:42,997 - absl - INFO - 2000 samples created
2023-03-23 10:55:43,123 - absl - INFO - 3000 samples created
2023-03-23 10:55:43,246 - absl - INFO - 4000 samples created
2023-03-23 10:55:43,325 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skipping download.
2023-03-23 10:55:43,326 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-23 10:55:43,330 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-23 10:55:43,331 - absl - INFO - Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0)
2023-03-23 10:55:43,332 - absl - INFO - Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
2023-03-23 10:55:43.520796: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [6]:
FLAGS.l1_weight = 0.001
FLAGS.train_steps = 1000
FLAGS.hidden_size = 32
FLAGS.msg_size = 64

In [7]:
# FLAGS.hidden_size = 8
# FLAGS.algorithms = ['dijkstra']

In [8]:
processor_factory = model.get_processor_factory(
    FLAGS.processor_type,
    use_ln=FLAGS.use_ln,
    nb_triplet_fts=FLAGS.nb_triplet_fts,
    nb_heads=FLAGS.nb_heads
)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=FLAGS.hidden_size,
    msg_dim=FLAGS.msg_size,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init=FLAGS.encoder_init,
    use_lstm=FLAGS.use_lstm,
    learning_rate=FLAGS.learning_rate,
    grad_clip_max_norm=FLAGS.grad_clip_max_norm,
    checkpoint_path=FLAGS.checkpoint_path,
    freeze_processor=FLAGS.freeze_processor,
    dropout_prob=FLAGS.dropout_prob,
    hint_teacher_forcing=FLAGS.hint_teacher_forcing,
    hint_repred_mode=FLAGS.hint_repred_mode,
    nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
    l1_weight=FLAGS.l1_weight
)

eval_model = model.BaselineMsgModel(
    spec=spec_list,
    dummy_trajectory=[next(t) for t in val_samplers],
    **model_params
)
# # we will never used chunked training
# if FLAGS.chunked_training:
#     train_model = clrs.models.BaselineModelChunked(
#         spec=spec_list,
#         dummy_trajectory=[next(t) for t in train_samplers],
#         **model_params
#     )
# else:
#     train_model = eval_model
train_model = eval_model

In [9]:
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0

while step < FLAGS.train_steps:
    feedback_list = [next(t) for t in train_samplers]

    # Initialize model.
    if step == 0:
        all_features = [f.features for f in feedback_list]
        if FLAGS.chunked_training:
            # We need to initialize the model with samples of all lengths for
            # all algorithms. Also, we need to make sure that the order of these
            # sample sizes is the same as the order of the actual training sizes.
            all_length_features = [all_features] + [
                [next(t).features for t in train_samplers]
                for _ in range(len(train_lengths))]
            train_model.init(all_length_features[:-1], FLAGS.seed + 1)
        else:
            train_model.init(all_features, FLAGS.seed + 1)

    # Training step.
    for algo_idx in range(len(train_samplers)):
        feedback = feedback_list[algo_idx]
        rng_key, new_rng_key = jax.random.split(rng_key)
        if FLAGS.chunked_training:
            # In chunked training, we must indicate which training length we are
            # using, so the model uses the correct state.
            length_and_algo_idx = (length_idx, algo_idx)
        else:
            # In non-chunked training, all training lengths can be treated equally,
            # since there is no state to maintain between batches.
            length_and_algo_idx = algo_idx
        cur_loss = train_model.feedback(
            rng_key, feedback, length_and_algo_idx)
        rng_key = new_rng_key

        if FLAGS.chunked_training:
            examples_in_chunk = np.sum(feedback.features.is_last).item()
        else:
            examples_in_chunk = len(feedback.features.lengths)
        current_train_items[algo_idx] += examples_in_chunk
        logging.info('Algo %s step %i current loss %f, current_train_items %i.',
                     FLAGS.algorithms[algo_idx], step,
                     cur_loss, current_train_items[algo_idx])

    # Periodically evaluate model
    if step >= next_eval:
        eval_model.params = train_model.params
        for algo_idx in range(len(train_samplers)):
            common_extras = {'examples_seen': current_train_items[algo_idx],
                             'step': step,
                             'algorithm': FLAGS.algorithms[algo_idx]}

            # Validation info.
            new_rng_key, rng_key = jax.random.split(rng_key)
            val_stats = collect_and_eval(
                val_samplers[algo_idx],
                functools.partial(eval_model.predict,
                                  algorithm_index=algo_idx),
                val_sample_counts[algo_idx],
                new_rng_key,
                extras=common_extras)
            logging.info('(val) algo %s step %d: %s',
                         FLAGS.algorithms[algo_idx], step, val_stats)
            val_scores[algo_idx] = val_stats['score']

        next_eval += FLAGS.eval_every

        # If best total score, update best checkpoint.
        # Also save a best checkpoint on the first step.
        msg = (f'best avg val score was '
               f'{best_score/len(FLAGS.algorithms):.3f}, '
               f'current avg val score is {np.mean(val_scores):.3f}, '
               f'val scores are: ')
        msg += ', '.join(
            ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
        if (sum(val_scores) > best_score) or step == 0:
            best_score = sum(val_scores)
            logging.info('Checkpointing best model, %s', msg)
            train_model.save_model('best.pkl')
        else:
            logging.info('Not saving new best model, %s', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)

2023-03-23 10:55:47,758 - root - INFO - Algo binary_search step 0 current loss 6.239127, current_train_items 32.
2023-03-23 10:55:50,400 - root - INFO - (val) algo binary_search step 0: {'return': 0.090087890625, 'score': 0.090087890625, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-23 10:55:50,400 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.090, val scores are: binary_search: 0.090
2023-03-23 10:55:53,668 - root - INFO - Algo binary_search step 1 current loss 8.979127, current_train_items 64.
2023-03-23 10:55:56,864 - root - INFO - Algo binary_search step 2 current loss 11.884584, current_train_items 96.
2023-03-23 10:56:00,204 - root - INFO - Algo binary_search step 3 current loss 13.249916, current_train_items 128.
2023-03-23 10:56:03,675 - root - INFO - Algo binary_search step 4 current loss 14.738131, current_train_items 160.
2023-03-23 10:56:03,687 - root - INFO - Algo binary_search step 5 current los

2023-03-23 10:56:08,047 - root - INFO - Algo binary_search step 65 current loss 3.129440, current_train_items 2112.
2023-03-23 10:56:08,064 - root - INFO - Algo binary_search step 66 current loss 5.345869, current_train_items 2144.
2023-03-23 10:56:08,091 - root - INFO - Algo binary_search step 67 current loss 6.262424, current_train_items 2176.
2023-03-23 10:56:08,122 - root - INFO - Algo binary_search step 68 current loss 6.721148, current_train_items 2208.
2023-03-23 10:56:08,165 - root - INFO - Algo binary_search step 69 current loss 7.544115, current_train_items 2240.
2023-03-23 10:56:08,174 - root - INFO - Algo binary_search step 70 current loss 3.016263, current_train_items 2272.
2023-03-23 10:56:08,191 - root - INFO - Algo binary_search step 71 current loss 4.796534, current_train_items 2304.
2023-03-23 10:56:08,220 - root - INFO - Algo binary_search step 72 current loss 5.786991, current_train_items 2336.
2023-03-23 10:56:08,253 - root - INFO - Algo binary_search step 73 curre

2023-03-23 10:56:12,325 - root - INFO - Algo binary_search step 133 current loss 5.085920, current_train_items 4288.
2023-03-23 10:56:12,365 - root - INFO - Algo binary_search step 134 current loss 6.456422, current_train_items 4320.
2023-03-23 10:56:12,378 - root - INFO - Algo binary_search step 135 current loss 2.050768, current_train_items 4352.
2023-03-23 10:56:12,397 - root - INFO - Algo binary_search step 136 current loss 2.633639, current_train_items 4384.
2023-03-23 10:56:12,423 - root - INFO - Algo binary_search step 137 current loss 4.423470, current_train_items 4416.
2023-03-23 10:56:12,455 - root - INFO - Algo binary_search step 138 current loss 5.306309, current_train_items 4448.
2023-03-23 10:56:12,495 - root - INFO - Algo binary_search step 139 current loss 6.110987, current_train_items 4480.
2023-03-23 10:56:12,504 - root - INFO - Algo binary_search step 140 current loss 1.612501, current_train_items 4512.
2023-03-23 10:56:12,522 - root - INFO - Algo binary_search step 

2023-03-23 10:56:16,270 - root - INFO - Algo binary_search step 200 current loss 1.432955, current_train_items 6432.
2023-03-23 10:56:18,483 - root - INFO - (val) algo binary_search step 200: {'return': 0.78271484375, 'score': 0.78271484375, 'examples_seen': 6432, 'step': 200, 'algorithm': 'binary_search'}
2023-03-23 10:56:18,484 - root - INFO - Not saving new best model, best avg val score was 0.785, current avg val score is 0.783, val scores are: binary_search: 0.783
2023-03-23 10:56:18,502 - root - INFO - Algo binary_search step 201 current loss 2.762525, current_train_items 6464.
2023-03-23 10:56:18,529 - root - INFO - Algo binary_search step 202 current loss 3.426068, current_train_items 6496.
2023-03-23 10:56:18,559 - root - INFO - Algo binary_search step 203 current loss 4.566252, current_train_items 6528.
2023-03-23 10:56:18,597 - root - INFO - Algo binary_search step 204 current loss 6.586327, current_train_items 6560.
2023-03-23 10:56:18,605 - root - INFO - Algo binary_search

2023-03-23 10:56:22,415 - root - INFO - Algo binary_search step 264 current loss 4.955285, current_train_items 8480.
2023-03-23 10:56:22,424 - root - INFO - Algo binary_search step 265 current loss 0.889073, current_train_items 8512.
2023-03-23 10:56:22,441 - root - INFO - Algo binary_search step 266 current loss 2.501162, current_train_items 8544.
2023-03-23 10:56:22,473 - root - INFO - Algo binary_search step 267 current loss 3.584902, current_train_items 8576.
2023-03-23 10:56:22,505 - root - INFO - Algo binary_search step 268 current loss 4.120938, current_train_items 8608.
2023-03-23 10:56:22,545 - root - INFO - Algo binary_search step 269 current loss 3.977637, current_train_items 8640.
2023-03-23 10:56:22,554 - root - INFO - Algo binary_search step 270 current loss 0.836556, current_train_items 8672.
2023-03-23 10:56:22,571 - root - INFO - Algo binary_search step 271 current loss 2.543896, current_train_items 8704.
2023-03-23 10:56:22,599 - root - INFO - Algo binary_search step 

2023-03-23 10:56:26,236 - root - INFO - Algo binary_search step 331 current loss 1.864539, current_train_items 10624.
2023-03-23 10:56:26,264 - root - INFO - Algo binary_search step 332 current loss 3.175578, current_train_items 10656.
2023-03-23 10:56:26,296 - root - INFO - Algo binary_search step 333 current loss 4.009730, current_train_items 10688.
2023-03-23 10:56:26,333 - root - INFO - Algo binary_search step 334 current loss 4.103711, current_train_items 10720.
2023-03-23 10:56:26,341 - root - INFO - Algo binary_search step 335 current loss 0.905106, current_train_items 10752.
2023-03-23 10:56:26,360 - root - INFO - Algo binary_search step 336 current loss 1.994769, current_train_items 10784.
2023-03-23 10:56:26,387 - root - INFO - Algo binary_search step 337 current loss 3.212667, current_train_items 10816.
2023-03-23 10:56:26,419 - root - INFO - Algo binary_search step 338 current loss 3.244218, current_train_items 10848.
2023-03-23 10:56:26,457 - root - INFO - Algo binary_sear

2023-03-23 09:56:18,435 - root - INFO - Algo binary_search step 398 current loss 4.322250, current_train_items 12768.
2023-03-23 09:56:18,474 - root - INFO - Algo binary_search step 399 current loss 4.094738, current_train_items 12800.
2023-03-23 09:56:18,483 - root - INFO - Algo binary_search step 400 current loss 0.450725, current_train_items 12832.
2023-03-23 09:56:21,021 - root - INFO - (val) algo binary_search step 400: {'return': 0.6318359375, 'score': 0.6318359375, 'examples_seen': 12832, 'step': 400, 'algorithm': 'binary_search'}
2023-03-23 09:56:21,025 - root - INFO - Not saving new best model, best avg val score was 0.828, current avg val score is 0.632, val scores are: binary_search: 0.632
2023-03-23 09:56:21,044 - root - INFO - Algo binary_search step 401 current loss 2.042456, current_train_items 12864.
2023-03-23 09:56:21,073 - root - INFO - Algo binary_search step 402 current loss 4.357592, current_train_items 12896.
2023-03-23 09:56:21,111 - root - INFO - Algo binary_se

2023-03-23 09:56:24,995 - root - INFO - Algo binary_search step 462 current loss 2.550913, current_train_items 14816.
2023-03-23 09:56:25,025 - root - INFO - Algo binary_search step 463 current loss 3.123833, current_train_items 14848.
2023-03-23 09:56:25,064 - root - INFO - Algo binary_search step 464 current loss 4.423604, current_train_items 14880.
2023-03-23 09:56:25,072 - root - INFO - Algo binary_search step 465 current loss 0.715862, current_train_items 14912.
2023-03-23 09:56:25,088 - root - INFO - Algo binary_search step 466 current loss 1.510598, current_train_items 14944.
2023-03-23 09:56:25,140 - root - INFO - Algo binary_search step 467 current loss 2.273409, current_train_items 14976.
2023-03-23 09:56:25,191 - root - INFO - Algo binary_search step 468 current loss 2.992526, current_train_items 15008.
2023-03-23 09:56:25,230 - root - INFO - Algo binary_search step 469 current loss 3.514890, current_train_items 15040.
2023-03-23 09:56:25,238 - root - INFO - Algo binary_sear

2023-03-23 09:56:28,974 - root - INFO - Algo binary_search step 529 current loss 3.703395, current_train_items 16960.
2023-03-23 09:56:28,985 - root - INFO - Algo binary_search step 530 current loss 0.407960, current_train_items 16992.
2023-03-23 09:56:29,002 - root - INFO - Algo binary_search step 531 current loss 1.277289, current_train_items 17024.
2023-03-23 09:56:29,043 - root - INFO - Algo binary_search step 532 current loss 2.813943, current_train_items 17056.
2023-03-23 09:56:29,073 - root - INFO - Algo binary_search step 533 current loss 3.126002, current_train_items 17088.
2023-03-23 09:56:29,110 - root - INFO - Algo binary_search step 534 current loss 3.753290, current_train_items 17120.
2023-03-23 09:56:29,118 - root - INFO - Algo binary_search step 535 current loss 0.756774, current_train_items 17152.
2023-03-23 09:56:29,134 - root - INFO - Algo binary_search step 536 current loss 1.085056, current_train_items 17184.
2023-03-23 09:56:29,161 - root - INFO - Algo binary_sear

2023-03-23 09:56:32,839 - root - INFO - Algo binary_search step 596 current loss 1.235948, current_train_items 19104.
2023-03-23 09:56:32,867 - root - INFO - Algo binary_search step 597 current loss 1.993064, current_train_items 19136.
2023-03-23 09:56:32,901 - root - INFO - Algo binary_search step 598 current loss 2.892847, current_train_items 19168.
2023-03-23 09:56:32,941 - root - INFO - Algo binary_search step 599 current loss 3.398289, current_train_items 19200.
2023-03-23 09:56:32,950 - root - INFO - Algo binary_search step 600 current loss 0.882334, current_train_items 19232.
2023-03-23 09:56:35,167 - root - INFO - (val) algo binary_search step 600: {'return': 0.902099609375, 'score': 0.902099609375, 'examples_seen': 19232, 'step': 600, 'algorithm': 'binary_search'}
2023-03-23 09:56:35,167 - root - INFO - Checkpointing best model, best avg val score was 0.897, current avg val score is 0.902, val scores are: binary_search: 0.902
2023-03-23 09:56:35,188 - root - INFO - Algo binary

2023-03-23 09:56:38,877 - root - INFO - Algo binary_search step 660 current loss 0.500361, current_train_items 21152.
2023-03-23 09:56:38,894 - root - INFO - Algo binary_search step 661 current loss 1.326054, current_train_items 21184.
2023-03-23 09:56:38,922 - root - INFO - Algo binary_search step 662 current loss 1.817952, current_train_items 21216.
2023-03-23 09:56:38,955 - root - INFO - Algo binary_search step 663 current loss 2.660838, current_train_items 21248.
2023-03-23 09:56:38,996 - root - INFO - Algo binary_search step 664 current loss 3.487111, current_train_items 21280.
2023-03-23 09:56:39,005 - root - INFO - Algo binary_search step 665 current loss 0.978258, current_train_items 21312.
2023-03-23 09:56:39,034 - root - INFO - Algo binary_search step 666 current loss 1.141056, current_train_items 21344.
2023-03-23 09:56:39,061 - root - INFO - Algo binary_search step 667 current loss 2.616483, current_train_items 21376.
2023-03-23 09:56:39,093 - root - INFO - Algo binary_sear

2023-03-23 09:56:42,784 - root - INFO - Algo binary_search step 727 current loss 1.585537, current_train_items 23296.
2023-03-23 09:56:42,815 - root - INFO - Algo binary_search step 728 current loss 1.845163, current_train_items 23328.
2023-03-23 09:56:42,853 - root - INFO - Algo binary_search step 729 current loss 3.426191, current_train_items 23360.
2023-03-23 09:56:42,861 - root - INFO - Algo binary_search step 730 current loss 1.116132, current_train_items 23392.
2023-03-23 09:56:42,877 - root - INFO - Algo binary_search step 731 current loss 1.922843, current_train_items 23424.
2023-03-23 09:56:42,903 - root - INFO - Algo binary_search step 732 current loss 2.877601, current_train_items 23456.
2023-03-23 09:56:42,934 - root - INFO - Algo binary_search step 733 current loss 4.109115, current_train_items 23488.
2023-03-23 09:56:42,971 - root - INFO - Algo binary_search step 734 current loss 4.122854, current_train_items 23520.
2023-03-23 09:56:42,979 - root - INFO - Algo binary_sear

2023-03-23 09:56:46,635 - root - INFO - Algo binary_search step 794 current loss 2.729894, current_train_items 25440.
2023-03-23 09:56:46,642 - root - INFO - Algo binary_search step 795 current loss 0.858472, current_train_items 25472.
2023-03-23 09:56:46,660 - root - INFO - Algo binary_search step 796 current loss 0.745255, current_train_items 25504.
2023-03-23 09:56:46,686 - root - INFO - Algo binary_search step 797 current loss 1.564073, current_train_items 25536.
2023-03-23 09:56:46,720 - root - INFO - Algo binary_search step 798 current loss 1.781281, current_train_items 25568.
2023-03-23 09:56:46,760 - root - INFO - Algo binary_search step 799 current loss 2.925053, current_train_items 25600.
2023-03-23 09:56:46,769 - root - INFO - Algo binary_search step 800 current loss 0.414793, current_train_items 25632.
2023-03-23 09:56:49,007 - root - INFO - (val) algo binary_search step 800: {'return': 0.88037109375, 'score': 0.88037109375, 'examples_seen': 25632, 'step': 800, 'algorithm':

2023-03-23 09:56:52,610 - root - INFO - Algo binary_search step 858 current loss 2.959124, current_train_items 27488.
2023-03-23 09:56:52,648 - root - INFO - Algo binary_search step 859 current loss 3.625861, current_train_items 27520.
2023-03-23 09:56:52,657 - root - INFO - Algo binary_search step 860 current loss 0.879026, current_train_items 27552.
2023-03-23 09:56:52,674 - root - INFO - Algo binary_search step 861 current loss 0.980302, current_train_items 27584.
2023-03-23 09:56:52,700 - root - INFO - Algo binary_search step 862 current loss 1.773355, current_train_items 27616.
2023-03-23 09:56:52,732 - root - INFO - Algo binary_search step 863 current loss 2.163724, current_train_items 27648.
2023-03-23 09:56:52,770 - root - INFO - Algo binary_search step 864 current loss 3.323235, current_train_items 27680.
2023-03-23 09:56:52,777 - root - INFO - Algo binary_search step 865 current loss 0.499162, current_train_items 27712.
2023-03-23 09:56:52,805 - root - INFO - Algo binary_sear

2023-03-23 09:56:56,430 - root - INFO - Algo binary_search step 925 current loss 0.661848, current_train_items 29632.
2023-03-23 09:56:56,446 - root - INFO - Algo binary_search step 926 current loss 0.708756, current_train_items 29664.
2023-03-23 09:56:56,473 - root - INFO - Algo binary_search step 927 current loss 1.276768, current_train_items 29696.
2023-03-23 09:56:56,504 - root - INFO - Algo binary_search step 928 current loss 2.631158, current_train_items 29728.
2023-03-23 09:56:56,542 - root - INFO - Algo binary_search step 929 current loss 3.555871, current_train_items 29760.
2023-03-23 09:56:56,549 - root - INFO - Algo binary_search step 930 current loss 0.837018, current_train_items 29792.
2023-03-23 09:56:56,566 - root - INFO - Algo binary_search step 931 current loss 0.693261, current_train_items 29824.
2023-03-23 09:56:56,593 - root - INFO - Algo binary_search step 932 current loss 1.426298, current_train_items 29856.
2023-03-23 09:56:56,623 - root - INFO - Algo binary_sear

2023-03-23 09:57:00,211 - root - INFO - Algo binary_search step 992 current loss 1.665790, current_train_items 31776.
2023-03-23 09:57:00,242 - root - INFO - Algo binary_search step 993 current loss 1.789044, current_train_items 31808.
2023-03-23 09:57:00,280 - root - INFO - Algo binary_search step 994 current loss 2.585242, current_train_items 31840.
2023-03-23 09:57:00,289 - root - INFO - Algo binary_search step 995 current loss 0.324096, current_train_items 31872.
2023-03-23 09:57:00,306 - root - INFO - Algo binary_search step 996 current loss 0.568989, current_train_items 31904.
2023-03-23 09:57:00,334 - root - INFO - Algo binary_search step 997 current loss 1.345299, current_train_items 31936.
2023-03-23 09:57:00,364 - root - INFO - Algo binary_search step 998 current loss 2.552667, current_train_items 31968.
2023-03-23 09:57:00,403 - root - INFO - Algo binary_search step 999 current loss 4.119041, current_train_items 32000.
2023-03-23 09:57:00,403 - root - INFO - Restoring best m

In [None]:
import pickle

def restore_model(model, file_name):
    """Restore model from `file_name`."""
    with open(file_name, 'rb') as f:
        restored_state = pickle.load(f)
        restored_params = restored_state['params']
        model.params = hk.data_structures.merge(restored_params)
        model.opt_state = restored_state['opt_state']

def save_model(model, file_name):
    """Save model (processor weights only) to `file_name`."""
    to_save = {'params': model.params, 'opt_state': model.opt_state}
    with open(file_name, 'wb') as f:
        pickle.dump(to_save, f)

In [None]:
# save_model(eval_model, 'eval_model_1e-3.pkl')

In [None]:
# restore_model(eval_model, 'eval_model_asdf.pkl')

In [10]:
algo_idx = 0
common_extras = {}

new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
    val_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    val_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(val_stats)

{'return': 0.906005859375, 'score': 0.906005859375}


In [11]:
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
    test_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    test_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(test_stats)

{'return': 0.736328125, 'score': 0.736328125}


In [None]:
feedback = next(train_samplers[0])

In [None]:
list(filter(lambda x: x[:2] != '__', dir(feedback)))

In [None]:
list(filter(lambda x: x[:2] != '__', dir(feedback.features)))

In [None]:
feedback.features

In [None]:
feedback.features.inputs

In [None]:
feedback.features.inputs[0].data

In [None]:
feedback.features.inputs[1].data

In [None]:
feedback.features.inputs[2].data

In [None]:
feedback.features.inputs[3].data

In [None]:
A = feedback.features.inputs[2].data
adj = feedback.features.inputs[3].data

In [None]:
jnp.where((A != 0) & (adj == 0))

In [None]:
feedback.features.hints

In [None]:
np.argmin(feedback.features.inputs[0].data, axis=1)

In [None]:
np.where(feedback.outputs[0].data == 1)

In [None]:
test_sample_counts[0]

In [None]:
def _concat(dps, axis):
    return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)

def get_msgs(sampler, predict_fn, sample_count, rng_key, sample_prob=0.001):
    """Get messages from model.
    
    CAUTION: size of msgs can get large very quickly, so beware when
        running with a large number of samples.
    Use sample_prob to reduce the number of messages that are saved
    by randomly sampling messages
    """
    processed_samples = 0
    msgs = []
    while processed_samples < sample_count:
        feedback = next(sampler)
        batch_size = feedback.outputs[0].data.shape[0]
        new_rng_key, rng_key = jax.random.split(rng_key)
        _, _, cur_msgs, cur_input_msg = predict_fn(new_rng_key, feedback.features)
        
        cur_msgs = cur_msgs.reshape(-1, cur_msgs.shape[-1])
        cur_input_msg = cur_input_msg.reshape(-1, cur_input_msg.shape[-1])
        cur_msg_concat = jnp.concatenate((cur_msgs, cur_input_msg), axis=-1)
        
        new_rng_key, rng_key = jax.random.split(rng_key)
        mask = jax.random.choice(new_rng_key,
                                 a=jnp.array([False, True]),
                                 shape=(cur_msg_concat.shape[0],),
                                 p=jnp.array([1 - sample_prob, sample_prob]),
                                 replace=True,)
        cur_msg_concat = cur_msg_concat[mask]
        
        msgs.append(cur_msg_concat)
        processed_samples += batch_size
    msgs = _concat(msgs, axis=0)
    
    return msgs


In [12]:
new_rng_key, rng_key = jax.random.split(rng_key)
batched_msgs = get_msgs(
    test_samplers[0],
    functools.partial(eval_model.predict, algorithm_index=0),
#     test_sample_counts[0],  # EXPLODING MEMORY LOL
    32*4,
    new_rng_key,
    0.01)

batched_msgs.shape

64 128 13
64 128 13
64 128 13
64 128 13


(31436, 205)

In [None]:
file_name = 'binary_search_val_msgs.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(msgs, f)

In [None]:
import sys
sys.getsizeof(batched_msgs)

In [None]:
test_sample_counts[0]

In [None]:
sys.getsizeof(batched_msgs) / 1000000