In [1]:
import jax
import clrs
import numpy as np

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [19]:
def samplers(num_samples, length, batch_size, **kwargs):
    if batch_size > num_samples:
        batch_size = num_samples

    def _iterate_sampler(sampler, batch_size):
        while True:
            yield sampler.next(batch_size)

    sampler, spec = clrs.build_sampler(
        name = 'simplified_min_sum',
        num_samples = num_samples,
        length = length,
        weighted = True,
        **kwargs)  # number of nodes

    sampler = _iterate_sampler(sampler, batch_size = batch_size)
    return sampler, spec

train_sampler, spec = samplers(100, 8, 32, generator = "ER")
test_sampler, _ = samplers(40, 64, 40)

In [20]:
sample = next(test_sampler)
sample.features.inputs[1].data[0]

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.00847525, 0.41870585,
        0.40326243],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.00847525, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.41870585, 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.40326243, 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [21]:
def define_model(spec, train_sampler, model = "mpnn"):
    if model == "mpnn":
        processor_factory = clrs.get_processor_factory('mpnn', use_ln = True,
                                                   nb_triplet_fts = 0)  #use_ln => use layer norm
    elif model == "gat":
        processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)

    elif model == "mpnndoublemax":
        processor_factory = clrs.get_processor_factory('mpnndoublemax', use_ln = True,
                                                       nb_triplet_fts = 0)  #use_ln => use layer norm

    model_params = dict(
        processor_factory = processor_factory,  # contains the processor_factory
        hidden_dim = 32,  # TODO put back to 32 if no difference
        encode_hints = True,
        decode_hints = True,
        #decode_diffs=False,
        #hint_teacher_forcing_noise=1.0,
        hint_teacher_forcing = 1.0,
        use_lstm = False,
        learning_rate = 0.001,
        checkpoint_path = '/tmp/checkpt',
        freeze_processor = False,  # Good for post step
        dropout_prob = 0.5,
        # nb_msg_passing_steps=3,
    )

    dummy_trajectory = next(train_sampler)  # jax needs a trajectory that is plausible looking to init

    model = clrs.models.BaselineModel(
        spec = spec,
        dummy_trajectory = dummy_trajectory,
        **model_params
    )

    model.init(dummy_trajectory.features, 1234)  # 1234 is a random seed

    return model

model = define_model(spec, train_sampler, "mpnndoublemax")

reduction: doublemax
Double max!
reduction: doublemax
Double max!
reduction: doublemax
Double max!


In [22]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
# step = 0
#
# while step <= 1:
#     feedback, test_feedback = next(train_sampler), next(test_sampler)
#     rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#     cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#     rng_key = new_rng_key
#     if step % 10 == 0:
#         print(step)
#     step += 1



In [23]:
def train(model, epochs, train_sampler, test_sampler):
    step = 0
    rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

    while step <= epochs:
        feedback, test_feedback = next(train_sampler), next(test_sampler)
        # TODO remove - testing if uses hints on tests
        # shape = test_feedback.features.hints[0].data[0].shape
        # test_feedback.features.hints[0].data = test_feedback.features.hints[0].data[0, :, :].reshape((1, *shape))

        rng_key, new_rng_key = jax.random.split(rng_key)  # jax needs new random seed at step
        cur_loss = model.feedback(rng_key, feedback)  # loss is contained in model somewhere
        rng_key = new_rng_key
        if step % 10 == 0:
            predictions_val, _ = model.predict(rng_key, feedback.features)
            out_val = clrs.evaluate(feedback.outputs, predictions_val)
            predictions, _ = model.predict(rng_key, test_feedback.features)
            out = clrs.evaluate(test_feedback.outputs, predictions)
            print(
                f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')  # here, val accuracy is actually training accuracy, not great but is example
        step += 1
    return model

In [35]:
model = train(model, 100, train_sampler, test_sampler)

step = 0 | loss = 0.15068848431110382 | val_acc = 0.875 | test_acc = 0.592578113079071
step = 10 | loss = 0.14301234483718872 | val_acc = 0.8125 | test_acc = 0.604687511920929
step = 20 | loss = 0.15729369223117828 | val_acc = 0.90625 | test_acc = 0.6285156607627869
step = 30 | loss = 0.15455292165279388 | val_acc = 0.79296875 | test_acc = 0.5980468988418579
step = 40 | loss = 0.14386148750782013 | val_acc = 0.78515625 | test_acc = 0.595703125
step = 50 | loss = 0.14222665131092072 | val_acc = 0.82421875 | test_acc = 0.5703125
step = 60 | loss = 0.1405799686908722 | val_acc = 0.80078125 | test_acc = 0.596875011920929
step = 70 | loss = 0.12779049575328827 | val_acc = 0.83203125 | test_acc = 0.599609375
step = 80 | loss = 0.15173453092575073 | val_acc = 0.796875 | test_acc = 0.598437488079071
step = 90 | loss = 0.17027558386325836 | val_acc = 0.859375 | test_acc = 0.6117187738418579
step = 100 | loss = 0.13329805433750153 | val_acc = 0.84765625 | test_acc = 0.6078125238418579


In [36]:
from scipy.optimize import linear_sum_assignment


def matching_value(samples, predictions, partial = False, match_rest = False, opt_scipy = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_accuracy = 0
    greedy_accuracy = 0

    # Iterating over all the samples
    for i in range(data.shape[0]):
        if opt_scipy:
            row_ind, col_ind = linear_sum_assignment(data[i], maximize = True)
            max_weight = data[i][row_ind, col_ind].sum() / 2  #TODO why /2
        else:
            max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        predicted_matching = predictions["match"].data[i]

        if partial:
            preds_weight = compute_partial_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, greedy learned: {preds_weight}")
        else:
            preds_weight = compute_greedy_matching_weight(i, data, masks, predicted_matching, match_rest = match_rest)
            print(f"opt: {max_weight}, partial: {preds_weight}")

        # assert preds_weight <= max_weight
        greedy_matching_weight = naive_greedy(i, data, masks)
        print(f"Naive greedy: {greedy_matching_weight}")
        greedy_accuracy += greedy_matching_weight / max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0], greedy_accuracy / data.shape[0]

def naive_greedy(i, data, masks):
    """Computes a matching greedily by for each node adding the maximum neighbor that
    hasn't yet been added to the matching"""

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    # At the start, all the right hand side values are possible matches
    matching_mask = np.full(A.shape[0], True)

    for buyer in range(n):
        # Checking if there are more elements to match (if more buyers than goods)
        if A[buyer, matching_mask].shape[0] != 0:
            matching_weight += np.max(A[buyer, matching_mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, matching_mask])
            good = np.arange(A.shape[1])[matching_mask][subset_idx]
            # The corresponding good cannot be used anymore
            matching_mask[good] = False
    return matching_weight




def compute_greedy_matching_weight(i, data, masks, matching, match_rest = False):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)
    unmatched_goods = set(range(n, n + m))
    unmatched_buyers = set(range(n))

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            mask = matching == buyer
            matching_weight += np.max(A[buyer, mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, mask])
            good = np.arange(A.shape[1])[mask][subset_idx]
            unmatched_goods.remove(good)
            unmatched_buyers.remove(buyer)

    if match_rest and len(unmatched_goods) > 0 and len(unmatched_buyers) > 0:
        # Compute optimal matching on the remaining unmatched nodes
        mask = np.zeros(A.shape)
        # TODO this is a horrible solution, there's definitely a prettier solution
        mask[list(unmatched_buyers)] += 1
        mask[:, list(unmatched_goods)] += 1
        mask = np.where(mask == 2, True, False)
        remaining_bipartite_graph = A * mask
        row_ind, col_ind = linear_sum_assignment(remaining_bipartite_graph, maximize = True)
        opt = A[row_ind, col_ind].sum() / 2  #TODO do I always need the division by 2
        matching_weight += opt

    return matching_weight


def compute_partial_matching_weight(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n + m]
    matching = matching[:n, n:n + m]

    max_weight = np.max(np.sum(matching, axis = 0))
    print(f"max weight: {max_weight}")
    matching /= max_weight
    return np.sum(matching * A_submatrix)

In [38]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)
matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)


opt: 28.987174894406635, partial: 27.698218043886605
Naive greedy: 26.751829752404866
opt: 28.673538026677384, partial: 26.631343399451104
Naive greedy: 25.375112883780695
opt: 29.28512827497548, partial: 27.08840981100459
Naive greedy: 26.767052757473845
opt: 28.567972726482722, partial: 27.516866746498213
Naive greedy: 26.555935021967624
opt: 28.285468675657334, partial: 27.03739412638852
Naive greedy: 25.51905125840595
opt: 29.807240647746205, partial: 26.892068653328785
Naive greedy: 28.63823640433621
opt: 28.77835963216673, partial: 26.642746485970438
Naive greedy: 25.591410127184165
opt: 29.50758037200704, partial: 27.644329485438174
Naive greedy: 27.663127212119207
opt: 29.578480584670196, partial: 27.18969576791876
Naive greedy: 26.788630524188175
opt: 28.845570090536263, partial: 26.271818073117043
Naive greedy: 26.308393553431728
opt: 29.50758037200704, partial: 27.644329485438174
Naive greedy: 27.663127212119207
opt: 28.922458960290797, partial: 26.264473318838363
Naive gree

(0.9377233284672846, 0.9233354604051114)

In [None]:
def variation_testing(train_params, test_params, epochs = 100, model = None, bypass_training = False):
    if model is None and bypass_training:
        print("Need a model to bypass training")
        return


    matching_values = []
    for train_param, test_param in zip(train_params, test_params):
        train_length = train_param.get("length", 16)
        test_length = test_param.get("length", 64)
        train_param.pop("length", None)
        test_param.pop("length", None)

        test_num_samples = 10
        batch_size = test_num_samples

        test_sampler, _ = samplers(test_num_samples, test_length, batch_size, **test_param)

        if not bypass_training:
            train_sampler, spec = samplers(100, train_length, 40, **train_param)
            model = define_model(spec, train_sampler, model="mpnn")
            train(model, 150, train_sampler, test_sampler)
        else:
            print("Bypassing training")

        test_feedback = next(test_sampler)
        predictions, _ = model.predict(rng_key, test_feedback.features)
        accuracy = matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)

        matching_values.append((train_param, test_param, accuracy))
    return model, matching_values

weight_params = [{"low": 0, "high": 0.001},
                 {"low": 1, "high": 1.001},
                 {"low": 0, "high": 0.1},
                 {"low": 0, "high": 1},
                 {"low": 0, "high": 10},
                 {"low": 0, "high": 100},
                 {"low": 50, "high": 200},
                 {"low": 500, "high": 2000},
                 {"low": 5000, "high": 20000}
                 ]

length_training = [{"generator": "ER"}]
length_testing = [{"generator": "ER", "length": 1000, "p": 0.01}]




model, results = variation_testing(length_training, length_testing, model = model, bypass_training = True)

results




In [16]:
# import copy
# model2 = copy.deepcopy(model)

ER p=0.25, 100 8x8 train and 40 32x32 test => 0.94 in 100 iterations

BA param=3, 100 8x8 train and 40 32x32 test => 0.97 in 100 iterations

BA param=5, 100 8x8 train and 40 32x32 test => 0.95 in 100 iterations (0.951 in 200 so has pretty much converged after 100)

BA param=7, 100 8x8 train and 40 32x32 test => 0.946 in 100 iterations

#### Cross training
BA param=7 to BA param=3

BA param=7 to ER p=0.25 0.946 with BA to 0.939 with ER (same as if trained only on BA)

ER p=0.25 to BA param=3 went from 0.939 with ER to 0.967 with BA (BA param 3 was 0.97 so basically nothing lost)

#### Weight variations
Uniform
* 0,0.001 -> 0.928
* 1,1.001 -> 0.962
* 0,0.1 -> 0.931
* 0,10 -> 0.883
* 0,100 -> 0.77
* 50, 200 -> 0.72
* 500, 2000 -> 0.69
* 5000, 20000 -> 0.7

Normal:
Basically same.

Gumbel
* 0,0.001 -> 0.323
* 1,1.001 -> 0.849
* 0,0.1 -> 0.498
* 5,10 -> 0.82
* 5,100 -> 0.8

#### Weight cross training
Train ER p=0.25 unif 0,1:
* 0,0.001 -> 0.948
* 1,1.001 -> 0.967
* 0,0.1 -> 0.916
* 0,10 -> 0.86
* 0,100 -> 0.75
* 50, 200 -> 0.72
* 500, 2000 -> 0.72
* 5000, 20000 -> 0.69

=> Seems to weight generalize quite well. Actually even better because basically no statistical difference with if we trained separately.

Train normal 5000, 20000:
* 0, 0.001 -> 0.39 (maybe it's the large to small that was a problem here? Also those values make little sense for a normal RV)

Other direction train normal 0, 0.001 (got to 0.78):
* 5000, 20000 -> 0.76

=> small to large seems better


#### Larger graphs
Same training
ER p=0.25 8x8 train:
* 100x100 test goes to 0.88
* 200x200 goes to 0.63 (only 12 prediction mismatches though)
* 200x200 p=0.3 =>
* 250x250 => 0.9448 (BUT p=0.1 to not kill my computer)
*
Try this but 16x16 train

#### RIDESHARE
8x8 train,
* 32x32 test => 0.96
* 50x50 test => 0.96
* 100x100 test => 0.938
* 250x250 test => 0.9



## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [27]:
# For two-way
def count_mismatches_two_way(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        for i in range(32):
            owner = data[datapoint][i]
            good = data[datapoint][int(owner)]
            if good != i:
                count += 1
    print(f"average number of edges contradicting matching: {count / nb_graphs}")

average number of edges contradicting matching: 12.2


In [17]:
# For self-loops
def count_mismatches_self_loop(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        owners = set(np.array(data[datapoint][32:64]))
        count += 32 - len(owners)
    print(f"average number of edges contradicting matching: {count / nb_graphs}")


average number of edges contradicting matching: 0.4


In [14]:
a = np.array([1, 2])
b = np.array([2, 3])
print(np.concatenate((a, b)))

[1 2 2 3]
