In [1]:
import copy
%load_ext autoreload
%autoreload 2
import jax
import clrs
import numpy as np

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))



No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [23]:
# If you don't want BipartiteMatching, just pass empty generator list and
# length separately

train_sampler_spec = {
    'num_samples': 100,
    'batch_size':  32,
    'schematics':  [
        {
        'length': 16,
        'proportion': 1,
        }
    ]
}

test_sampler_spec = {
    'num_samples': 40,
    'batch_size':  40,
    'schematics':  [
        {
            'length': 64,
            'proportion': 1,
        }
    ]
}


def samplers(sampler_spec, **kwargs):
    batch_size = sampler_spec.get('batch_size', 1)
    num_samples = sampler_spec['num_samples']
    if batch_size > num_samples:
        batch_size = num_samples

    def _iterate_sampler(sampler, batch_size):
        while True:
            yield sampler.next(batch_size)

    sampler, spec = clrs.build_sampler(
        name = 'online_testing',
        sampler_spec = sampler_spec,
        **kwargs)  # number of nodes

    sampler = _iterate_sampler(sampler, batch_size = batch_size)
    return sampler, spec


train_sampler, spec = samplers(train_sampler_spec)
test_sampler, _ = samplers(test_sampler_spec)

In [24]:
sample = next(test_sampler)
sample.features.inputs[0].data[0]
sample.features.hints[0].data[0]

array([[0., 1., 1., ..., 1., 1., 1.],
       [0., 0., 1., ..., 1., 1., 1.],
       [1., 0., 0., ..., 1., 0., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [0., 0., 1., ..., 1., 0., 1.],
       [0., 1., 0., ..., 0., 0., 1.]])

In [25]:
def define_model(spec, train_sampler, model = "mpnn"):
    if model == "mpnn":
        processor_factory = clrs.get_processor_factory('mpnn', use_ln = True,
                                                       nb_triplet_fts = 4)  #use_ln => use layer norm
    elif model == "gat":
        processor_factory = clrs.get_processor_factory('gat', use_ln = True, nb_heads = 4, nb_triplet_fts = 4)

    elif model == "mpnndoublemax":
        processor_factory = clrs.get_processor_factory('mpnndoublemax', use_ln = True,
                                                       nb_triplet_fts = 0)  #use_ln => use layer norm

    elif model == "gmpnn":
        processor_factory = clrs.get_processor_factory('gmpnn', use_ln = True,
                                                       nb_triplet_fts = 4)  #use_ln => use layer norm
    elif model == "pgn":
        processor_factory = clrs.get_processor_factory('pgn', use_ln = True,
                                                       nb_triplet_fts = 32)  #use_ln => use layer norm
    elif model == "triplet_pgn_mask":
        processor_factory = clrs.get_processor_factory('triplet_pgn_mask', use_ln = True,
                                                       nb_triplet_fts = 32)  #use_ln => use layer norm

    model_params = dict(
        processor_factory = processor_factory,  # contains the processor_factory
        hidden_dim = 32,  # TODO put back to 32 if no difference
        encode_hints = True,
        decode_hints = True,
        #decode_diffs=False,
        #hint_teacher_forcing_noise=1.0,
        hint_teacher_forcing = 1.0,
        use_lstm = False,
        learning_rate = 0.001,
        checkpoint_path = '/tmp/checkpt',
        freeze_processor = False,  # Good for post step
        dropout_prob = 0.5,
        # nb_msg_passing_steps=3,
    )

    dummy_trajectory = next(train_sampler)  # jax needs a trajectory that is plausible looking to init

    model = clrs.models.BaselineModel(
        spec = spec,
        dummy_trajectory = dummy_trajectory,
        **model_params
    )

    model.init(dummy_trajectory.features, 1234)  # 1234 is a random seed

    return model


model = define_model(spec, train_sampler, "mpnn")

In [26]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
# step = 0
#
# while step <= 1:
#     feedback, test_feedback = next(train_sampler), next(test_sampler)
#     rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#     cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#     rng_key = new_rng_key
#     if step % 10 == 0:
#         print(step)
#     step += 1



In [27]:
def train(model, epochs, train_sampler, test_sampler):
    step = 0
    rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

    while step <= epochs:
        feedback, test_feedback = next(train_sampler), next(test_sampler)
        # TODO remove - testing if uses hints on tests
        # shape = test_feedback.features.hints[0].data[0].shape
        # test_feedback.features.hints[0].data = test_feedback.features.hints[0].data[0, :, :].reshape((1, *shape))

        rng_key, new_rng_key = jax.random.split(rng_key)  # jax needs new random seed at step
        cur_loss = model.feedback(rng_key, feedback)  # loss is contained in model somewhere
        rng_key = new_rng_key
        if step % 50 == 0:
            predictions_val, _ = model.predict(rng_key, feedback.features)
            out_val = clrs.evaluate(feedback.outputs, predictions_val)
            predictions, _ = model.predict(rng_key, test_feedback.features)
            out = clrs.evaluate(test_feedback.outputs, predictions)
            print(
                f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')  # here, val accuracy is actually training accuracy, not great but is example
        step += 1
    return model

In [28]:
model = train(model, 100, train_sampler, test_sampler)

step = 0 | loss = 2.1660819053649902 | val_acc = 0.47500887513160706 | test_acc = 0.513396143913269
step = 50 | loss = 0.4169905185699463 | val_acc = 0.006059996783733368 | test_acc = 0.004073990974575281
step = 100 | loss = 0.27615106105804443 | val_acc = 0.003404786344617605 | test_acc = 0.002677247626706958


In [29]:
test_feedback = next(test_sampler)


In [30]:
# test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)

In [31]:
predictions['value'].data[2]

Array([ 0.98087317, -0.07416291,  0.98156744,  0.9806936 ,  0.98143446,
       -0.06535578, -0.07103633,  0.9777561 , -0.06063077,  0.98111415,
       -0.06765325,  0.9769704 ,  0.98019904, -0.06847873, -0.0817865 ,
        0.9802213 ,  0.97885627,  0.9815149 , -0.06642177,  0.9826355 ,
        0.979281  , -0.07321703, -0.07755595, -0.06382804, -0.07131772,
       -0.06748506, -0.06530993,  0.97771525,  0.9815438 , -0.07130258,
        0.97742194,  0.9817032 , -0.06612809, -0.07532345,  0.9818983 ,
       -0.06972069,  0.97804874, -0.07019711,  0.98192054, -0.06860285,
       -0.07033233,  0.9796916 , -0.06734654, -0.07062507, -0.07609523,
        0.9808937 , -0.06427146,  0.9773147 ,  0.9782223 , -0.0806498 ,
       -0.06241468, -0.06299652,  0.9802721 ,  0.98304   , -0.06592128,
        0.9815593 , -0.07300188,  0.97947204,  0.98258513,  0.979904  ,
        0.98187655,  0.9812439 ,  0.9805051 ,  0.9758394 ], dtype=float32)