# Run training of CLRS algorithm

Copied main function from clrs_train.py. Helper functions are located in clrs_train_funcs.py

In [37]:
import functools
import os
import shutil
from typing import Any, Dict, List

import logging
import clrs
import jax
import jax.numpy as jnp
import numpy as np
import requests
import tensorflow as tf
import haiku as hk

import model
import flags
from clrs_train_funcs import *

In [2]:
FLAGS = flags.FLAGS

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

In [4]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError(
        'Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

2023-03-17 10:01:01,938 - jax._src.lib.xla_bridge - INFO - Remote TPU is not linked into jax; skipping remote TPU.
2023-03-17 10:01:01,939 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
2023-03-17 10:01:01,939 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-17 10:01:01,939 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2023-03-17 10:01:01,940 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
2023-03-17 10:01:01,940 - jax._src.lib.xla_bridge - INFO - Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (d

In [5]:
# Create samplers
(train_samplers,
 val_samplers, val_sample_counts,
 test_samplers, test_sample_counts,
 spec_list) = create_samplers(rng, train_lengths)

2023-03-17 10:01:01,981 - root - INFO - Creating samplers for algo binary_search


Metal device set to: Apple M1 Pro


2023-03-17 10:01:02,638 - absl - INFO - Creating a dataset with 4096 samples.
2023-03-17 10:01:02,871 - absl - INFO - 1000 samples created
2023-03-17 10:01:03,010 - absl - INFO - 2000 samples created
2023-03-17 10:01:03,143 - absl - INFO - 3000 samples created
2023-03-17 10:01:03,265 - absl - INFO - 4000 samples created
2023-03-17 10:01:03,340 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skipping download.
2023-03-17 10:01:03,341 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-17 10:01:03,342 - absl - INFO - Load dataset info from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0
2023-03-17 10:01:03,343 - absl - INFO - Reusing dataset clrs_dataset (/tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0)
2023-03-17 10:01:03,343 - absl - INFO - Constructing tf.data.Dataset clrs_dataset for split test, from /tmp/CLRS30/CLRS30_v1.0.0/clrs_dataset/binary_search_test/1.0.0


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
2023-03-17 10:01:03.540859: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


In [11]:
FLAGS.l1_weight = 0.001
FLAGS.train_steps = 500

In [12]:
processor_factory = model.get_processor_factory(
    FLAGS.processor_type,
    use_ln=FLAGS.use_ln,
    nb_triplet_fts=FLAGS.nb_triplet_fts,
    nb_heads=FLAGS.nb_heads
)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=FLAGS.hidden_size,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init=FLAGS.encoder_init,
    use_lstm=FLAGS.use_lstm,
    learning_rate=FLAGS.learning_rate,
    grad_clip_max_norm=FLAGS.grad_clip_max_norm,
    checkpoint_path=FLAGS.checkpoint_path,
    freeze_processor=FLAGS.freeze_processor,
    dropout_prob=FLAGS.dropout_prob,
    hint_teacher_forcing=FLAGS.hint_teacher_forcing,
    hint_repred_mode=FLAGS.hint_repred_mode,
    nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
    l1_weight=FLAGS.l1_weight
)

eval_model = model.BaselineMsgModel(
    spec=spec_list,
    dummy_trajectory=[next(t) for t in val_samplers],
    **model_params
)
# # we will never used chunked training
# if FLAGS.chunked_training:
#     train_model = clrs.models.BaselineModelChunked(
#         spec=spec_list,
#         dummy_trajectory=[next(t) for t in train_samplers],
#         **model_params
#     )
# else:
#     train_model = eval_model
train_model = eval_model

In [13]:
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0

while step < FLAGS.train_steps:
    feedback_list = [next(t) for t in train_samplers]

    # Initialize model.
    if step == 0:
        all_features = [f.features for f in feedback_list]
        if FLAGS.chunked_training:
            # We need to initialize the model with samples of all lengths for
            # all algorithms. Also, we need to make sure that the order of these
            # sample sizes is the same as the order of the actual training sizes.
            all_length_features = [all_features] + [
                [next(t).features for t in train_samplers]
                for _ in range(len(train_lengths))]
            train_model.init(all_length_features[:-1], FLAGS.seed + 1)
        else:
            train_model.init(all_features, FLAGS.seed + 1)

    # Training step.
    for algo_idx in range(len(train_samplers)):
        feedback = feedback_list[algo_idx]
        rng_key, new_rng_key = jax.random.split(rng_key)
        if FLAGS.chunked_training:
            # In chunked training, we must indicate which training length we are
            # using, so the model uses the correct state.
            length_and_algo_idx = (length_idx, algo_idx)
        else:
            # In non-chunked training, all training lengths can be treated equally,
            # since there is no state to maintain between batches.
            length_and_algo_idx = algo_idx
        cur_loss = train_model.feedback(
            rng_key, feedback, length_and_algo_idx)
        rng_key = new_rng_key

        if FLAGS.chunked_training:
            examples_in_chunk = np.sum(feedback.features.is_last).item()
        else:
            examples_in_chunk = len(feedback.features.lengths)
        current_train_items[algo_idx] += examples_in_chunk
        logging.info('Algo %s step %i current loss %f, current_train_items %i.',
                     FLAGS.algorithms[algo_idx], step,
                     cur_loss, current_train_items[algo_idx])

    # Periodically evaluate model
    if step >= next_eval:
        eval_model.params = train_model.params
        for algo_idx in range(len(train_samplers)):
            common_extras = {'examples_seen': current_train_items[algo_idx],
                             'step': step,
                             'algorithm': FLAGS.algorithms[algo_idx]}

            # Validation info.
            new_rng_key, rng_key = jax.random.split(rng_key)
            val_stats = collect_and_eval(
                val_samplers[algo_idx],
                functools.partial(eval_model.predict,
                                  algorithm_index=algo_idx),
                val_sample_counts[algo_idx],
                new_rng_key,
                extras=common_extras)
            logging.info('(val) algo %s step %d: %s',
                         FLAGS.algorithms[algo_idx], step, val_stats)
            val_scores[algo_idx] = val_stats['score']

        next_eval += FLAGS.eval_every

        # If best total score, update best checkpoint.
        # Also save a best checkpoint on the first step.
        msg = (f'best avg val score was '
               f'{best_score/len(FLAGS.algorithms):.3f}, '
               f'current avg val score is {np.mean(val_scores):.3f}, '
               f'val scores are: ')
        msg += ', '.join(
            ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
        if (sum(val_scores) > best_score) or step == 0:
            best_score = sum(val_scores)
            logging.info('Checkpointing best model, %s', msg)
            train_model.save_model('best.pkl')
        else:
            logging.info('Not saving new best model, %s', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)

(32, 2, 4, 4, 128)


2023-03-17 10:03:38,267 - root - INFO - Algo binary_search step 0 current loss 6.506681, current_train_items 32.
2023-03-17 10:03:40,861 - root - INFO - (val) algo binary_search step 0: {'return': 0.09521484375, 'score': 0.09521484375, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-17 10:03:40,861 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.095, val scores are: binary_search: 0.095


(32, 3, 7, 7, 128)


2023-03-17 10:03:44,304 - root - INFO - Algo binary_search step 1 current loss 9.225758, current_train_items 64.


(32, 4, 11, 11, 128)


2023-03-17 10:03:47,739 - root - INFO - Algo binary_search step 2 current loss 13.320404, current_train_items 96.


(32, 4, 13, 13, 128)


2023-03-17 10:03:51,240 - root - INFO - Algo binary_search step 3 current loss 13.820802, current_train_items 128.


(32, 4, 16, 16, 128)


2023-03-17 10:03:54,674 - root - INFO - Algo binary_search step 4 current loss 15.492587, current_train_items 160.
2023-03-17 10:03:54,690 - root - INFO - Algo binary_search step 5 current loss 5.844983, current_train_items 192.
2023-03-17 10:03:54,725 - root - INFO - Algo binary_search step 6 current loss 7.556711, current_train_items 224.
2023-03-17 10:03:54,795 - root - INFO - Algo binary_search step 7 current loss 9.790084, current_train_items 256.
2023-03-17 10:03:54,898 - root - INFO - Algo binary_search step 8 current loss 10.749120, current_train_items 288.
2023-03-17 10:03:55,007 - root - INFO - Algo binary_search step 9 current loss 12.213992, current_train_items 320.
2023-03-17 10:03:55,022 - root - INFO - Algo binary_search step 10 current loss 5.405589, current_train_items 352.
2023-03-17 10:03:55,059 - root - INFO - Algo binary_search step 11 current loss 6.864581, current_train_items 384.
2023-03-17 10:03:55,128 - root - INFO - Algo binary_search step 12 current loss 8.8

2023-03-17 10:04:01,399 - root - INFO - Algo binary_search step 72 current loss 4.815959, current_train_items 2336.
2023-03-17 10:04:01,482 - root - INFO - Algo binary_search step 73 current loss 4.907534, current_train_items 2368.
2023-03-17 10:04:01,594 - root - INFO - Algo binary_search step 74 current loss 5.452176, current_train_items 2400.
2023-03-17 10:04:01,607 - root - INFO - Algo binary_search step 75 current loss 1.807760, current_train_items 2432.
2023-03-17 10:04:01,643 - root - INFO - Algo binary_search step 76 current loss 2.645643, current_train_items 2464.
2023-03-17 10:04:01,717 - root - INFO - Algo binary_search step 77 current loss 4.997289, current_train_items 2496.
2023-03-17 10:04:01,819 - root - INFO - Algo binary_search step 78 current loss 5.114048, current_train_items 2528.
2023-03-17 10:04:01,933 - root - INFO - Algo binary_search step 79 current loss 5.380050, current_train_items 2560.
2023-03-17 10:04:01,948 - root - INFO - Algo binary_search step 80 curre

2023-03-17 10:04:08,181 - root - INFO - Algo binary_search step 140 current loss 0.685631, current_train_items 4512.
2023-03-17 10:04:08,215 - root - INFO - Algo binary_search step 141 current loss 1.513210, current_train_items 4544.
2023-03-17 10:04:08,285 - root - INFO - Algo binary_search step 142 current loss 2.753431, current_train_items 4576.
2023-03-17 10:04:08,370 - root - INFO - Algo binary_search step 143 current loss 3.698956, current_train_items 4608.
2023-03-17 10:04:08,483 - root - INFO - Algo binary_search step 144 current loss 4.164170, current_train_items 4640.
2023-03-17 10:04:08,497 - root - INFO - Algo binary_search step 145 current loss 0.903493, current_train_items 4672.
2023-03-17 10:04:08,532 - root - INFO - Algo binary_search step 146 current loss 2.017942, current_train_items 4704.
2023-03-17 10:04:08,601 - root - INFO - Algo binary_search step 147 current loss 3.391451, current_train_items 4736.
2023-03-17 10:04:08,684 - root - INFO - Algo binary_search step 

2023-03-17 10:04:17,072 - root - INFO - Algo binary_search step 204 current loss 3.773260, current_train_items 6560.
2023-03-17 10:04:17,087 - root - INFO - Algo binary_search step 205 current loss 1.066801, current_train_items 6592.
2023-03-17 10:04:17,123 - root - INFO - Algo binary_search step 206 current loss 1.166779, current_train_items 6624.
2023-03-17 10:04:17,196 - root - INFO - Algo binary_search step 207 current loss 2.960557, current_train_items 6656.
2023-03-17 10:04:17,279 - root - INFO - Algo binary_search step 208 current loss 2.679141, current_train_items 6688.
2023-03-17 10:04:17,388 - root - INFO - Algo binary_search step 209 current loss 3.843515, current_train_items 6720.
2023-03-17 10:04:17,402 - root - INFO - Algo binary_search step 210 current loss 0.945793, current_train_items 6752.
2023-03-17 10:04:17,435 - root - INFO - Algo binary_search step 211 current loss 1.361296, current_train_items 6784.
2023-03-17 10:04:17,506 - root - INFO - Algo binary_search step 

2023-03-17 10:04:23,484 - root - INFO - Algo binary_search step 271 current loss 2.417202, current_train_items 8704.
2023-03-17 10:04:23,554 - root - INFO - Algo binary_search step 272 current loss 3.040916, current_train_items 8736.
2023-03-17 10:04:23,640 - root - INFO - Algo binary_search step 273 current loss 3.035372, current_train_items 8768.
2023-03-17 10:04:23,808 - root - INFO - Algo binary_search step 274 current loss 4.233149, current_train_items 8800.
2023-03-17 10:04:23,822 - root - INFO - Algo binary_search step 275 current loss 1.457255, current_train_items 8832.
2023-03-17 10:04:23,856 - root - INFO - Algo binary_search step 276 current loss 2.291582, current_train_items 8864.
2023-03-17 10:04:23,924 - root - INFO - Algo binary_search step 277 current loss 2.449690, current_train_items 8896.
2023-03-17 10:04:24,006 - root - INFO - Algo binary_search step 278 current loss 2.946670, current_train_items 8928.
2023-03-17 10:04:24,115 - root - INFO - Algo binary_search step 

2023-03-17 10:04:30,068 - root - INFO - Algo binary_search step 338 current loss 3.390095, current_train_items 10848.
2023-03-17 10:04:30,175 - root - INFO - Algo binary_search step 339 current loss 3.460988, current_train_items 10880.
2023-03-17 10:04:30,190 - root - INFO - Algo binary_search step 340 current loss 0.672797, current_train_items 10912.
2023-03-17 10:04:30,224 - root - INFO - Algo binary_search step 341 current loss 1.010101, current_train_items 10944.
2023-03-17 10:04:30,296 - root - INFO - Algo binary_search step 342 current loss 2.547439, current_train_items 10976.
2023-03-17 10:04:30,381 - root - INFO - Algo binary_search step 343 current loss 2.866984, current_train_items 11008.
2023-03-17 10:04:30,492 - root - INFO - Algo binary_search step 344 current loss 4.187172, current_train_items 11040.
2023-03-17 10:04:30,505 - root - INFO - Algo binary_search step 345 current loss 0.880275, current_train_items 11072.
2023-03-17 10:04:30,539 - root - INFO - Algo binary_sear

2023-03-17 10:04:38,911 - root - INFO - Algo binary_search step 402 current loss 2.297888, current_train_items 12896.
2023-03-17 10:04:38,998 - root - INFO - Algo binary_search step 403 current loss 2.882795, current_train_items 12928.
2023-03-17 10:04:39,107 - root - INFO - Algo binary_search step 404 current loss 3.857188, current_train_items 12960.
2023-03-17 10:04:39,122 - root - INFO - Algo binary_search step 405 current loss 0.466323, current_train_items 12992.
2023-03-17 10:04:39,155 - root - INFO - Algo binary_search step 406 current loss 0.688842, current_train_items 13024.
2023-03-17 10:04:39,225 - root - INFO - Algo binary_search step 407 current loss 2.567268, current_train_items 13056.
2023-03-17 10:04:39,307 - root - INFO - Algo binary_search step 408 current loss 2.775641, current_train_items 13088.
2023-03-17 10:04:39,422 - root - INFO - Algo binary_search step 409 current loss 3.407589, current_train_items 13120.
2023-03-17 10:04:39,436 - root - INFO - Algo binary_sear

2023-03-17 10:04:45,523 - root - INFO - Algo binary_search step 469 current loss 3.134935, current_train_items 15040.
2023-03-17 10:04:45,537 - root - INFO - Algo binary_search step 470 current loss 1.265869, current_train_items 15072.
2023-03-17 10:04:45,572 - root - INFO - Algo binary_search step 471 current loss 0.863863, current_train_items 15104.
2023-03-17 10:04:45,641 - root - INFO - Algo binary_search step 472 current loss 1.681667, current_train_items 15136.
2023-03-17 10:04:45,722 - root - INFO - Algo binary_search step 473 current loss 1.818089, current_train_items 15168.
2023-03-17 10:04:45,840 - root - INFO - Algo binary_search step 474 current loss 3.439990, current_train_items 15200.
2023-03-17 10:04:45,853 - root - INFO - Algo binary_search step 475 current loss 0.725968, current_train_items 15232.
2023-03-17 10:04:45,885 - root - INFO - Algo binary_search step 476 current loss 0.716647, current_train_items 15264.
2023-03-17 10:04:45,953 - root - INFO - Algo binary_sear

In [14]:
import pickle

def restore_model(model, file_name):
    """Restore model from `file_name`."""
    with open(file_name, 'rb') as f:
        restored_state = pickle.load(f)
        restored_params = restored_state['params']
        model.params = hk.data_structures.merge(restored_params)
        model.opt_state = restored_state['opt_state']

def save_model(model, file_name):
    """Save model (processor weights only) to `file_name`."""
    to_save = {'params': model.params, 'opt_state': model.opt_state}
    with open(file_name, 'wb') as f:
        pickle.dump(to_save, f)

In [15]:
save_model(eval_model, 'eval_model_1e-3.pkl')

In [8]:
restore_model(eval_model, 'eval_model_asdf.pkl')

In [16]:
algo_idx = 0
common_extras = {}

new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
    val_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    val_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

val_stats

{'return': 0.912353515625, 'score': 0.912353515625}

In [17]:
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
    test_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    test_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

test_stats

{'return': 0.70068359375, 'score': 0.70068359375}

In [17]:
feedback = next(val_samplers[0])

In [18]:
list(filter(lambda x: x[:2] != '__', dir(feedback)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_make',
 '_replace',
 'count',
 'features',
 'index',
 'outputs']

In [28]:
list(filter(lambda x: x[:2] != '__', dir(feedback.features)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_make',
 '_replace',
 'count',
 'hints',
 'index',
 'inputs',
 'lengths']

In [42]:
feedback.features

Features(inputs=(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 16)), DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 16)), DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)), DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 16))), hints=(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(5, 32, 16)), DataPoint(name="high",	location=node,	type=mask_one,	data=Array(5, 32, 16)), DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(5, 32, 16))), lengths=array([5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.,
       5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5.]))

In [37]:
feedback.features.inputs

(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 16)),
 DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)),
 DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 16)))

In [43]:
feedback.features.inputs[2].data

array([0.14317903, 0.07972956, 0.67909305, 0.98567252, 0.93329147,
       0.28345592, 0.43966783, 0.07153851, 0.12399225, 0.76594674,
       0.36091068, 0.97857869, 0.28576997, 0.40300221, 0.26544856,
       0.52667835, 0.12452213, 0.07972956, 0.93578709, 0.51574709,
       0.3870473 , 0.58158791, 0.68471948, 0.92190682, 0.18154095,
       0.38265265, 0.65144341, 0.42893379, 0.05218584, 0.86252671,
       0.059794  , 0.44887807])

In [48]:
feedback.features.hints

(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="high",	location=node,	type=mask_one,	data=Array(5, 32, 16)),
 DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(5, 32, 16)))

In [20]:
np.argmin(feedback.features.inputs[0].data, axis=1)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [21]:
np.where(feedback.outputs[0].data == 1)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 array([ 2,  1, 11, 15, 14,  3,  9,  0,  3, 11,  5, 15,  6,  5,  6,  6,  3,
         1, 14,  9,  7, 12, 10, 13,  2,  5,  9, 10,  1, 14,  1,  8]))

In [18]:
test_sample_counts[0]

2048

In [21]:
new_rng_key, rng_key = jax.random.split(rng_key)
batched_msgs = get_msgs(
    test_samplers[0],
    functools.partial(eval_model.predict, algorithm_index=0),
#     test_sample_counts[0],  # EXPLODING MEMORY LOL
    32*4,
    new_rng_key,
    0.01)

batched_msgs.shape

0
32
64
96


(128, 6, 64, 64, 128)

In [57]:
file_name = 'binary_search_val_msgs.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(msgs, f)

In [22]:
import sys
sys.getsizeof(batched_msgs)

1610612912

In [44]:
test_sample_counts[0]

2048

In [46]:
sys.getsizeof(batched_msgs) / 1000000

25.906304