In [None]:
!pip install git+https://github.com/PLanza/johnsons-clrs-gnn.git@bf-dijkstra

Collecting git+https://github.com/PLanza/johnsons-clrs-gnn.git@bfs-dijkstra
  Cloning https://github.com/PLanza/johnsons-clrs-gnn.git (to revision bfs-dijkstra) to /tmp/pip-req-build-t___z1ki
  Running command git clone --filter=blob:none --quiet https://github.com/PLanza/johnsons-clrs-gnn.git /tmp/pip-req-build-t___z1ki
  Running command git checkout -b bfs-dijkstra --track origin/bfs-dijkstra
  Switched to a new branch 'bfs-dijkstra'
  Branch 'bfs-dijkstra' set up to track remote branch 'bfs-dijkstra' from 'origin'.
  Resolved https://github.com/PLanza/johnsons-clrs-gnn.git to commit 01aaf92337c0d2c9e98cd15e8acb87c1dee7cf3d
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
from clrs.examples import run
import numpy as np
import jax
import clrs
import functools
import random
import sys

In [None]:
def train_model(algorithms):

  seed = random.randint(0,1000)

  run.FLAGS([sys.argv])
  run.FLAGS.seed = seed

  run.FLAGS.algorithms = [algorithms]

  train_lengths = [int(x) for x in run.FLAGS.train_lengths]

  rng = np.random.RandomState(run.FLAGS.seed)
  rng_key = jax.random.PRNGKey(rng.randint(2**32))

  FLAGS = run.FLAGS
  print(FLAGS.algorithms)
  encode_hints = True
  decode_hints = True

  processor_factory = clrs.get_processor_factory(
      FLAGS.processor_type,
      use_ln=FLAGS.use_ln,
      nb_triplet_fts=FLAGS.nb_triplet_fts,
      nb_heads=FLAGS.nb_heads
  )
  model_params = dict(
      processor_factory=processor_factory,
      hidden_dim=FLAGS.hidden_size,
      encode_hints=encode_hints,
      decode_hints=decode_hints,
      encoder_init=FLAGS.encoder_init,
      use_lstm=FLAGS.use_lstm,
      learning_rate=FLAGS.learning_rate,
      grad_clip_max_norm=FLAGS.grad_clip_max_norm,
      checkpoint_path=FLAGS.checkpoint_path,
      freeze_processor=FLAGS.freeze_processor,
      dropout_prob=FLAGS.dropout_prob,
      hint_teacher_forcing=FLAGS.hint_teacher_forcing,
      hint_repred_mode=FLAGS.hint_repred_mode,
      nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
      )

  (train_samplers,
    val_samplers, val_sample_counts,
    test_samplers, test_sample_counts,
    spec_list) = run.create_samplers(rng, train_lengths)

  eval_model = clrs.models.BaselineModel(
      spec=spec_list,
      dummy_trajectory=[next(train_samplers[0])],
      **model_params
  )
  train_model = eval_model

  # Training loop.
  best_score = -1.0
  current_train_items = [0]
  step = 0
  next_eval = 0
  # Make sure scores improve on first step, but not overcome best score
  # until all algos have had at least one evaluation.
  val_scores = [-99999.9]
  length_idx = 0

  losses = { name: [] for name in FLAGS.algorithms }
  evals = { name: [] for name in FLAGS.algorithms }

  algo_idx = 0

  while step < FLAGS.train_steps:
    feedback_list = [next(train_samplers[algo_idx])]

    # Initialize model.
    if step == 0:
      all_features = [f.features for f in feedback_list]
      train_model.init(all_features, FLAGS.seed + 1)

    # Training step.
    feedback = feedback_list[0]
    rng_key, new_rng_key = jax.random.split(rng_key)
    length_and_algo_idx = algo_idx
    cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
    rng_key = new_rng_key
    losses[FLAGS.algorithms[algo_idx]].append(cur_loss.item())

    examples_in_chunk = len(feedback.features.lengths)
    current_train_items[algo_idx] += examples_in_chunk
    print(f'Algo {FLAGS.algorithms[algo_idx]} step {step} current loss {cur_loss}, current_train_items {current_train_items[algo_idx]}.')

    # Periodically evaluate model
    if step >= next_eval:
      eval_model.params = train_model.params
      for algo_idx in range(len(train_samplers)):
        common_extras = {'examples_seen': current_train_items[algo_idx],
                          'step': step,
                          'algorithm': FLAGS.algorithms[algo_idx]}

        # Validation info.
        new_rng_key, rng_key = jax.random.split(rng_key)
        val_stats = run.collect_and_eval(
            val_samplers[algo_idx],
            functools.partial(eval_model.predict, algorithm_index=algo_idx),
            val_sample_counts[algo_idx],
            new_rng_key,
            extras=common_extras)
        print(f'(val) algo {FLAGS.algorithms[algo_idx]} step {step}: {val_stats}')
        val_scores[algo_idx] = val_stats['score']
        evals[FLAGS.algorithms[algo_idx]].append(val_stats['score'])

      next_eval += FLAGS.eval_every

      # If best total score, update best checkpoint.
      # Also save a best checkpoint on the first step.
      msg = (f'best avg val score was '
              f'{best_score/len(FLAGS.algorithms):.3f}, '
              f'current avg val score is {np.mean(val_scores):.3f}, '
              f'val scores are: ')
      msg += ', '.join(
          ['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
      if (sum(val_scores) > best_score) or step == 0:
        best_score = sum(val_scores)
        print('Checkpointing best model, ', msg)
        train_model.save_model('best' + str(algorithms) + '.pkl')
      else:
        print('Not saving new best model, ', msg)

    step += 1
    length_idx = (length_idx + 1) % len(train_lengths)

  return eval_model

In [None]:
bf_model = train_model('bellman_ford')
dijkstra_model = train_model('dijkstra')

In [None]:
rng = np.random.RandomState(run.FLAGS.seed)

def get_test_sampler(algorithm):
  common_sampler_args = dict(
    algorithm=algorithm,
    rng=rng,
    enforce_pred_as_input=True,
    enforce_permutations=True,
    chunk_length=16,
    )
  test_args = dict(sizes=[64],
                   split='test',
                   batch_size=32,
                   multiplier=2,
                   randomize_pos=False,
                   chunked=False,
                   sampler_kwargs={},
                   **common_sampler_args)

  return run.make_multi_sampler(**test_args)




In [None]:
bf_sampler, bf_num_samples, bf_spec = get_test_sampler('bellman_ford')

In [None]:
def get_rewritten_A(A, pi):
  d = np.zeros(pi.shape)
  for i in range(pi.shape[-1]):
    prev_node = i
    node = int(pi[prev_node])
    d[i] += A[node, prev_node]
    while node != pi.shape[-1] - 1:
      prev_node = node
      node = int(pi[prev_node])
      d[i] += A[node, prev_node]
  assert(np.all(d < 0.0001))

  A_rw = A.copy()
  for i in range(A_rw.shape[0]):
    for j in range(A_rw.shape[1]):
      if A[i, j] != 0:
        A_rw[i, j] += d[i] - d[j]

  assert(np.all(A_rw >= -1e6))
  return A_rw[:-1,:-1]


In [None]:
rng = np.random.RandomState(run.FLAGS.seed)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

import pickle
import haiku as hk
from clrs._src import algorithms

with open("/tmp/CLRS30/bestbellman_ford.pkl", 'rb') as f:
  restored_state = pickle.load(f)
  bf_model.params = restored_state['params']
  bf_model.params = hk.data_structures.merge(bf_model.params, restored_state['params'])
  bf_model.opt_state = restored_state['opt_state']

predict_fn = functools.partial(bf_model.predict, algorithm_index=0)

processed_samples = 0
bf_truths = []
outputs = []
A_rws = []
As = []
bf_preds = []

while processed_samples < bf_num_samples:
  feedback = next(bf_sampler)
  A = feedback.features.inputs[2].data
  batch_size = feedback.outputs[0].data.shape[0]
  bf_truths.append(feedback.outputs)
  new_rng_key, rng_key = jax.random.split(rng_key)
  cur_preds, _ = predict_fn(new_rng_key, feedback.features)
  for i in range(len(A)):
    outputs.append(algorithms.johnsons(A[i][:-1,:-1])[0])
    A_rws.append(get_rewritten_A(A[i], cur_preds['pi'].data[i]))
    As.append(A[i])
  bf_preds.append(cur_preds)
  processed_samples += batch_size


In [None]:
bf_truths = run._concat(bf_truths, axis=0)
bf_preds = run._concat(bf_preds, axis=0)
clrs.evaluate(bf_truths, bf_preds)

{'pi': 0.971923828125, 'score': 0.971923828125}

In [None]:
d_sampler, d_num_samples, d_spec = get_test_sampler('dijkstra')

In [None]:
run.collect_and_eval(d_sampler, functools.partial(dijkstra_model.predict, algorithm_index=0), d_num_samples, rng_key, {})

{'pi': 0.93505859375, 'score': 0.93505859375}

In [None]:
from clrs._src import probing

with open("/tmp/CLRS30/bestdijkstra.pkl", 'rb') as f:
  restored_state = pickle.load(f)
  dijkstra_model.params = restored_state['params']
  dijkstra_model.params = hk.data_structures.merge(dijkstra_model.params, restored_state['params'])
  dijkstra_model.opt_state = restored_state['opt_state']
predict_fn = functools.partial(dijkstra_model.predict, algorithm_index=0)

preds = []
d_preds = []
d_truths = []
for A_rw in A_rws:
  result = []
  for s in range(len(A_rw)):
    dijkstra_out = algorithms.dijkstra(A_rw, s)[1]
    inp, outp, hint = probing.split_stages(dijkstra_out, clrs.SPECS['dijkstra'])
    d_truths.append(outp)

    features = clrs._src.samplers.Features(tuple(inp), tuple(hint), np.array([len(hint)]))
    new_rng_key, rng_key = jax.random.split(rng_key)
    cur_preds, _ = predict_fn(new_rng_key, features)
    d_preds.append(cur_preds)
    result.append(cur_preds['pi'].data[0])
  preds.append(result)

In [None]:
d_truths = run._concat(d_truths, axis=0)
d_preds = run._concat(d_preds, axis=0)

In [None]:
clrs.evaluate(d_truths, d_preds)

{'pi': 0.48778423406399596, 'score': 0.48778423406399596}

In [None]:
preds = probing.DataPoint("Pi","edge","pointer",np.array(preds))
outputs = probing.DataPoint("Pi","edge","pointer",np.array(outputs))

In [None]:
clrs._src.evaluation.evaluate([outputs], {"Pi": preds})

{'Pi': 0.19944019274376418, 'score': 0.19944019274376418}