# Run training of CLRS algorithm

Copied main function from clrs_train.py. Helper functions are located in clrs_train_funcs.py

In [46]:
import functools
import os
import shutil
from typing import Any, Dict, List

import logging
import clrs
import jax
import jax.numpy as jnp
import numpy as np
import requests
import tensorflow as tf
import haiku as hk

import model
import flags
from clrs_train_funcs import *

In [47]:
!nvidia-smi

Fri Mar 24 18:59:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.65       Driver Version: 527.56       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   46C    P8    16W / 147W |   8022MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [48]:
FLAGS = flags.FLAGS

In [49]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

In [50]:
if FLAGS.hint_mode == 'encoded_decoded':
    encode_hints = True
    decode_hints = True
elif FLAGS.hint_mode == 'decoded_only':
    encode_hints = False
    decode_hints = True
elif FLAGS.hint_mode == 'none':
    encode_hints = False
    decode_hints = False
else:
    raise ValueError(
        'Hint mode not in {encoded_decoded, decoded_only, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

rng = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

In [51]:
# Create samplers
(train_samplers,
 val_samplers, val_sample_counts,
 test_samplers, test_sample_counts,
 spec_list) = create_samplers(rng, train_lengths)

2023-03-24 18:59:44,917 - root - INFO - Creating samplers for algo binary_search
2023-03-24 18:59:44,917 - root - INFO - Creating samplers for algo binary_search
2023-03-24 18:59:45,837 - absl - INFO - Creating a dataset with 4096 samples.
2023-03-24 18:59:45,837 - absl - INFO - Creating a dataset with 4096 samples.
2023-03-24 18:59:46,046 - absl - INFO - 1000 samples created
2023-03-24 18:59:46,046 - absl - INFO - 1000 samples created
2023-03-24 18:59:46,254 - absl - INFO - 2000 samples created
2023-03-24 18:59:46,254 - absl - INFO - 2000 samples created
2023-03-24 18:59:46,487 - absl - INFO - 3000 samples created
2023-03-24 18:59:46,487 - absl - INFO - 3000 samples created
2023-03-24 18:59:46,733 - absl - INFO - 4000 samples created
2023-03-24 18:59:46,733 - absl - INFO - 4000 samples created
2023-03-24 18:59:47,131 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skipping download.
2023-03-24 18:59:47,131 - root - INFO - Dataset found at /tmp/CLRS30/CLRS30_v1.0.0. Skippin

In [52]:
FLAGS.hidden_size = 128 # 32
FLAGS.msg_size = 32 # 32

FLAGS.train_steps = 10000 # 1000
FLAGS.eval_every = 50
FLAGS.test_every = 500

FLAGS.l1_weight = 0.01 # 0.001 

# l1_weight_fn_mult_start = 1
# l1_weight_fn_mult_end = 10
# l1_weight_fn = lambda step: FLAGS.l1_weight * (l1_weight_fn_mult_start + step / FLAGS.train_steps * (l1_weight_fn_mult_end - l1_weight_fn_mult_start))

l1_weight_fn = lambda step: FLAGS.l1_weight

In [53]:
# FLAGS.hidden_size = 8
# FLAGS.algorithms = ['dijkstra']

In [54]:
processor_factory = model.get_processor_factory(
    FLAGS.processor_type,
    use_ln=FLAGS.use_ln,
    nb_triplet_fts=FLAGS.nb_triplet_fts,
    nb_heads=FLAGS.nb_heads
)
model_params = dict(
    processor_factory=processor_factory,
    hidden_dim=FLAGS.hidden_size,
    msg_dim=FLAGS.msg_size,
    encode_hints=encode_hints,
    decode_hints=decode_hints,
    encoder_init=FLAGS.encoder_init,
    use_lstm=FLAGS.use_lstm,
    learning_rate=FLAGS.learning_rate,
    grad_clip_max_norm=FLAGS.grad_clip_max_norm,
    checkpoint_path=FLAGS.checkpoint_path,
    freeze_processor=FLAGS.freeze_processor,
    dropout_prob=FLAGS.dropout_prob,
    hint_teacher_forcing=FLAGS.hint_teacher_forcing,
    hint_repred_mode=FLAGS.hint_repred_mode,
    nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
    l1_weight=FLAGS.l1_weight
)

eval_model = model.BaselineMsgModel(
    spec=spec_list,
    dummy_trajectory=[next(t) for t in val_samplers],
    **model_params
)
# # we will never used chunked training
# if FLAGS.chunked_training:
#     train_model = clrs.models.BaselineModelChunked(
#         spec=spec_list,
#         dummy_trajectory=[next(t) for t in train_samplers],
#         **model_params
#     )
# else:
#     train_model = eval_model
train_model = eval_model

In [55]:
# Training loop.
best_score = -1.0
current_train_items = [0] * len(FLAGS.algorithms)
step = 0
next_eval = 0
# Make sure scores improve on first step, but not overcome best score
# until all algos have had at least one evaluation.
val_scores = [-99999.9] * len(FLAGS.algorithms)
length_idx = 0

while step < FLAGS.train_steps:
    feedback_list = [next(t) for t in train_samplers]

    # Initialize model.
    if step == 0:
        all_features = [f.features for f in feedback_list]
        if FLAGS.chunked_training:
            # We need to initialize the model with samples of all lengths for
            # all algorithms. Also, we need to make sure that the order of these
            # sample sizes is the same as the order of the actual training sizes.
            all_length_features = [all_features] + [
                [next(t).features for t in train_samplers]
                for _ in range(len(train_lengths))]
            train_model.init(all_length_features[:-1], FLAGS.seed + 1)
        else:
            train_model.init(all_features, FLAGS.seed + 1)

    train_model.l1_weight = l1_weight_fn(step)
    # Training step.
    for algo_idx in range(len(train_samplers)):
        feedback = feedback_list[algo_idx]
        rng_key, new_rng_key = jax.random.split(rng_key)
        if FLAGS.chunked_training:
            # In chunked training, we must indicate which training length we are
            # using, so the model uses the correct state.
            length_and_algo_idx = (length_idx, algo_idx)
        else:
            # In non-chunked training, all training lengths can be treated equally,
            # since there is no state to maintain between batches.
            length_and_algo_idx = algo_idx
        cur_loss = train_model.feedback(
            rng_key, feedback, length_and_algo_idx)
        rng_key = new_rng_key

        if FLAGS.chunked_training:
            examples_in_chunk = np.sum(feedback.features.is_last).item()
        else:
            examples_in_chunk = len(feedback.features.lengths)
        current_train_items[algo_idx] += examples_in_chunk
        if step % 50 == 0:
            logging.info('Algo %s step %i current loss %f, current_train_items %i, l1_weight %f.',
                        FLAGS.algorithms[algo_idx], step,
                        cur_loss, current_train_items[algo_idx], train_model.l1_weight)

    # Periodically evaluate model
    if step >= next_eval:
        eval_model.params = train_model.params
        for algo_idx in range(len(train_samplers)):
            common_extras = {'examples_seen': current_train_items[algo_idx],
                             'step': step,
                             'algorithm': FLAGS.algorithms[algo_idx]}

            # Validation info.
            new_rng_key, rng_key = jax.random.split(rng_key)
            val_stats = collect_and_eval(
                val_samplers[algo_idx],
                functools.partial(eval_model.predict,
                                  algorithm_index=algo_idx),
                val_sample_counts[algo_idx],
                new_rng_key,
                extras=common_extras)
            logging.info('(val) algo %s step %d: %s',
                         FLAGS.algorithms[algo_idx], step, val_stats)
            val_scores[algo_idx] = val_stats['score']

        next_eval += FLAGS.eval_every

        # If best total score, update best checkpoint.
        # Also save a best checkpoint on the first step.
        msg = (f'best avg val score was '
               f'{best_score/len(FLAGS.algorithms):.3f}, '
               f'current avg val score is {np.mean(val_scores):.3f}, '
               f'val scores are: ')
        msg += ', '.join(
            ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
        if (sum(val_scores) > best_score) or step == 0:
            best_score = sum(val_scores)
            logging.info('Checkpointing best model, %s', msg)
            train_model.save_model('best.pkl')
        else:
            logging.info('Not saving new best model, %s', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

logging.info('Checkpointing final model, %s', msg)
train_model.save_model('final.pkl')

logging.info('Restoring best model from checkpoint...')
eval_model.restore_model('best.pkl', only_load_processor=False)
# logging.info('Restoring final model from checkpoint...')
# eval_model.restore_model('final.pkl', only_load_processor=False)

2023-03-24 19:00:03,496 - root - INFO - Algo binary_search step 0 current loss 7.124803, current_train_items 32, l1_weight 0.010000.
2023-03-24 19:00:03,496 - root - INFO - Algo binary_search step 0 current loss 7.124803, current_train_items 32, l1_weight 0.010000.
2023-03-24 19:00:05,187 - root - INFO - (val) algo binary_search step 0: {'return': 0.127197265625, 'score': 0.127197265625, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-24 19:00:05,187 - root - INFO - (val) algo binary_search step 0: {'return': 0.127197265625, 'score': 0.127197265625, 'examples_seen': 32, 'step': 0, 'algorithm': 'binary_search'}
2023-03-24 19:00:05,188 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.127, val scores are: binary_search: 0.127
2023-03-24 19:00:05,188 - root - INFO - Checkpointing best model, best avg val score was -1.000, current avg val score is 0.127, val scores are: binary_search: 0.127
2023-03-24 19:00:34,711 - ro

In [56]:
import pickle

def restore_model(model, file_name):
    """Restore model from `file_name`."""
    with open(file_name, 'rb') as f:
        restored_state = pickle.load(f)
        restored_params = restored_state['params']
        model.params = hk.data_structures.merge(restored_params)
        model.opt_state = restored_state['opt_state']

def save_model(model, file_name):
    """Save model (processor weights only) to `file_name`."""
    to_save = {'params': model.params, 'opt_state': model.opt_state}
    with open(file_name, 'wb') as f:
        pickle.dump(to_save, f)

In [57]:
# save_model(eval_model, 'eval_model_1e-3.pkl')

In [58]:
# restore_model(eval_model, 'eval_model_asdf.pkl')

In [59]:
algo_idx = 0
common_extras = {}

new_rng_key, rng_key = jax.random.split(rng_key)
val_stats = collect_and_eval(
    val_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    val_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(val_stats)

{'return': 0.974609375, 'score': 0.974609375}


In [60]:
new_rng_key, rng_key = jax.random.split(rng_key)
test_stats = collect_and_eval(
    test_samplers[algo_idx],
    functools.partial(eval_model.predict, algorithm_index=algo_idx),
    test_sample_counts[algo_idx],
    new_rng_key,
    extras=common_extras)

print(test_stats)

{'return': 0.8212890625, 'score': 0.8212890625}


In [61]:
feedback = next(train_samplers[0])

In [62]:
list(filter(lambda x: x[:2] != '__', dir(feedback)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_fields_defaults',
 '_make',
 '_replace',
 'count',
 'features',
 'index',
 'outputs']

In [63]:
list(filter(lambda x: x[:2] != '__', dir(feedback.features)))

['_asdict',
 '_field_defaults',
 '_fields',
 '_fields_defaults',
 '_make',
 '_replace',
 'count',
 'hints',
 'index',
 'inputs',
 'lengths']

In [64]:
feedback.features

Features(inputs=(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 4)), DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 4)), DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)), DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 4))), hints=(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(3, 32, 4)), DataPoint(name="high",	location=node,	type=mask_one,	data=Array(3, 32, 4)), DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(3, 32, 4))), lengths=array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]))

In [65]:
feedback.features.inputs

(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 4)),
 DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 4)),
 DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)),
 DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 4)))

In [66]:
feedback.features.inputs[0].data

array([[0.07944593, 0.41586608, 0.71447369, 0.77809141],
       [0.08085861, 0.13429643, 0.27291493, 0.41807869],
       [0.12472272, 0.13972125, 0.542758  , 0.99247434],
       [0.48136395, 0.58456102, 0.72940445, 0.91156752],
       [0.16545865, 0.35328275, 0.69469487, 0.78535665],
       [0.07005675, 0.33551755, 0.53910825, 0.60737041],
       [0.22549218, 0.50906486, 0.73207958, 0.8340708 ],
       [0.44082032, 0.80129438, 0.90084925, 0.95391742],
       [0.02581133, 0.39771323, 0.7993167 , 0.89175966],
       [0.06990053, 0.27263508, 0.80853305, 0.81788009],
       [0.08547658, 0.135077  , 0.14821431, 0.32904758],
       [0.04718299, 0.69072433, 0.81743709, 0.92777011],
       [0.43525787, 0.66466334, 0.83766448, 0.96035262],
       [0.3995087 , 0.4379399 , 0.56940683, 0.72093139],
       [0.10209387, 0.61299073, 0.6903771 , 0.74456705],
       [0.66177456, 0.68484313, 0.6915208 , 0.81013657],
       [0.66159714, 0.77898996, 0.94729679, 0.98304125],
       [0.03490931, 0.05979691,

In [67]:
feedback.features.inputs[1].data

array([[0.14505245, 0.50669634, 0.69746885, 0.99827601],
       [0.04907691, 0.55939811, 0.60407716, 0.69362913],
       [0.22850282, 0.32995465, 0.45680348, 0.4631161 ],
       [0.60183918, 0.61901712, 0.75530681, 0.84752412],
       [0.06522265, 0.49855728, 0.71561711, 0.74078881],
       [0.13983012, 0.26721134, 0.34307139, 0.61253297],
       [0.44517601, 0.59939087, 0.70310817, 0.71624527],
       [0.21499737, 0.22490709, 0.2290317 , 0.27179855],
       [0.0659541 , 0.36794665, 0.64325931, 0.66634537],
       [0.00141337, 0.02005159, 0.42050811, 0.75774222],
       [0.47937064, 0.76829195, 0.77577855, 0.89999182],
       [0.11322815, 0.4436976 , 0.45798428, 0.73927385],
       [0.02718076, 0.39560593, 0.47574331, 0.57584161],
       [0.25199846, 0.42218847, 0.66165925, 0.75193187],
       [0.17945598, 0.19686362, 0.23018133, 0.98835444],
       [0.03165351, 0.19228618, 0.82584994, 0.9468503 ],
       [0.17438562, 0.74685971, 0.74929382, 0.85605148],
       [0.10119071, 0.32314454,

In [68]:
feedback.features.inputs[2].data

array([0.90789178, 0.76675769, 0.46466379, 0.21048436, 0.19953704,
       0.71214855, 0.72076151, 0.11210543, 0.72746734, 0.47947478,
       0.50966803, 0.45263618, 0.77380012, 0.48942764, 0.96081892,
       0.95727694, 0.10890818, 0.5357713 , 0.80139012, 0.57620145,
       0.93474738, 0.89697058, 0.57118694, 0.48263026, 0.2236446 ,
       0.6904246 , 0.91944913, 0.11814233, 0.64556645, 0.37133494,
       0.4398061 , 0.70607906])

In [69]:
feedback.features.inputs[3].data

array([[0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.],
       [0., 0., 1., 2.]])

In [70]:
A = feedback.features.inputs[2].data
adj = feedback.features.inputs[3].data

In [71]:
jnp.where(jnp.expand_dims((A != 0), 1) & (adj == 0))

(Array([ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,
         8,  9,  9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16,
        17, 17, 18, 18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 24, 24, 25,
        25, 26, 26, 27, 27, 28, 28, 29, 29, 30, 30, 31, 31], dtype=int32),
 Array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
        0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1],      dtype=int32))

In [72]:
feedback.features.hints

(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(3, 32, 4)),
 DataPoint(name="high",	location=node,	type=mask_one,	data=Array(3, 32, 4)),
 DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(3, 32, 4)))

In [73]:
np.argmin(feedback.features.inputs[0].data, axis=1)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [74]:
np.where(feedback.outputs[0].data == 1)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 array([3, 3, 3, 0, 1, 3, 3, 0, 3, 3, 1, 2, 3, 2, 3, 3, 0, 3, 3, 1, 3, 3,
        1, 3, 1, 2, 3, 0, 2, 3, 3, 3]))

In [75]:
test_sample_counts[0]

2048

In [76]:
feedback.features

Features(inputs=(DataPoint(name="pos",	location=node,	type=scalar,	data=Array(32, 4)), DataPoint(name="key",	location=node,	type=scalar,	data=Array(32, 4)), DataPoint(name="target",	location=graph,	type=scalar,	data=Array(32,)), DataPoint(name="pred",	location=node,	type=pointer,	data=Array(32, 4))), hints=(DataPoint(name="low",	location=node,	type=mask_one,	data=Array(3, 32, 4)), DataPoint(name="high",	location=node,	type=mask_one,	data=Array(3, 32, 4)), DataPoint(name="mid",	location=node,	type=mask_one,	data=Array(3, 32, 4))), lengths=array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]))

In [77]:
def _concat(dps, axis):
    return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)

def get_msgs(sampler, predict_fn, sample_count, rng_key, sample_prob=0.001, threshold_max_elem=0.001):
    """Get messages from model.
    
    CAUTION: size of msgs can get large very quickly, so beware when
        running with a large number of samples.
    Use sample_prob to reduce the number of messages that are saved
    by randomly sampling messages
    
    Note: we can only perform symbolic regression on a maximum of
    ~5000 messages anyways, so we can use sample_prob to get to this threshold
    """
    processed_samples = 0
    msgs = []
    while processed_samples < sample_count:
        feedback = next(sampler)
        batch_size = feedback.outputs[0].data.shape[0]
        new_rng_key, rng_key = jax.random.split(rng_key)
        _, _, cur_msgs, cur_input_msg, cur_input_algo = predict_fn(new_rng_key, feedback.features)
        
        print(cur_msgs.shape, cur_msgs.shape[-1], cur_input_msg.shape[-1], cur_input_algo.shape[-1])
        
        cur_msgs = cur_msgs.reshape(-1, cur_msgs.shape[-1])
        cur_input_msg = cur_input_msg.reshape(-1, cur_input_msg.shape[-1])
        cur_input_algo = cur_input_algo.reshape(-1, cur_input_algo.shape[-1])
        cur_msg_concat = jnp.concatenate((cur_msgs, cur_input_msg, cur_input_algo), axis=-1)
        
        N = cur_msg_concat.shape[0]
        max_element_mask = jnp.max(jnp.abs(cur_msgs), axis=1) > threshold_max_elem
        print("Thresholding:", jnp.sum(max_element_mask), "of", N)
        cur_msg_concat = cur_msg_concat[max_element_mask]

        # sampled_rate = jnp.sum(max_element_mask) / cur_msg_concat.shape[0]
        # new_sample_prob = max(1., sample_prob / sampled_rate)
        # print("New sample prob:", new_sample_prob)

        new_rng_key, rng_key = jax.random.split(rng_key)
        mask = jax.random.choice(new_rng_key,
                                 a=jnp.array([False, True]),
                                 shape=(cur_msg_concat.shape[0],),
                                 p=jnp.array([1 - sample_prob, sample_prob]),
                                 replace=True,)
        cur_msg_concat = cur_msg_concat[mask]
        print("Random sampling:", cur_msg_concat.shape[0], "of", N)
        
        msgs.append(cur_msg_concat)
        processed_samples += batch_size
    msgs = _concat(msgs, axis=0)
    
    return msgs


In [78]:
new_rng_key, rng_key = jax.random.split(rng_key)
batched_msgs = get_msgs(
    test_samplers[0],
    functools.partial(eval_model.predict, algorithm_index=0),
#     test_sample_counts[0],  # EXPLODING MEMORY LOL
    32*4,
    new_rng_key,
    0.1)

batched_msgs.shape

(32, 6, 64, 64, 32) 32 512 13
Thresholding: 36480 of 786432
Random sampling: 3676 of 786432
(32, 6, 64, 64, 32) 32 512 13
Thresholding: 36480 of 786432
Random sampling: 3705 of 786432
(32, 6, 64, 64, 32) 32 512 13
Thresholding: 36480 of 786432
Random sampling: 3681 of 786432
(32, 6, 64, 64, 32) 32 512 13
Thresholding: 36480 of 786432
Random sampling: 3638 of 786432


(14700, 557)

In [79]:
jnp.sum(jnp.linalg.norm(batched_msgs[:,:32],axis=1) < 0.01)

Array(11118, dtype=int32)

In [80]:
jnp.sum(jnp.max(jnp.abs(batched_msgs[:,:32]),axis=1) < 0.001)

Array(0, dtype=int32)

In [81]:
batched_msgs[0][:32], batched_msgs[1][:32], batched_msgs[2][:32]

(array([ 0.00382642, -0.04599392, -0.03524892,  0.00504001, -0.03600484,
        -0.04021579,  0.07054279,  0.20388949, -0.02077518,  0.01939277,
         0.11014283, -0.07487595, -0.02024115, -0.05271044, -0.04856784,
        -0.0711588 , -0.05519877, -0.07215293, -0.06703275,  0.14277695,
        -0.6742578 , -0.12658611, -0.07090352, -0.00759287,  0.05890476,
         0.04787841,  0.04071072, -0.05187277, -0.00110583,  0.03360759,
        -0.02287874, -0.00142055], dtype=float32),
 array([-0.00133758, -0.04075012, -0.03303221,  0.00242467, -0.00619409,
        -0.02537211,  0.05130481,  0.13965325, -0.01104111,  0.01242436,
         0.07080761, -0.07218133, -0.01166432, -0.04023448, -0.03310691,
        -0.05876958, -0.03926387, -0.05544468, -0.05155095,  0.10095557,
        -0.58233196, -0.10178439, -0.04886349, -0.00464496,  0.04269987,
         0.0484421 ,  0.03763483, -0.03981557,  0.00494035,  0.01634523,
        -0.0141501 ,  0.00116069], dtype=float32),
 array([-2.0105876e-03

In [82]:
file_name = 'binary_search_val_msgs.pkl'
with open(file_name, 'wb') as f:
    pickle.dump(batched_msgs, f)

In [83]:
import sys
sys.getsizeof(batched_msgs)

32751728

In [84]:
test_sample_counts[0]

2048

In [85]:
sys.getsizeof(batched_msgs) / 1000000

32.751728