In [1]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [77]:
train_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=100,
    length=16,
    weighted=True) # number of nodes

test_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=40, # TODO set back to more
    length=64,
    weighted=True) # testing on much larger
# TODO how do you know aren't generating same graphs? (well not possible here since different size but in general?)

pprint.pprint(spec) # spec is the algorithm specification, all the probes

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=40) # full batch for the test set


{'A': ('input', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'buyers': ('input', 'node', 'mask'),
 'in_queue': ('hint', 'node', 'mask'),
 'owners': ('output', 'node', 'pointer'),
 'owners_h': ('hint', 'node', 'pointer'),
 'p': ('hint', 'node', 'scalar'),
 'pos': ('input', 'node', 'scalar')}


In [35]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True, nb_triplet_fts=0) #use_ln => use layer norm 
# processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)
model_params = dict(
    processor_factory=processor_factory, # contains the processor_factory
    hidden_dim=32,
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    hint_teacher_forcing=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False, # Good for post step
    dropout_prob=0.5,
)

dummy_trajectory = next(train_sampler) # jax needs a trajectory that is plausible looking to init

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234) # 1234 is a random seed

In [63]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
step = 0

while step <= 100:
    feedback, test_feedback = next(train_sampler), next(test_sampler)
    rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
    cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
    rng_key = new_rng_key
    if step % 10 == 0:
        print(step)
    step += 1


0
10
20
30
40
50
60
70
80
90
100


In [6]:
step = 0

while step <= 100:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
  cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}') # here, val accuracy is actually training accuracy, not great but is example
  step += 1

AssertionError: 

## Some intermediate results
ALL: 0.5 dropout

MPNN 100 train, 40 test, 100 epochs, self-loops -> loss = 0.9931782484054565 | val_acc = 0.8515625 | test_acc = 0.7484375238418579 | accuracy = 0.65, average nb non-matched: 13.45

MPNN 200 train, 40 test, 100 epochs, self-loops -> loss = 0.8129420280456543 | val_acc = 0.8125 | test_acc = 0.77734375 | accuracy = 0.72, average nb non-matched: 11.125

MPNN 100 train, 40 test, 100 epochs, double links -> loss = 0.8689386248588562 | val_acc = 0.7109375 | test_acc = 0.42695313692092896 | accuracy =? NOTE: only started "learning" in the last epochs => trying more, interestingly has less loss than self-loops but less accuracy too

MPNN 100 train, 40 test, 200 epochs, double links -> step = 100 | loss = 0.6802611351013184 | val_acc = 0.806640625 | test_acc = 0.681640625 | accuracy = 0.89, average nb non-matched: 5.65

MPNN 300 train, 40 test, 400 epochs, double links -> loss = 0.5485531091690063 | val_acc = 0.775390625 | test_acc = 0.6910156607627869 | accuracy = 0.928, average nb non-matched: 4.075 Note: best test_acc 0.727, similar test_acc to 100 train 200 epochs but better accuracy + still does not converge on training accuracy though

Diff: length 100 testing instead of 64
MPNN 100 train, 40 test LENGTH 100, 200 epochs, double links -> loss = 0.6958761215209961 | val_acc = 0.787109375 | test_acc = 0.503250002861023 | accuracy = 0.759, average nb non-matched: 7.9/100


#### Now with actually bipartite graph (no owner-owner / good-good edges)

In [71]:
def matching_value(samples, predictions, partial = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_accuracy = 0

    # Iterating over all the samples
    for i in range(data.shape[0]):
        max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        # TODO remove
        predicted_matching = predictions["owners"].data[i]
        # buyers_mask = masks[i]
        # n = int(np.sum(buyers_mask))
        # permutation = np.random.permutation(np.arange(np.sum(buyers_mask == 0)))
        # predicted_matching = np.concatenate((np.zeros(n), permutation))
        if partial:
            preds_weight = compute_partial_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, partial: {preds_weight}")
        else:
            preds_weight = compute_greedy_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, partial: {preds_weight}")

        assert preds_weight <= max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0]

def compute_greedy_matching_weight(i, data, masks, matching):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask



    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            matching_weight += np.max(A[buyer, matching == buyer])

    return matching_weight

def compute_partial_matching_weight(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n+m]
    matching = matching[:n, n:n+m]

    max_weight = np.max(np.sum(matching, axis = 0))
    print(f"max weight: {max_weight}")
    matching /= max_weight
    return np.sum(matching * A_submatrix)

In [64]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)

In [72]:
matching_value(test_feedback, predictions, partial = True)

max weight: 1.5297893285751343
opt: 22.509909997050908, partial: 14.029775619506836, greedy: 0
max weight: 1.5107792615890503
opt: 23.135114167854553, partial: 14.419025421142578, greedy: 0
max weight: 1.5107792615890503
opt: 23.135114167854553, partial: 14.419025421142578, greedy: 0
max weight: 1.5683262348175049
opt: 23.027825681330196, partial: 14.078963279724121, greedy: 0
max weight: 1.3711823225021362
opt: 22.79800465355557, partial: 15.573226928710938, greedy: 0
max weight: 1.5052690505981445
opt: 24.698609651980103, partial: 15.195359230041504, greedy: 0
max weight: 1.3506972789764404
opt: 23.933260816797336, partial: 16.73806381225586, greedy: 0
max weight: 1.5683262348175049
opt: 23.027825681330196, partial: 14.078963279724121, greedy: 0
max weight: 1.700014352798462
opt: 21.86007317728456, partial: 11.941967010498047, greedy: 0
max weight: 1.352318286895752
opt: 23.53177149762247, partial: 16.173660278320312, greedy: 0
max weight: 1.3669893741607666
opt: 24.734511097181954, 

Array(0.6384877, dtype=float32)

## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [None]:
# For two-way
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    for i in range(32):
        owner = data[datapoint][i]
        good = data[datapoint][int(owner)]
        if good != i:
            count += 1
print(f"average number of edges contradicting matching: {count / nb_graphs}")

In [68]:
# For self-loops
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    owners = set(np.array(data[datapoint][32:64]))
    count += 32 - len(owners)
print(f"average number of edges contradicting matching: {count / nb_graphs}")


TypeError: unhashable type: 'numpy.ndarray'

In [73]:
# Whyyyyy are you not bipartite
data = test_feedback.features.inputs[1].data[0]
#data[:32, :32]
data[32:64, 32:64]



array([[0.        , 0.        , 0.        , ..., 0.28533341, 0.26362947,
        0.84964214],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.44046488,
        0.        ],
       ...,
       [0.28533341, 0.        , 0.        , ..., 0.50471434, 0.        ,
        0.        ],
       [0.26362947, 0.        , 0.44046488, ..., 0.        , 0.23655434,
        0.        ],
       [0.84964214, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [79]:
predictions, hints = model.predict(rng_key, test_feedback.features, return_hints = True, return_all_outputs = True)


In [86]:
predictions['owners'].data[::100].shape

(5, 40, 64, 64)

In [143]:
hints[10]['owners_h'][0].shape


(16, 16)

In [88]:
arr = np.array([1, 2, 3, 4, 5, 6])
arr[np.arange(len(arr)) % 2 == 0]
arr[::2]

array([6, 4, 2])

In [93]:
test_feedback.features.inputs[1].data[0]

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.45111127,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.74242566,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.50471434, 0.        ,
        0.        ],
       [0.        , 0.45111127, 0.74242566, ..., 0.        , 0.23655434,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]])