In [1]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
train_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=100,
    length=16,
    weighted=True) # number of nodes

test_sampler, spec = clrs.build_sampler(
    name='auction_matching',
    num_samples=40, # TODO set back to more
    length=64,
    weighted=True) # testing on much larger
# TODO how do you know aren't generating same graphs? (well not possible here since different size but in general?)

pprint.pprint(spec) # spec is the algorithm specification, all the probes

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

# TODO put back normal batch values
train_sampler = _iterate_sampler(train_sampler, batch_size=32)
test_sampler = _iterate_sampler(test_sampler, batch_size=40) # full batch for the test set



{'A': ('input', 'edge', 'scalar'),
 'adj': ('input', 'edge', 'mask'),
 'buyers': ('input', 'node', 'mask'),
 'in_queue': ('hint', 'node', 'mask'),
 'owners': ('output', 'node', 'pointer'),
 'owners_h': ('hint', 'node', 'pointer'),
 'p': ('hint', 'node', 'scalar'),
 'pos': ('input', 'node', 'scalar')}


In [3]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True, nb_triplet_fts=0) #use_ln => use layer norm
# processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)
model_params = dict(
    processor_factory=processor_factory, # contains the processor_factory
    hidden_dim=32, # TODO put back to 32 if no difference, indeed not much diff for MPNN
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    hint_teacher_forcing=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False, # Good for post step
    dropout_prob=0.5,
    # nb_msg_passing_steps=3,
)

dummy_trajectory = next(train_sampler) # jax needs a trajectory that is plausible looking to init

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234) # 1234 is a random seed

In [111]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
step = 0

while step <= 100:
    feedback, test_feedback = next(train_sampler), next(test_sampler)
    rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
    cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
    rng_key = new_rng_key
    if step % 10 == 0:
        print(step)
    step += 1


0
10
20
30
40
50
60
70
80
90
100


In [5]:
# import copy
#
# step = 0
# while step <= 20:
#   feedback, test_feedback = next(train_sampler), next(test_sampler)
#
#   rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#   cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#   rng_key = new_rng_key
#   # if step % 10 == 0:
#   #   predictions_val, _ = model.predict(rng_key, feedback.features)
#   #   out_val = clrs.evaluate(feedback.outputs, predictions_val)
#   #   predictions, _ = model.predict(rng_key, test_feedback.features)
#   #   out = clrs.evaluate(test_feedback.outputs, predictions)
#   #   print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}') # here, val accuracy is actually training accuracy, not great but is example
#   step += 1
# model2 = copy.deepcopy(model)

-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
-------------- step -----------------
------------

## Some intermediate results
ALL: 0.5 dropout

MPNN 100 train, 40 test, 100 epochs, self-loops -> loss = 0.9931782484054565 | val_acc = 0.8515625 | test_acc = 0.7484375238418579 | accuracy = 0.65, average nb non-matched: 13.45

MPNN 200 train, 40 test, 100 epochs, self-loops -> loss = 0.8129420280456543 | val_acc = 0.8125 | test_acc = 0.77734375 | accuracy = 0.72, average nb non-matched: 11.125

MPNN 100 train, 40 test, 100 epochs, double links -> loss = 0.8689386248588562 | val_acc = 0.7109375 | test_acc = 0.42695313692092896 | accuracy =? NOTE: only started "learning" in the last epochs => trying more, interestingly has less loss than self-loops but less accuracy too

MPNN 100 train, 40 test, 200 epochs, double links -> step = 100 | loss = 0.6802611351013184 | val_acc = 0.806640625 | test_acc = 0.681640625 | accuracy = 0.89, average nb non-matched: 5.65

MPNN 300 train, 40 test, 400 epochs, double links -> loss = 0.5485531091690063 | val_acc = 0.775390625 | test_acc = 0.6910156607627869 | accuracy = 0.928, average nb non-matched: 4.075 Note: best test_acc 0.727, similar test_acc to 100 train 200 epochs but better accuracy + still does not converge on training accuracy though

MPNN 200 train, 40 test, 200 epochs, double links -> 94.4% acc without added matches, 95.7% with added matches

Diff: length 100 testing instead of 64
MPNN 100 train, 40 test LENGTH 100, 200 epochs, double links -> loss = 0.6958761215209961 | val_acc = 0.787109375 | test_acc = 0.503250002861023 | accuracy = 0.759, average nb non-matched: 7.9/100


#### Now with actually bipartite graph (no owner-owner / good-good edges)
Doesn't really change results

ALL with 0 dropout

#### No hints
0 dropout Can get up to 0.78 of OPT, average nb non-matched: 10.5/64


#### Training with 64 hidden dimensions
MPNN 100 train, 40 test, 0 dropout, double links -> Get to 90% acc in 30 iterations, 93% in 60
GAT 100 train, 40 test, 0 dropout, double links -> Get to 0.78 in 30 iterations, 91% in 60, 92% in 100, 92% in 200

Same with 0.5 dropout
MPNN 100 train, 40 test, 0.5 dropout, double links -> Get to 85% in 30 iterations, 93% in 60 iterations, 93% in 100 iterations
GAT 100 train, 40 test, 0.5 dropout, double links -> 94.7% in 200

#### 3 message passing steps
TLDR not great
64 dims, 3 message passing steps
MPNN 100 train, 40 test, 0.5 dropout, double links -> 91.7% in 100 iterations (worse than 1 MP step), 93% in 200 iterations

Back to 1 message passing step

#### Larger MLP
TLDR not great
[out_size, out_size, out_size] MLP
MPNN 100 train, 40 test, 0.5 dropout, double links, 3 layer MLP -> loss = 0.6642395257949829 | val_acc = 0.75390625 | test_acc = 0.699999988079071 | 90% in 100 iterations (worse than smaller MLP) | 93% in 200 iterations

# Add max matching to greedy

MPNN 100 train, 40 test, 0.5 dropout, double links, 64 hidden -> trained on 100 iterations, goes from 91% to 94% with added max matching, 200 iterations 91.7 -> 94.7 (exactly 3% more again)

#### On 32 vs 64
32 learns slower (will need 200 iterations to get to good acc) but basically peaks at same values as 64

#### Partial matchings
All tested with the usual 200 epochs of MPNN where greedy gets up to 0.92-ish

Softmax gets 0.66 with max weights of around 1.4 in average (so close to opt since maximum with 1.5 would be 0.66)

Normalized gets 0.34 with max weight edges of around 1.55 in average, not very close to opt, surprising that it is much worse than softmax

Min with softmax does much better, gets 0.86 with an average outgoing edge weight of 0.86 (opt should be 1) then divided by max outgoing but basically no difference in result.

Min with normalized does not great, get 0.47 though average outgoing edge weight similar at around 0.85, seems like the predictions are plainly just worse for some reason.

My intuition for why it's worse in normalized: the model is trained to find the best and put the others to 0 after softmax and the thing about softmax is that, even if the others have predicted weight close to the max, they'll be pushed much more towards 0 (thus is less of a problem and aren't pushed more toward 0). Doing without softmax, then, removes this maximum bias and thus the other, less good predictions have higher weight than they'd have under softmax.

Demonstrated by the fact that min with argmax instead of softmax (i.e. basically greedy but worse) does 0.93

In [140]:
from scipy.optimize import linear_sum_assignment


def matching_value(samples, predictions, partial = False, match_rest = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_acc_greedy = 0
    pred_acc_softmax = 0
    pred_acc_normalized = 0
    pred_acc_min = 0

    # Iterating over all the samples
    for i in range(data.shape[0]):
        max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        predicted_matching = predictions["owners"].data[i]

        if partial:
            preds_weight_softmax = compute_partial_matching_weight_softmax(i, data, masks, predicted_matching)
            preds_weight_normalized = compute_partial_matching_weight_normalized(i, data, masks, predicted_matching)
            preds_weight_min = compute_partial_matching_weight_min_edges(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, partial softmax: {preds_weight_softmax}, partial normalized: {preds_weight_normalized}, partial minimum: {preds_weight_min}")
            assert preds_weight_softmax <= max_weight
            assert preds_weight_normalized <= max_weight
            assert preds_weight_min <= max_weight
            pred_acc_softmax += preds_weight_softmax / max_weight
            pred_acc_normalized += preds_weight_normalized / max_weight
            pred_acc_min += preds_weight_min / max_weight
        else:
            preds_weight_greedy = compute_greedy_matching_weight(i, data, masks, predicted_matching, match_rest = match_rest)
            print(f"opt: {max_weight}, partial: {preds_weight_greedy}")

            assert preds_weight_greedy <= max_weight
            pred_acc_greedy += preds_weight_greedy / max_weight

    print(f"--------------------")
    if partial:
        pred_acc_softmax /= data.shape[0]
        pred_acc_normalized /= data.shape[0]
        pred_acc_min /= data.shape[0]
        print(f"average accuracy: softmax {pred_acc_softmax:.4f}, normalized {pred_acc_normalized:.4f}, min: {pred_acc_min:.4f}")
    else:
        pred_acc_greedy /= data.shape[0]
        print(f"average accuracy: greedy {pred_acc_greedy}")

def compute_greedy_matching_weight(i, data, masks, matching, match_rest = False):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))



    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)
    unmatched_goods = set(range(n, n+m))
    unmatched_buyers = set(range(n))

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            mask = matching == buyer
            matching_weight += np.max(A[buyer, mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, mask])
            good = np.arange(A.shape[1])[mask][subset_idx]
            unmatched_goods.remove(good)
            unmatched_buyers.remove(buyer)

    if match_rest and len(unmatched_goods) > 0 and len(unmatched_buyers) > 0:
        # Compute optimal matching on the remaining unmatched nodes
        mask = np.zeros(A.shape)
        # TODO this is a horrible solution, there's definitely a prettier solution
        mask[list(unmatched_buyers)] += 1
        mask[:, list(unmatched_goods)] += 1
        mask = np.where(mask == 2, True, False)
        remaining_bipartite_graph = A * mask
        row_ind, col_ind = linear_sum_assignment(remaining_bipartite_graph, maximize = True)
        opt = A[row_ind, col_ind].sum() / 2 #TODO do I always need the division by 2
        matching_weight += opt





    return matching_weight

def compute_partial_matching_weight_softmax(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n+m]
    matching = matching[:n, n:n+m]

    max_weight_edge = np.max(np.sum(matching, axis = 0))
    print(f"max weight edge softmax: {max_weight_edge}")
    matching /= max_weight_edge
    return np.sum(matching * A_submatrix)

def compute_partial_matching_weight_normalized(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)
    # Has already been softmaxed => we turn it back into log values

    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))


    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n+m]
    matching = matching[:n, n:n+m]

    matching = np.log(matching)
    matching -= np.min(matching, axis = 1).reshape((-1, 1)) #TODO test out with +1
    matching += 0
    matching /= np.sum(matching, axis = 1).reshape((-1, 1))


    max_weight_edge = np.max(np.sum(matching, axis = 0))
    print(f"max weight edge normalized: {max_weight_edge}")
    matching /= max_weight_edge
    return np.sum(matching * A_submatrix)


def compute_partial_matching_weight_min_edges(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))


    # For testing argmax instead of softmax
    # new_matching = np.zeros(matching.shape)
    # print(matching)
    # new_matching[np.arange(n+m), np.argmax(matching, axis = 1)] = 1
    # matching = new_matching
    # print(matching)
    # print(np.sum(matching, axis = 1))


    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n+m]
    matching_1 = matching[:n, n:n+m]
    matching_2 = matching[n:n+m, :n].T

    matching = np.minimum(matching_1, matching_2)
    matching *= 1/np.max(np.sum(matching, axis = 1))

    print(f"average outgoing edge weight: {np.mean(np.sum(matching, axis = 1))}")

    return np.sum(matching * A_submatrix)


In [123]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)

In [141]:
matching_value(test_feedback, predictions, partial = True, match_rest = False)

max weight edge softmax: 1.283907413482666
max weight edge normalized: 1.6004301309585571
average outgoing edge weight: 0.8716641664505005
opt: 24.122013578126086, partial softmax: 18.43136978149414, partial normalized: 8.58033682688231, partial minimum: 21.37261755190633
max weight edge softmax: 1.6178152561187744
max weight edge normalized: 1.71533203125
average outgoing edge weight: 0.8057221174240112
opt: 24.27531631767341, partial softmax: 14.904389381408691, partial normalized: 8.065111701376873, partial minimum: 20.15252750675249
max weight edge softmax: 1.3881453275680542
max weight edge normalized: 1.7349904775619507
average outgoing edge weight: 0.8779746294021606
opt: 24.09069655259647, partial softmax: 17.274770736694336, partial normalized: 7.91917318669128, partial minimum: 21.651993481193635
max weight edge softmax: 1.4463163614273071
max weight edge normalized: 1.3664538860321045
average outgoing edge weight: 0.8771687746047974
opt: 23.77244764558711, partial softmax: 1

## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [11]:
# For two-way
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    for i in range(32):
        owner = data[datapoint][i]
        good = data[datapoint][int(owner)]
        if good != i:
            count += 1
print(f"average number of edges contradicting matching: {count / nb_graphs}")

average number of edges contradicting matching: 32.0


In [12]:
# For self-loops
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    owners = set(np.array(data[datapoint][32:64]))
    count += 32 - len(owners)
print(f"average number of edges contradicting matching: {count / nb_graphs}")


average number of edges contradicting matching: 27.0


In [13]:
predictions, hints = model.predict(rng_key, test_feedback.features, return_hints = True, return_all_outputs = True)


repred: True
HINTS FROM NETS
[DataPoint(name="owners_h",	location=node,	type=pointer,	data=Array(220, 1, 64)), DataPoint(name="p",	location=node,	type=scalar,	data=Array(220, 1, 64)), DataPoint(name="in_queue",	location=node,	type=mask,	data=Array(220, 1, 64))]


SyntaxError: incomplete input (118648624.py, line 3)

In [None]:
len(hints)

In [None]:
arr = np.array([1, 2, 3, 4, 5, 6])
arr[np.arange(len(arr)) % 2 == 0]
arr[::2]

In [None]:
test_feedback.features.inputs[1].data[0]

In [44]:
arr = np.array([[1,2],[3,4]])
np.min(arr, axis = 0)

array([1, 2])