In [1]:
import copy

import jax
import clrs
import numpy as np

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))


  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [13]:
# If you don't want BipartiteMatching, just pass empty generator list and
# length separately

train_sampler_spec = {
    'num_samples': 100,
    'batch_size': 32,
    'schematics': [
        {
            'generator': 'ER',
            'proportion': 1,
            'length': 16,
            'kwargs': {'p': 0.5, 'low': 0, 'high': 1, 'weighted': True}
        }
    ]
}

test_sampler_spec = {
    'num_samples': 40,
    'batch_size': 40,
    'schematics': [
        {
            'generator': 'GEOMETRIC',
            'proportion': 1,
            'length': 64,
            'kwargs': {'threshold': 0, 'scaling': 1 / np.sqrt(2)}
        },
    ]
}

def samplers(sampler_spec, **kwargs):
    batch_size = sampler_spec.get('batch_size', 1)
    num_samples = sampler_spec['num_samples']
    if batch_size > num_samples:
        batch_size = num_samples

    def _iterate_sampler(sampler, batch_size):
        while True:
            yield sampler.next(batch_size)

    sampler, spec = clrs.build_sampler(
        name = 'simplified_min_sum',
        sampler_spec = sampler_spec,
        **kwargs)  # number of nodes

    sampler = _iterate_sampler(sampler, batch_size = batch_size)
    return sampler, spec

train_sampler, spec = samplers(train_sampler_spec)
test_sampler, _ = samplers(test_sampler_spec)

In [14]:
def define_model(spec, train_sampler, model = "mpnn"):
    if model == "mpnn":
        processor_factory = clrs.get_processor_factory('mpnn', use_ln = True,
                                                   nb_triplet_fts = 0)  #use_ln => use layer norm
    elif model == "gat":
        processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)

    elif model == "mpnndoublemax":
        processor_factory = clrs.get_processor_factory('mpnndoublemax', use_ln = True,
                                                       nb_triplet_fts = 0)  #use_ln => use layer norm

    model_params = dict(
        processor_factory = processor_factory,  # contains the processor_factory
        hidden_dim = 32,  # TODO put back to 32 if no difference
        encode_hints = True,
        decode_hints = True,
        #decode_diffs=False,
        #hint_teacher_forcing_noise=1.0,
        hint_teacher_forcing = 1.0,
        use_lstm = False,
        learning_rate = 0.001,
        checkpoint_path = '/tmp/checkpt',
        freeze_processor = False,  # Good for post step
        dropout_prob = 0.5,
        # nb_msg_passing_steps=3,
    )

    dummy_trajectory = next(train_sampler)  # jax needs a trajectory that is plausible looking to init

    model = clrs.models.BaselineModel(
        spec = spec,
        dummy_trajectory = dummy_trajectory,
        **model_params
    )

    model.init(dummy_trajectory.features, 1234)  # 1234 is a random seed

    return model

model = define_model(spec, train_sampler, "mpnn")

In [15]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
# step = 0
#
# while step <= 1:
#     feedback, test_feedback = next(train_sampler), next(test_sampler)
#     rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#     cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#     rng_key = new_rng_key
#     if step % 10 == 0:
#         print(step)
#     step += 1



In [16]:
def train(model, epochs, train_sampler, test_sampler):
    step = 0
    rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

    while step <= epochs:
        feedback, test_feedback = next(train_sampler), next(test_sampler)
        # TODO remove - testing if uses hints on tests
        # shape = test_feedback.features.hints[0].data[0].shape
        # test_feedback.features.hints[0].data = test_feedback.features.hints[0].data[0, :, :].reshape((1, *shape))

        rng_key, new_rng_key = jax.random.split(rng_key)  # jax needs new random seed at step
        cur_loss = model.feedback(rng_key, feedback)  # loss is contained in model somewhere
        rng_key = new_rng_key
        if step % 10 == 0:
            predictions_val, _ = model.predict(rng_key, feedback.features)
            out_val = clrs.evaluate(feedback.outputs, predictions_val)
            predictions, _ = model.predict(rng_key, test_feedback.features)
            out = clrs.evaluate(test_feedback.outputs, predictions)
            print(
                f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')  # here, val accuracy is actually training accuracy, not great but is example
        step += 1
    return model

In [17]:
model = train(model, 100, train_sampler, test_sampler)

step = 0 | loss = 6.619992733001709 | val_acc = 0.017578125 | test_acc = 0.0011718750465661287
step = 10 | loss = 2.4107420444488525 | val_acc = 0.11328125 | test_acc = 0.04140625149011612
step = 20 | loss = 1.3147313594818115 | val_acc = 0.3515625 | test_acc = 0.10312499850988388
step = 30 | loss = 0.7281531691551208 | val_acc = 0.28125 | test_acc = 0.11054687947034836
step = 40 | loss = 0.6583711504936218 | val_acc = 0.3125 | test_acc = 0.11093749850988388
step = 50 | loss = 0.48690590262413025 | val_acc = 0.529296875 | test_acc = 0.08945312350988388
step = 60 | loss = 0.46321335434913635 | val_acc = 0.82421875 | test_acc = 0.11796875298023224
step = 70 | loss = 0.3621886670589447 | val_acc = 0.75390625 | test_acc = 0.10195312649011612
step = 80 | loss = 0.29430145025253296 | val_acc = 0.697265625 | test_acc = 0.107421875
step = 90 | loss = 0.3907209038734436 | val_acc = 0.66015625 | test_acc = 0.11640625447034836
step = 100 | loss = 0.3608171343803406 | val_acc = 0.64453125 | test_a

In [18]:
from scipy.optimize import linear_sum_assignment


def matching_value(samples, predictions, partial = False, match_rest = False, opt_scipy = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[3].data
    pred_accuracy = 0
    greedy_accuracy = 0

    #TODO remove
    def _add_uniform_weights(adj, low, high):
        n, m = adj.shape
        weights = np.random.uniform(
            low=low, high=high, size=(n, m)
        )
        return adj * high + low

    # Iterating over all the samples
    for i in range(data.shape[0]):

        if opt_scipy:
            row_ind, col_ind = linear_sum_assignment(data[i], maximize = True)
            max_weight = data[i][row_ind, col_ind].sum() / 2  #TODO why /2
        else:
            max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        predicted_matching = predictions["match"].data[i]

        if partial:
            preds_weight = compute_partial_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, greedy learned: {preds_weight}")
        else:
            preds_weight = compute_greedy_matching_weight(i, data, masks, predicted_matching, match_rest = match_rest)
            print(f"opt: {max_weight}, partial: {preds_weight}")

        # assert preds_weight <= max_weight
        greedy_matching_weight = naive_greedy(i, data, masks)
        print(f"Naive greedy: {greedy_matching_weight}")
        greedy_accuracy += greedy_matching_weight / max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0], greedy_accuracy / data.shape[0]

def naive_greedy(i, data, masks):
    """Computes a matching greedily by, for each node, adding the maximum neighbor that
    hasn't yet been added to the matching"""

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    # At the start, all the right hand side values are possible matches
    matching_mask = np.full(A.shape[0], True)

    # for buyer in range(n):
    #     # Checking if there are more elements to match (if more buyers than goods)
    #     if A[buyer, matching_mask].shape[0] != 0:
    #         matching_weight += np.max(A[buyer, matching_mask])
    #         # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
    #         subset_idx = np.argmax(A[buyer, matching_mask])
    #         good = np.arange(A.shape[1])[matching_mask][subset_idx]
    #         # The corresponding good cannot be used anymore
    #         matching_mask[good] = False

    # Second method of computing a greedy matching
    # Set of vertices already matched
    matching = set()
    # Get the indices of the weights in highest to lowest order (hence the negative sign), inspired by https://stackoverflow.com/questions/30577375/have-numpy-argsort-return-an-array-of-2d-indices
    indices = np.dstack(np.unravel_index(np.argsort(- A.ravel()), A.shape))
    for index in indices[0]:
        if index[0] not in matching and index[1] not in matching:
            matching_weight += A[tuple(index)]
            matching.add(index[0])
            matching.add(index[1])


    return matching_weight




def compute_greedy_matching_weight(i, data, masks, matching, match_rest = False):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)
    unmatched_goods = set(range(n, n + m))
    unmatched_buyers = set(range(n))

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            mask = matching == buyer
            matching_weight += np.max(A[buyer, mask])
            # Recovering the index of the maximum, inspired by http://seanlaw.github.io/2015/09/10/numpy-argmin-with-a-condition/
            subset_idx = np.argmax(A[buyer, mask])
            good = np.arange(A.shape[1])[mask][subset_idx]
            unmatched_goods.remove(good)
            unmatched_buyers.remove(buyer)

    if match_rest and len(unmatched_goods) > 0 and len(unmatched_buyers) > 0:
        # Compute optimal matching on the remaining unmatched nodes
        mask = np.zeros(A.shape)
        # TODO this is a horrible solution, there's definitely a prettier solution
        mask[list(unmatched_buyers)] += 1
        mask[:, list(unmatched_goods)] += 1
        mask = np.where(mask == 2, True, False)
        remaining_bipartite_graph = A * mask
        row_ind, col_ind = linear_sum_assignment(remaining_bipartite_graph, maximize = True)
        opt = A[row_ind, col_ind].sum() / 2  #TODO do I always need the division by 2
        matching_weight += opt

    return matching_weight


def compute_partial_matching_weight(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n + m]
    matching = matching[:n, n:n + m]

    max_weight = np.max(np.sum(matching, axis = 0))
    print(f"max weight: {max_weight}")
    matching /= max_weight
    return np.sum(matching * A_submatrix)

In [19]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)
matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)


opt: 15.615172295331325, partial: 6.095603697533447
Naive greedy: 15.289470167028691
opt: 16.73994838212646, partial: 6.7020031255124035
Naive greedy: 16.327635592598387
opt: 15.615172295331325, partial: 6.095603697533447
Naive greedy: 15.289470167028691
opt: 18.09627159980599, partial: 9.653710925139084
Naive greedy: 17.956110937943848
opt: 16.82287331299731, partial: 7.005816185190747
Naive greedy: 16.58330043512991
opt: 17.007992729049086, partial: 8.577196068462275
Naive greedy: 16.515635135059178
opt: 17.420226410794875, partial: 8.043418746651392
Naive greedy: 16.876647551362748
opt: 15.741199280067438, partial: 8.24723344632167
Naive greedy: 15.582544100683762
opt: 16.93222620857793, partial: 7.970584109631478
Naive greedy: 16.70145328592178
opt: 16.93222620857793, partial: 7.970584109631478
Naive greedy: 16.70145328592178
opt: 16.58016942204692, partial: 6.397045235540737
Naive greedy: 16.311817952841317
opt: 15.94204068675385, partial: 7.048934657517263
Naive greedy: 15.684822

(0.4369315977049094, 0.9815892218661197)

In [None]:
import copy

def variation_testing(train_sampler_spec, test_sampler_spec, epochs = 300, model = None, bypass_training = False):
    if model is None and bypass_training:
        print("Need a model to bypass training")
        return


    matching_values = []
    for train_param, test_param in zip(train_sampler_spec, test_sampler_spec):
        test_param['num_samples'] = 40
        test_param['batch_size'] = 40
        schematics = test_param['schematics']
        schematics[0]['length'] = 64
        test_param['schematics'] = schematics

        test_sampler, _ = samplers(test_param)

        if not bypass_training:
            train_sampler, spec = samplers(train_param)
            model = define_model(spec, train_sampler, model="mpnn")
            train(model, epochs, train_sampler, test_sampler)
        else:
            print("Bypassing training")

        test_feedback = next(test_sampler)
        predictions, _ = model.predict(rng_key, test_feedback.features)
        accuracy = matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)

        matching_values.append((train_param, test_param, accuracy))
    return model, matching_values

weight_params = [{"low": 0, "high": 0.001},
                 {"low": 1, "high": 1.001},
                 {"low": 1, "high": 1.1},
                 {"low": 1, "high": 2},
                 {"low": 0, "high": 0.1},
                 {"low": 0, "high": 1},
                 # {"low": 0, "high": 10},
                 # {"low": 0, "high": 100},
                 # {"low": 50, "high": 200},
                 # {"low": 500, "high": 2000},
                 # {"low": 5000, "high": 20000}
                 ]


train_sampler_spec = [
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.001, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.01, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 0.1, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 1, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 10, 'weighted': True}
            }
        ]
    },
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics': [
            {
                'generator': 'ER',
                'proportion': 1,
                'length': 8,
                'kwargs': {'low': 0, 'high': 100, 'weighted': True}
            }
        ]
    },
]



length_training = [{"generator": "ER"}]
length_testing = [{"generator": "ER", "length": 1000, "p": 0.01}]




model, results = variation_testing(train_sampler_spec, copy.deepcopy(train_sampler_spec), model = model, bypass_training = False)

results




step = 0 | loss = 5.587885856628418 | val_acc = 0.11328125 | test_acc = 0.021484375
step = 10 | loss = 2.453561782836914 | val_acc = 0.36328125 | test_acc = 0.06562500447034836
step = 20 | loss = 1.0959182977676392 | val_acc = 0.13671875 | test_acc = 0.05976562574505806
step = 30 | loss = 0.6259484887123108 | val_acc = 0.359375 | test_acc = 0.07109375298023224
step = 40 | loss = 0.33939117193222046 | val_acc = 0.05078125 | test_acc = 0.05390625074505806
step = 50 | loss = 0.2931690514087677 | val_acc = 0.1015625 | test_acc = 0.037109375
step = 60 | loss = 0.2331790030002594 | val_acc = 0.625 | test_acc = 0.03398437425494194
step = 70 | loss = 0.170087069272995 | val_acc = 0.6640625 | test_acc = 0.2925781309604645
step = 80 | loss = 0.13126277923583984 | val_acc = 0.7265625 | test_acc = 0.482421875
step = 90 | loss = 0.13931426405906677 | val_acc = 0.7421875 | test_acc = 0.4632812440395355
step = 100 | loss = 0.11292165517807007 | val_acc = 0.6953125 | test_acc = 0.48945313692092896
ste

In [16]:
# import copy
# model2 = copy.deepcopy(model)

ER p=0.25, 100 8x8 train and 40 32x32 test => 0.94 in 100 iterations

BA param=3, 100 8x8 train and 40 32x32 test => 0.97 in 100 iterations

BA param=5, 100 8x8 train and 40 32x32 test => 0.95 in 100 iterations (0.951 in 200 so has pretty much converged after 100)

BA param=7, 100 8x8 train and 40 32x32 test => 0.946 in 100 iterations

#### Cross training
BA param=7 to BA param=3

BA param=7 to ER p=0.25 0.946 with BA to 0.939 with ER (same as if trained only on BA)

ER p=0.25 to BA param=3 went from 0.939 with ER to 0.967 with BA (BA param 3 was 0.97 so basically nothing lost)

#### Weight variations
Uniform
* 0,0.001 -> 0.928
* 1,1.001 -> 0.962
* 0,0.1 -> 0.931
* 0,10 -> 0.883
* 0,100 -> 0.77
* 50, 200 -> 0.72
* 500, 2000 -> 0.69
* 5000, 20000 -> 0.7

Normal:
Basically same.

Gumbel
* 0,0.001 -> 0.323
* 1,1.001 -> 0.849
* 0,0.1 -> 0.498
* 5,10 -> 0.82
* 5,100 -> 0.8

#### Weight cross training
Train ER p=0.25 unif 0,1:
* 0,0.001 -> 0.948
* 1,1.001 -> 0.967
* 0,0.1 -> 0.916
* 0,10 -> 0.86
* 0,100 -> 0.75
* 50, 200 -> 0.72
* 500, 2000 -> 0.72
* 5000, 20000 -> 0.69

=> Seems to weight generalize quite well. Actually even better because basically no statistical difference with if we trained separately.

Train normal 5000, 20000:
* 0, 0.001 -> 0.39 (maybe it's the large to small that was a problem here? Also those values make little sense for a normal RV)

Other direction train normal 0, 0.001 (got to 0.78):
* 5000, 20000 -> 0.76

=> small to large seems better


#### Larger graphs
Same training
ER p=0.25 8x8 train:
* 100x100 test goes to 0.88
* 200x200 goes to 0.63 (only 12 prediction mismatches though)
* 200x200 p=0.3 =>
* 250x250 => 0.9448 (BUT p=0.1 to not kill my computer)
*
Try this but 16x16 train

#### RIDESHARE
8x8 train,
* 32x32 test => 0.96
* 50x50 test => 0.96
* 100x100 test => 0.938
* 250x250 test => 0.9

#### Double max
8x8 train 32x23 test
300 iterations gets us to 0.93 as normal max (though normal max takes 100 iterations to get there), 600 iterations gets us to 0.965
==> Testing single max on 600 iterations => 0.956
==> Testing single max with 64 hidden dim embeddings on 600 iterations 0.96 (already in 200) (seeing if gain is only from more parameters or if double max is actually more aligned)

Conclusion, it was mainly due to more iterations + some amount of more parameters but only 1% so probably not statistically significant.

#### Training with scaling
Train/test with 5000, 200000 weights ==> 0.76 accuracy
But if normalize 0, 1 on training (or just train on normalized) ==> 0.91 (same acc as had train/testing on normalized)

#### More weight scales training
300 epochs for all, 4x4 train, 32x32 test
* 0, 1: 0.946, 0.917
* 1, 1.01: 0.993, 0.970
* 1, 1.001: 0.956, 0.976
* 1, 1.1: 0.989, 0.972
* 1, 1.2: 0.986, 0.958
* 1, 1.5: 0.966, 0.949
* 1, 2: 0.957, 0.939
* 2, 2.1: 0.9927, 0.9757
* 10, 10.001: 0.4, 0.97
Realization: shifting just doesn't makes sense (val + 1000) / (opt + 1000) > val / opt


## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [27]:
# For two-way
def count_mismatches_two_way(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        for i in range(32):
            owner = data[datapoint][i]
            good = data[datapoint][int(owner)]
            if good != i:
                count += 1
    print(f"average number of edges contradicting matching: {count / nb_graphs}")

average number of edges contradicting matching: 12.2


In [17]:
# For self-loops
def count_mismatches_self_loop(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        owners = set(np.array(data[datapoint][32:64]))
        count += 32 - len(owners)
    print(f"average number of edges contradicting matching: {count / nb_graphs}")


average number of edges contradicting matching: 0.4


In [14]:
a = np.array([1, 2])
b = np.array([2, 3])
print(np.concatenate((a, b)))

[1 2 2 3]
