In [1]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
train_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=100,
    length=16,
    weighted=True) # number of nodes

test_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=40, # TODO set back to more
    length=64,
    weighted=True) # testing on much larger
# TODO how do you know aren't generating same graphs? (well not possible here since different size but in general?)

pprint.pprint(spec) # spec is the algorithm specification, all the probes

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=40) # full batch for the test set


{'A': ('input', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'buyers': ('input', 'node', 'mask'),
 'in_queue': ('hint', 'node', 'mask'),
 'owners': ('output', 'node', 'pointer'),
 'owners_h': ('hint', 'node', 'pointer'),
 'p': ('hint', 'node', 'scalar'),
 'pos': ('input', 'node', 'scalar')}


In [3]:
# Bipartite graph is weighted
print(next(train_sampler).features.inputs[1].data[0])

[[0.         0.         0.         0.         0.         0.
  0.         0.         0.13720626 0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.92963277 0.         0.88031202 0.         0.53336933
  0.         0.52068571 0.         0.         0.82835848 0.49224233
  0.         0.         0.30007914 0.        ]
 [0.         0.         0.         0.38189772 0.         0.
  0.         0.05240804 0.         0.         0.         0.4720396
  0.53658954 0.         0.         0.        ]
 [0.         0.88031202 0.38189772 0.         0.         0.
  0.         0.         0.         0.         0.40886447 0.
  0.73795361 0.5591565  0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.62631406 0.         0.51914184 0.88036373 0.79552802 0.33534672
  0.         0.         0.06154556 0.80690847]
 [0.         0.53336933 0.         0.         0.         0.78045274
  0.33451041 0.         0.         0.         0.         0.
  0.  

In [19]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True, nb_triplet_fts=0) #use_ln => use layer norm 
processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)
model_params = dict(
    processor_factory=processor_factory, # contains the processor_factory
    hidden_dim=32,
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    hint_teacher_forcing=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False, # Good for post step
    dropout_prob=0.0,
)

dummy_trajectory = next(train_sampler) # jax needs a trajectory that is plausible looking to init

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234) # 1234 is a random seed

In [25]:
step = 0

while step <= 100:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
  cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}') # here, val accuracy is actually training accuracy, not great but is example
  step += 1

step = 0 | loss = 0.6523988246917725 | val_acc = 0.82421875 | test_acc = 0.790234386920929
step = 10 | loss = 0.6065563559532166 | val_acc = 0.86328125 | test_acc = 0.762890636920929
step = 20 | loss = 0.6038270592689514 | val_acc = 0.861328125 | test_acc = 0.7730469107627869
step = 30 | loss = 0.6365153789520264 | val_acc = 0.845703125 | test_acc = 0.76953125
step = 40 | loss = 0.591829240322113 | val_acc = 0.81640625 | test_acc = 0.7796875238418579
step = 50 | loss = 0.6535412669181824 | val_acc = 0.849609375 | test_acc = 0.7671875357627869
step = 60 | loss = 0.5626905560493469 | val_acc = 0.861328125 | test_acc = 0.780078113079071
step = 70 | loss = 0.5786242485046387 | val_acc = 0.8203125 | test_acc = 0.775390625
step = 80 | loss = 0.5910295248031616 | val_acc = 0.8203125 | test_acc = 0.783203125
step = 90 | loss = 0.5957251787185669 | val_acc = 0.830078125 | test_acc = 0.7640625238418579
step = 100 | loss = 0.5553303360939026 | val_acc = 0.806640625 | test_acc = 0.770312488079071


In [26]:
# TODO verify code and check what it gives with random permutation
def matching_value(samples, predictions):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_accuracy = 0

    # Iterating over all the samples
    for i in range(data.shape[0]):
        max_weight = compute_matching_weight(i, data, masks, gt_matchings[i])

        # TODO remove
        predicted_matching = predictions["owners"].data[i]
        buyers_mask = masks[i]
        n = int(np.sum(buyers_mask))
        permutation = np.random.permutation(np.arange(np.sum(buyers_mask == 0)))
        # predicted_matching = np.concatenate((np.zeros(n), permutation))
        preds_weight = compute_matching_weight(i, data, masks, predicted_matching)
        print(f"max: {max_weight}, pred: {preds_weight}")
        assert preds_weight <= max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0]

def compute_matching_weight(i, data, masks, matching):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    consumers_mask = 1 - buyers_mask



    # Only consider the matching values for consumers
    matching = np.where(consumers_mask == 1, matching, -1)

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            matching_weight += np.max(A[buyer, matching == buyer])

    return matching_weight


In [27]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)

In [28]:
matching_value(test_feedback, predictions)

max: 24.27956714466754, pred: 16.34168389134737
max: 22.102475176188705, pred: 16.059375099667196
max: 23.86077967779625, pred: 16.820780982633178
max: 23.047660039579167, pred: 18.383731617695002
max: 22.102475176188705, pred: 16.059375099667196
max: 23.25031575548427, pred: 15.89369440318066
max: 21.975819935846754, pred: 14.496811062052748
max: 23.25031575548427, pred: 15.89369440318066
max: 23.15032587271355, pred: 17.09720336740702
max: 22.655820454391158, pred: 16.33835209746942
max: 22.246620952133604, pred: 15.210615682279373
max: 23.622342726948837, pred: 16.52460606280363
max: 23.35910025981948, pred: 16.420581118740042
max: 24.36949590014474, pred: 17.73089264444693
max: 23.132478080934842, pred: 18.370195596097687
max: 22.699991401473866, pred: 15.062203166478351
max: 22.617621061457204, pred: 15.237357107627743
max: 23.35910025981948, pred: 16.420581118740042
max: 23.25031575548427, pred: 15.89369440318066
max: 23.35910025981948, pred: 16.420581118740042
max: 21.3457761562

0.7209796596579074

## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72
