In [1]:
import copy

import jax
import clrs
import numpy as np

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [68]:
# If you don't want BipartiteMatching, just pass empty generator list and
# length separately

train_sampler_spec = {
    'num_samples': 100,
    'batch_size': 32,
    'schematics': [
        {
            'generator': 'ER',
            'proportion': 1,
            'length': 8,
            'kwargs': {'low': 0, 'high': 1, 'weighted': True}
        },
        {
            'generator': 'ER',
            'proportion': 0,
            'length': 8,
            'kwargs': {'p': 1, 'low': 1, 'high': 1.0001, 'weighted': True}
        },
    ]
}

test_sampler_spec = {
    'num_samples': 40,
    'batch_size': 40,
    'schematics': [
        {
            'generator': 'ER',
            'proportion': 1,
            'length': 64,
            'kwargs': {'low': 0, 'high': 1, 'weighted': True}
        }
    ]
}

def samplers(sampler_spec, **kwargs):
    batch_size = sampler_spec.get('batch_size', 1)
    num_samples = sampler_spec['num_samples']
    if batch_size > num_samples:
        batch_size = num_samples

    def _iterate_sampler(sampler, batch_size):
        while True:
            yield sampler.next(batch_size)

    sampler, spec = clrs.build_sampler(
        name = 'simplified_min_sum',
        sampler_spec = sampler_spec,
        **kwargs)  # number of nodes

    sampler = _iterate_sampler(sampler, batch_size = batch_size)
    return sampler, spec

train_sampler, spec = samplers(train_sampler_spec)
test_sampler, _ = samplers(test_sampler_spec)

In [69]:
sample = next(test_sampler)
sample.features.inputs[1].data[0]

array([[0.        , 0.        , 0.        , ..., 0.85262814, 0.        ,
        0.57133077],
       [0.        , 0.        , 0.        , ..., 0.        , 0.51040028,
        0.2183404 ],
       [0.        , 0.        , 0.        , ..., 0.92918532, 0.57401117,
        0.        ],
       ...,
       [0.85262814, 0.        , 0.92918532, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.51040028, 0.57401117, ..., 0.        , 0.        ,
        0.        ],
       [0.57133077, 0.2183404 , 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [81]:
def define_model(spec, train_sampler, model = "mpnn"):
    if model == "mpnn":
        processor_factory = clrs.get_processor_factory('mpnn', use_ln = True,
                                                   nb_triplet_fts = 0)  #use_ln => use layer norm
    elif model == "gat":
        processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)

    elif model == "mpnndoublemax":
        processor_factory = clrs.get_processor_factory('mpnndoublemax', use_ln = True,
                                                       nb_triplet_fts = 0)  #use_ln => use layer norm

    model_params = dict(
        processor_factory = processor_factory,  # contains the processor_factory
        hidden_dim = 32,  # TODO put back to 32 if no difference
        encode_hints = True,
        decode_hints = True,
        #decode_diffs=False,
        #hint_teacher_forcing_noise=1.0,
        hint_teacher_forcing = 1.0,
        use_lstm = False,
        learning_rate = 0.001,
        checkpoint_path = '/tmp/checkpt',
        freeze_processor = False,  # Good for post step
        dropout_prob = 0.5,
        # nb_msg_passing_steps=3,
    )

    dummy_trajectory = next(train_sampler)  # jax needs a trajectory that is plausible looking to init

    model = clrs.models.BaselineModel(
        spec = spec,
        dummy_trajectory = dummy_trajectory,
        **model_params
    )

    model.init(dummy_trajectory.features, 1234)  # 1234 is a random seed

    return model

model = define_model(spec, train_sampler, "mpnn")

In [82]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
# step = 0
#
# while step <= 1:
#     feedback, test_feedback = next(train_sampler), next(test_sampler)
#     rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#     cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#     rng_key = new_rng_key
#     if step % 10 == 0:
#         print(step)
#     step += 1



In [83]:
def train(model, epochs, train_sampler, test_sampler):
    step = 0
    rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

    while step <= epochs:
        feedback, test_feedback = next(train_sampler), next(test_sampler)
        # TODO remove - testing if uses hints on tests
        # shape = test_feedback.features.hints[0].data[0].shape
        # test_feedback.features.hints[0].data = test_feedback.features.hints[0].data[0, :, :].reshape((1, *shape))

        rng_key, new_rng_key = jax.random.split(rng_key)  # jax needs new random seed at step
        cur_loss = model.feedback(rng_key, feedback)  # loss is contained in model somewhere
        rng_key = new_rng_key
        if step % 10 == 0:
            predictions_val, _ = model.predict(rng_key, feedback.features)
            out_val = clrs.evaluate(feedback.outputs, predictions_val)
            predictions, _ = model.predict(rng_key, test_feedback.features)
            out = clrs.evaluate(test_feedback.outputs, predictions)
            print(
                f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')  # here, val accuracy is actually training accuracy, not great but is example
        step += 1
    return model

In [84]:
model = train(model, 200, train_sampler, test_sampler)

step = 0 | loss = 4.965429782867432 | val_acc = 0.0390625 | test_acc = 0.0
step = 10 | loss = 1.6784148216247559 | val_acc = 0.015625 | test_acc = 0.00937500037252903
step = 20 | loss = 0.8443288803100586 | val_acc = 0.12890625 | test_acc = 0.02695312537252903
step = 30 | loss = 0.5100876688957214 | val_acc = 0.296875 | test_acc = 0.09531249850988388
step = 40 | loss = 0.42952391505241394 | val_acc = 0.3515625 | test_acc = 0.12187500298023224
step = 50 | loss = 0.32434675097465515 | val_acc = 0.39453125 | test_acc = 0.154296875
step = 60 | loss = 0.34040892124176025 | val_acc = 0.30859375 | test_acc = 0.171875
step = 70 | loss = 0.2824343144893646 | val_acc = 0.33984375 | test_acc = 0.21875
step = 80 | loss = 0.29514825344085693 | val_acc = 0.421875 | test_acc = 0.255859375
step = 90 | loss = 0.3518792688846588 | val_acc = 0.5234375 | test_acc = 0.4046874940395355
step = 100 | loss = 0.23638272285461426 | val_acc = 0.671875 | test_acc = 0.5953125357627869
step = 110 | loss = 0.22414666

In [98]:
from scipy.optimize import linear_sum_assignment


def matching_value(samples, predictions, partial = False, match_rest = False, opt_scipy = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_accuracy = 0
    greedy_accuracy = 0

    #TODO remove
    def _add_uniform_weights(adj, low, high):
        n, m = adj.shape
        weights = np.random.uniform(
            low=low, high=high, size=(n, m)
        )
        return adj * high + low

    # Iterating over all the samples
    for i in range(data.shape[0]):

        if opt_scipy:
            row_ind, col_ind = linear_sum_assignment(data[i], maximize = True)
            max_weight = data[i][row_ind, col_ind].sum() / 2  #TODO why /2
        else:
            max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        predicted_matching = predictions["match"].data[i]

        if partial:
            preds_weight = compute_partial_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, greedy learned: {preds_weight}")
        else:
            preds_weight = compute_greedy_matching_weight(i, data, masks, predicted_matching, match_rest = match_rest)
            print(f"opt: {max_weight}, partial: {preds_weight}")

        # assert preds_weight <= max_weight
        greedy_matching_weight = naive_greedy(i, data, masks)
        print(f"Naive greedy: {greedy_matching_weight}")
        greedy_accuracy += greedy_matching_weight / max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0], greedy_accuracy / data.shape[0]

def naive_greedy(i, data, masks):
    """Computes a matching greedily by for each node adding the maximum neighbor that
    hasn't yet been added to the matching"""

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    # At the start, all the right hand side values are possible matches
    matching_mask = np.full(A.shape[0], True)

    for buyer in range(n):
        # Checking if there are more elements to match (if more buyers than goods)
        if A[buyer, matching_mask].shape[0] != 0:
            matching_weight += np.max(A[buyer, matching_mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, matching_mask])
            good = np.arange(A.shape[1])[matching_mask][subset_idx]
            # The corresponding good cannot be used anymore
            matching_mask[good] = False
    return matching_weight




def compute_greedy_matching_weight(i, data, masks, matching, match_rest = False):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)
    unmatched_goods = set(range(n, n + m))
    unmatched_buyers = set(range(n))

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            mask = matching == buyer
            matching_weight += np.max(A[buyer, mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, mask])
            good = np.arange(A.shape[1])[mask][subset_idx]
            unmatched_goods.remove(good)
            unmatched_buyers.remove(buyer)

    if match_rest and len(unmatched_goods) > 0 and len(unmatched_buyers) > 0:
        # Compute optimal matching on the remaining unmatched nodes
        mask = np.zeros(A.shape)
        # TODO this is a horrible solution, there's definitely a prettier solution
        mask[list(unmatched_buyers)] += 1
        mask[:, list(unmatched_goods)] += 1
        mask = np.where(mask == 2, True, False)
        remaining_bipartite_graph = A * mask
        row_ind, col_ind = linear_sum_assignment(remaining_bipartite_graph, maximize = True)
        opt = A[row_ind, col_ind].sum() / 2  #TODO do I always need the division by 2
        matching_weight += opt

    return matching_weight


def compute_partial_matching_weight(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n + m]
    matching = matching[:n, n:n + m]

    max_weight = np.max(np.sum(matching, axis = 0))
    print(f"max weight: {max_weight}")
    matching /= max_weight
    return np.sum(matching * A_submatrix)

In [100]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)
matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)


opt: 5860606.017943202, partial: 5375879.24805133
Naive greedy: 5491054.91271316
opt: 5885273.654798187, partial: 5349249.835306774
Naive greedy: 5166104.137059483
opt: 5940231.253988897, partial: 5609311.9237106275
Naive greedy: 5392231.790562748
opt: 5992254.803463874, partial: 5800732.558806262
Naive greedy: 5672137.031411302
opt: 5896399.517961098, partial: 5426647.9185259305
Naive greedy: 5497540.657817004
opt: 6115802.803925209, partial: 5468496.158563082
Naive greedy: 5701231.657479361
opt: 5811205.017066159, partial: 5416207.144849619
Naive greedy: 5210911.086525901
opt: 5917393.924111161, partial: 5376373.462757005
Naive greedy: 5276987.545165191
opt: 5917393.924111161, partial: 5376373.462757005
Naive greedy: 5276987.545165191
opt: 5885856.320126977, partial: 5663459.02090307
Naive greedy: 5197508.121336856
opt: 6021942.3973772, partial: 5331541.7674708655
Naive greedy: 5529343.665181269
opt: 5811205.017066159, partial: 5416207.144849619
Naive greedy: 5210911.086525901
opt: 5

(0.915517376042058, 0.9039568638776888)

In [119]:
import copy

def variation_testing(train_sampler_spec, test_sampler_spec, epochs = 300, model = None, bypass_training = False):
    if model is None and bypass_training:
        print("Need a model to bypass training")
        return


    matching_values = []
    for train_param, test_param in zip(train_sampler_spec, test_sampler_spec):
        test_param['num_samples'] = 40
        test_param['batch_size'] = 40
        schematics = test_param['schematics']
        schematics[0]['length'] = 64
        test_param['schematics'] = schematics

        test_sampler, _ = samplers(test_param)

        if not bypass_training:
            train_sampler, spec = samplers(train_param)
            model = define_model(spec, train_sampler, model="mpnn")
            train(model, epochs, train_sampler, test_sampler)
        else:
            print("Bypassing training")

        test_feedback = next(test_sampler)
        predictions, _ = model.predict(rng_key, test_feedback.features)
        accuracy = matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)

        matching_values.append((train_param, test_param, accuracy))
    return model, matching_values

weight_params = [{"low": 0, "high": 0.001},
                 {"low": 1, "high": 1.001},
                 {"low": 1, "high": 1.1},
                 {"low": 1, "high": 2},
                 {"low": 0, "high": 0.1},
                 {"low": 0, "high": 1},
                 # {"low": 0, "high": 10},
                 # {"low": 0, "high": 100},
                 # {"low": 50, "high": 200},
                 # {"low": 500, "high": 2000},
                 # {"low": 5000, "high": 20000}
                 ]


train_sampler_spec = [
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.001, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.01, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.1, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 1, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 10, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 100, 'weighted': True}
            }
        ]
    },
]



length_training = [{"generator": "ER"}]
length_testing = [{"generator": "ER", "length": 1000, "p": 0.01}]




model, results = variation_testing(train_sampler_spec, copy.deepcopy(train_sampler_spec), model = model, bypass_training = False)

results




step = 0 | loss = 5.39536190032959 | val_acc = 0.01953125 | test_acc = 0.0
step = 10 | loss = 2.2167346477508545 | val_acc = 0.12109375 | test_acc = 0.014453125186264515
step = 20 | loss = 1.2875858545303345 | val_acc = 0.3203125 | test_acc = 0.06171875074505806
step = 30 | loss = 0.9298624992370605 | val_acc = 0.359375 | test_acc = 0.09687500447034836
step = 40 | loss = 0.6615515351295471 | val_acc = 0.375 | test_acc = 0.08203125
step = 50 | loss = 0.6140495538711548 | val_acc = 0.38671875 | test_acc = 0.06562500447034836
step = 60 | loss = 0.49350741505622864 | val_acc = 0.2421875 | test_acc = 0.04726562649011612
step = 70 | loss = 0.5167531967163086 | val_acc = 0.21875 | test_acc = 0.03750000149011612
step = 80 | loss = 0.43201586604118347 | val_acc = 0.23828125 | test_acc = 0.05117187649011612
step = 90 | loss = 0.4791484475135803 | val_acc = 0.5 | test_acc = 0.07343750447034836
step = 100 | loss = 0.4280708134174347 | val_acc = 0.60546875 | test_acc = 0.32539063692092896
step = 11

[({'num_samples': 100,
   'batch_size': 32,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 8,
     'kwargs': {'low': 1, 'high': 1.1, 'weighted': True}}]},
  {'num_samples': 40,
   'batch_size': 40,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 64,
     'kwargs': {'low': 1, 'high': 1.1, 'weighted': True}}]},
  (0.9846031227966948, 0.9744017733208385)),
 ({'num_samples': 100,
   'batch_size': 32,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 8,
     'kwargs': {'low': 1, 'high': 1.01, 'weighted': True}}]},
  {'num_samples': 40,
   'batch_size': 40,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 64,
     'kwargs': {'low': 1, 'high': 1.01, 'weighted': True}}]},
  (0.99344147749463, 0.969707455327214)),
 ({'num_samples': 100,
   'batch_size': 32,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 8,
     'kwargs': {'low': 1, 'high': 1.2, 'weighted': True}}

In [16]:
# import copy
# model2 = copy.deepcopy(model)

ER p=0.25, 100 8x8 train and 40 32x32 test => 0.94 in 100 iterations

BA param=3, 100 8x8 train and 40 32x32 test => 0.97 in 100 iterations

BA param=5, 100 8x8 train and 40 32x32 test => 0.95 in 100 iterations (0.951 in 200 so has pretty much converged after 100)

BA param=7, 100 8x8 train and 40 32x32 test => 0.946 in 100 iterations

#### Cross training
BA param=7 to BA param=3

BA param=7 to ER p=0.25 0.946 with BA to 0.939 with ER (same as if trained only on BA)

ER p=0.25 to BA param=3 went from 0.939 with ER to 0.967 with BA (BA param 3 was 0.97 so basically nothing lost)

#### Weight variations
Uniform
* 0,0.001 -> 0.928
* 1,1.001 -> 0.962
* 0,0.1 -> 0.931
* 0,10 -> 0.883
* 0,100 -> 0.77
* 50, 200 -> 0.72
* 500, 2000 -> 0.69
* 5000, 20000 -> 0.7

Normal:
Basically same.

Gumbel
* 0,0.001 -> 0.323
* 1,1.001 -> 0.849
* 0,0.1 -> 0.498
* 5,10 -> 0.82
* 5,100 -> 0.8

#### Weight cross training
Train ER p=0.25 unif 0,1:
* 0,0.001 -> 0.948
* 1,1.001 -> 0.967
* 0,0.1 -> 0.916
* 0,10 -> 0.86
* 0,100 -> 0.75
* 50, 200 -> 0.72
* 500, 2000 -> 0.72
* 5000, 20000 -> 0.69

=> Seems to weight generalize quite well. Actually even better because basically no statistical difference with if we trained separately.

Train normal 5000, 20000:
* 0, 0.001 -> 0.39 (maybe it's the large to small that was a problem here? Also those values make little sense for a normal RV)

Other direction train normal 0, 0.001 (got to 0.78):
* 5000, 20000 -> 0.76

=> small to large seems better


#### Larger graphs
Same training
ER p=0.25 8x8 train:
* 100x100 test goes to 0.88
* 200x200 goes to 0.63 (only 12 prediction mismatches though)
* 200x200 p=0.3 =>
* 250x250 => 0.9448 (BUT p=0.1 to not kill my computer)
*
Try this but 16x16 train

#### RIDESHARE
8x8 train,
* 32x32 test => 0.96
* 50x50 test => 0.96
* 100x100 test => 0.938
* 250x250 test => 0.9

#### Double max
8x8 train 32x23 test
300 iterations gets us to 0.93 as normal max (though normal max takes 100 iterations to get there), 600 iterations gets us to 0.965
==> Testing single max on 600 iterations => 0.956
==> Testing single max with 64 hidden dim embeddings on 600 iterations 0.96 (already in 200) (seeing if gain is only from more parameters or if double max is actually more aligned)

Conclusion, it was mainly due to more iterations + some amount of more parameters but only 1% so probably not statistically significant.

#### Training with scaling
Train/test with 5000, 200000 weights ==> 0.76 accuracy
But if normalize 0, 1 on training (or just train on normalized) ==> 0.91 (same acc as had train/testing on normalized)

#### More weight scales training
300 epochs for all, 4x4 train, 32x32 test
* 0, 1: 0.946, 0.917
* 1, 1.01: 0.993, 0.970
* 1, 1.001: 0.956, 0.976
* 1, 1.1: 0.989, 0.972
* 1, 1.2: 0.986, 0.958
* 1, 1.5: 0.966, 0.949
* 1, 2: 0.957, 0.939
* 2, 2.1: 0.9927, 0.9757
* 10, 10.001: 0.4, 0.97
Realization: shifting just doesn't makes sense (val + 1000) / (opt + 1000) > val / opt


## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [27]:
# For two-way
def count_mismatches_two_way(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        for i in range(32):
            owner = data[datapoint][i]
            good = data[datapoint][int(owner)]
            if good != i:
                count += 1
    print(f"average number of edges contradicting matching: {count / nb_graphs}")

average number of edges contradicting matching: 12.2


In [17]:
# For self-loops
def count_mismatches_self_loop(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        owners = set(np.array(data[datapoint][32:64]))
        count += 32 - len(owners)
    print(f"average number of edges contradicting matching: {count / nb_graphs}")


average number of edges contradicting matching: 0.4


In [14]:
a = np.array([1, 2])
b = np.array([2, 3])
print(np.concatenate((a, b)))

[1 2 2 3]
