In [2]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
train_sampler, spec = clrs.build_sampler(
    name='bipartite_matching',
    num_samples=100,
    length=16,
    weighted=True) # number of nodes

test_sampler, spec = clrs.build_sampler(
    name='bipartite_matching',
    num_samples=100,
    length=64,
    weighted=True) # testing on much larger
# TODO how do you know aren't generating same graphs? (well not possible here since different size but in general?)

pprint.pprint(spec) # spec is the algorithm specification, all the probes

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=100) # full batch for the test set


{'A': ('input', 'edge', 'scalar'),
 'A_h': ('hint', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'adj_h': ('hint', 'edge', 'mask'),
 'd': ('hint', 'node', 'scalar'),
 'in_matching': ('output', 'edge', 'mask'),
 'in_matching_h': ('hint', 'edge', 'mask'),
 'msk': ('hint', 'node', 'mask'),
 'phase': ('hint', 'graph', 'mask'),
 'pi': ('hint', 'node', 'pointer'),
 'pos': ('input', 'node', 'scalar'),
 's': ('input', 'node', 'mask_one'),
 't': ('input', 'node', 'mask_one'),
 'u': ('hint', 'node', 'mask_one')}


In [8]:
# Bipartite graph is weighted
next(train_sampler).features.inputs[1].data[0]

array([[0.        , 0.40309255, 0.82758018, 0.91295964, 0.30944842,
        0.0315634 , 0.06428168, 0.00208438, 0.65677349, 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.42245016, 0.        , 0.20672695, 0.        ,
        0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.38261076, 0.        , 0.        ,
        0.        , 0.36597594, 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.71865507,
        0.        , 0.        , 0.        , 0.        , 0.14134505,
        0.        , 0.        , 0.        ],
    

In [None]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True, nb_triplet_fts=0) #use_ln => use layer norm 
model_params = dict(
    processor_factory=processor_factory, # contains the processor_factory
    hidden_dim=32,
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    hint_teacher_forcing=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False, # Good for post step
    dropout_prob=0.0,
)

dummy_trajectory = next(train_sampler) # jax needs a trajectory that is plausible looking to init

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234) # 1234 is a random seed

In [None]:
step = 0

while step <= 200:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
  cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}') # here, val accuracy is actually training accuracy, not great but is example
  step += 1

step = 0 | loss = 18.544353485107422 | val_acc = 0.5307644605636597 | test_acc = 0.28374454379081726
step = 10 | loss = 8.695533752441406 | val_acc = 0.5333715081214905 | test_acc = 0.3007984757423401
step = 20 | loss = 6.970330238342285 | val_acc = 0.47899776697158813 | test_acc = 0.3936541974544525
step = 30 | loss = 6.004898548126221 | val_acc = 0.6611765027046204 | test_acc = 0.41111698746681213
step = 40 | loss = 5.4687700271606445 | val_acc = 0.577235758304596 | test_acc = 0.3613807260990143
step = 50 | loss = 4.589337348937988 | val_acc = 0.5734208226203918 | test_acc = 0.3562544584274292
step = 60 | loss = 4.119411468505859 | val_acc = 0.5963488817214966 | test_acc = 0.27703341841697693
step = 70 | loss = 3.780219316482544 | val_acc = 0.6741573214530945 | test_acc = 0.3141278624534607
step = 80 | loss = 3.339991331100464 | val_acc = 0.7116518616676331 | test_acc = 0.2799810469150543
step = 90 | loss = 3.2398462295532227 | val_acc = 0.7173252105712891 | test_acc = 0.444053679704