In [95]:
import copy

import jax
import clrs
import numpy as np

%load_ext autoreload
%autoreload 2

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
# If you don't want BipartiteMatching, just pass empty generator list and
# length separately

train_sampler_spec = {
    'num_samples': 100,
    'batch_size':  32,
    'schematics':  [
        {
            'generator':  'ER',
            'proportion': 0,
            'length':     16,
            'kwargs':     {'p': 0.8, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
        {
            'generator':  'ER',
            'proportion': 1,
            'length':     3,
            'length_2':   15,
            'kwargs':     {'p': 0.8, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
        {
            'generator':  'ER',
            'proportion': 0,
            'length':     10,
            'length_2':   50,
            'kwargs':     {'p': 0.1, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
    ]
}

test_sampler_spec = {
    'num_samples': 40,
    'batch_size':  40,
    'schematics':  [
        {
            'generator':  'ER',
            'proportion': 0,
            'length':     100,
            'kwargs':     {'p': 0.05, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
        {
            'generator':  'ER',
            'proportion': 0,
            'length':     16,
            'kwargs':     {'p': 0.8, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
        {
            'generator':  'ER',
            'proportion': 1,
            'length':     3,
            'length_2':   15,
            'kwargs':     {'p': 0.8, 'low': 0, 'high': 1, 'weighted': True},
            'online':     True
        },
    ]
}


def samplers(sampler_spec, **kwargs):
    batch_size = sampler_spec.get('batch_size', 1)
    num_samples = sampler_spec['num_samples']
    if batch_size > num_samples:
        batch_size = num_samples

    def _iterate_sampler(sampler, batch_size):
        while True:
            yield sampler.next(batch_size)

    sampler, spec = clrs.build_sampler(
        name = 'online_bipartite_matching',
        sampler_spec = sampler_spec,
        **kwargs)  # number of nodes

    sampler = _iterate_sampler(sampler, batch_size = batch_size)
    return sampler, spec


train_sampler, spec = samplers(train_sampler_spec)
test_sampler, _ = samplers(test_sampler_spec)

119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119
119


In [97]:
sample = next(test_sampler)
sample.features.inputs[1].data[0].shape

(19, 19)

In [98]:
def define_model(spec, train_sampler, model = "mpnn"):
    if model == "mpnn":
        processor_factory = clrs.get_processor_factory('mpnn', use_ln = True,
                                                       nb_triplet_fts = 4)  #use_ln => use layer norm
    elif model == "gat":
        processor_factory = clrs.get_processor_factory('gat', use_ln = True, nb_heads = 4, nb_triplet_fts = 4)

    elif model == "mpnndoublemax":
        processor_factory = clrs.get_processor_factory('mpnndoublemax', use_ln = True,
                                                       nb_triplet_fts = 0)  #use_ln => use layer norm

    elif model == "gmpnn":
        processor_factory = clrs.get_processor_factory('gmpnn', use_ln = True,
                                                       nb_triplet_fts = 4)  #use_ln => use layer norm
    elif model == "pgn":
        processor_factory = clrs.get_processor_factory('pgn', use_ln = True,
                                                       nb_triplet_fts = 32)  #use_ln => use layer norm
    elif model == "triplet_pgn_mask":
        processor_factory = clrs.get_processor_factory('triplet_pgn_mask', use_ln = True,
                                                       nb_triplet_fts = 32)  #use_ln => use layer norm

    model_params = dict(
        processor_factory = processor_factory,  # contains the processor_factory
        hidden_dim = 32,  # TODO put back to 32 if no difference
        encode_hints = True,
        decode_hints = True,
        #decode_diffs=False,
        #hint_teacher_forcing_noise=1.0,
        hint_teacher_forcing = 0.5,
        use_lstm = False,
        learning_rate = 0.001,
        checkpoint_path = '/tmp/checkpt',
        freeze_processor = False,  # Good for post step
        dropout_prob = 0.5,
        # nb_msg_passing_steps=3,
    )

    dummy_trajectory = next(train_sampler)  # jax needs a trajectory that is plausible looking to init

    model = clrs.models.BaselineModel(
        spec = spec,
        dummy_trajectory = dummy_trajectory,
        **model_params
    )

    model.init(dummy_trajectory.features, 1234)  # 1234 is a random seed

    return model


model = define_model(spec, train_sampler, "mpnn")

In [99]:
# No evaluation since we are postprocessing with soft: TO CHANGE -> baselines.py line 336 outs change hard to False
# step = 0
#
# while step <= 1:
#     feedback, test_feedback = next(train_sampler), next(test_sampler)
#     rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
#     cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
#     rng_key = new_rng_key
#     if step % 10 == 0:
#         print(step)
#     step += 1



In [100]:
def train(model, epochs, train_sampler, test_sampler):
    step = 0
    rng_key = jax.random.PRNGKey(rng.randint(2 ** 32))

    while step <= epochs:
        feedback, test_feedback = next(train_sampler), next(test_sampler)

        rng_key, new_rng_key = jax.random.split(rng_key)  # jax needs new random seed at step
        cur_loss = model.feedback(rng_key, feedback)  # loss is contained in model somewhere
        rng_key = new_rng_key
        if step % 10 == 0:
            predictions_val, _ = model.predict(rng_key, feedback.features)
            out_val = clrs.evaluate(feedback.outputs, predictions_val)
            predictions, _ = model.predict(rng_key, test_feedback.features)
            out = clrs.evaluate(test_feedback.outputs, predictions)
            print(
                f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}')  # here, val accuracy is actually training accuracy, not great but is example
        # if step % 150 == 0:
        # learned, greedy = matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)
        # print(f"**learned: {learned}, greedy: {greedy}**")
        step += 1
    return model

In [133]:
model = train(model, 300, train_sampler, test_sampler)

step = 0 | loss = 2.5965096950531006 | val_acc = 0.8667763471603394 | test_acc = 0.8447368741035461
step = 10 | loss = 2.49609375 | val_acc = 0.8996710777282715 | test_acc = 0.8565790057182312
step = 20 | loss = 2.5253994464874268 | val_acc = 0.8782894611358643 | test_acc = 0.8302631974220276
step = 30 | loss = 2.554109811782837 | val_acc = 0.8585526347160339 | test_acc = 0.8381579518318176
step = 40 | loss = 2.468383312225342 | val_acc = 0.9029605388641357 | test_acc = 0.8565790057182312
step = 50 | loss = 2.469703435897827 | val_acc = 0.9013158082962036 | test_acc = 0.8684210777282715
step = 60 | loss = 2.366101026535034 | val_acc = 0.8815789818763733 | test_acc = 0.8460526466369629
step = 70 | loss = 2.2869222164154053 | val_acc = 0.8832237124443054 | test_acc = 0.8697368502616882
step = 80 | loss = 2.364321231842041 | val_acc = 0.8634868264198303 | test_acc = 0.8815789818763733
step = 90 | loss = 2.220386028289795 | val_acc = 0.8914473652839661 | test_acc = 0.8315789699554443
step 

In [134]:
def matching_value(samples, predicted_hints, predicted_outputs, predict_outputs = False, predict_match_h = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    predicted_matchings = predicted_outputs['match'].data
    # inputs for the matrix A are at index 1 (see spec.py)
    matrices = features.inputs[1].data
    masks = features.inputs[4].data
    current_nodes = features.hints[3].data
    # If there is no matching of weight > 0, don't count it
    count_non_zero_matchings = 0
    pred_accuracy = 0
    greedy_accuracy = 0

    # Iterating over all the samples
    for i in range(matrices.shape[0]):
        max_weight = compute_matching_weight_from_matching(matrices[i], masks[i], gt_matchings[i], predicted = False)

        if max_weight > 0:
            if predict_outputs:
                preds_weight = compute_matching_weight_from_matching(matrices[i], masks[i], predicted_matchings[i], predicted = True)
            elif predict_match_h:
                preds_weight = compute_matching_weight_from_match_h(i, matrices[i], masks[i], predicted_hints, current_nodes)
            else:
                preds_weight = compute_hint_matching_weight(i, matrices[i], masks[i], predicted_hints, current_nodes)



            greedy_weight = compute_greedy_matching_weight(i, matrices[i], masks[i], current_nodes, random_match = False)

            print(f"max weight: {max_weight} predicted weight: {preds_weight}, greedy weight: {greedy_weight}")

            # assert preds_weight <= max_weight
            # assert greedy_weight <= max_weight

            greedy_accuracy += greedy_weight / max_weight
            pred_accuracy += preds_weight / max_weight

            count_non_zero_matchings += 1



    return pred_accuracy / count_non_zero_matchings, greedy_accuracy / count_non_zero_matchings


def compute_matching_weight_from_matching(A, mask, matching, predicted = False):
    matching_weight = 0
    matched = set()
    # m is the number of offline nodes, -1 to not count the no match node
    m = int(np.sum(mask)) - 1
    # n is the number of online nodes
    online_mask = 1 - mask
    n = int(np.sum(online_mask))
    unmatched_node = m + n

    #TODO remove, is not great but here for eval, need 0-valued edges to the unmatched node
    A[unmatched_node, :] = 0
    A[:, unmatched_node] = 0


    for online_node in range(m, m+n):
        match = int(matching[online_node])
        if match != unmatched_node:
            # If points to self => weight is 0
            if online_node != match and match not in matched:
                matching_weight += A[online_node, match]
            if not predicted:
                # Checking that a same offline node is not assigned twice if using opt
                assert match not in matched
            matched.add(match)

    # print(f"opt matching: {matching}")

    return matching_weight


def compute_matching_weight_from_match_h(i, A, mask, predicted_hints, current_nodes):
    matching_weight = 0
    # m is the number of offline nodes, -1 to not count the no match node
    m = int(np.sum(mask)) - 1
    # n is the number of online nodes
    online_mask = 1 - mask
    n = int(np.sum(online_mask))
    unmatched_node = m + n
    matched = np.ones(A.shape[0])
    matched[m:m+n] = 0

    matching = np.arange(m+n+1)

    #TODO remove, is not great but here for eval, need 0-valued edges to the unmatched node
    A[unmatched_node, :] = 0
    A[:, unmatched_node] = 0

    for iter, iter_hint in enumerate(predicted_hints):
        hint = iter_hint["match_h"][i]
        match = np.argmax(hint)
        print(f"match predicted: {match}")
        online_node = np.argmax(current_nodes[iter, i])
        if np.sum(current_nodes[iter, i]) != 0 and  A[online_node, match] != 0:
            # If not the case, we have reached the end of the hints for this instance (there can be more if other instances require more hints) but then current_nodes[iter, i] will be the all 0s vector
            # Also if  A[online_node, greedy_match] is 0, we can equivalently say that the node isn't matched (and it avoids using up a match for nothing). Due to sparsity
            if online_node != match:
                matching_weight += A[online_node, match]
            if match != unmatched_node:
                # Checking that a same offline node is not assigned twice
                assert matched[match] == 1
                matched[match] = 0
                matching[match] = online_node
            matching[online_node] = match
        # TODO remove Checking if the model isn't matching with a stupid node
        # assert match < m


    # print(f"hint matching: {matching}")
    return matching_weight




def compute_hint_matching_weight(i, A, mask, predicted_hints, current_nodes):
    matching_weight = 0
    # m is the number of offline nodes, -1 to not count the no match node
    m = int(np.sum(mask)) - 1
    # n is the number of online nodes
    online_mask = 1 - mask
    n = int(np.sum(online_mask))
    unmatched_node = m + n
    matched = np.ones(A.shape[0])
    matched[m:m+n] = 0

    matching = np.arange(m+n+1)

    #TODO remove, is not great but here for eval, need 0-valued edges to the unmatched node
    A[unmatched_node, :] = 0
    A[:, unmatched_node] = 0

    for iter, iter_hint in enumerate(predicted_hints):
        hint = iter_hint["value_to_go_h"][i]
        # Shifting hints to be positive so that the maximum is always > 0 if there are still edges to match
        shifted_hint = hint + abs(np.min(hint))
        # print(shifted_hint / np.max(shifted_hint))
        masked_hint = shifted_hint * matched
        match = np.argmax(masked_hint)
        # current_nodes[iter, i] is a one-hot encoding of the current online node
        online_node = np.argmax(current_nodes[iter, i])
        if np.sum(current_nodes[iter, i]) != 0 and  A[online_node, match] != 0:
            # If not the case, we have reached the end of the hints for this instance (there can be more if other instances require more hints) but then current_nodes[iter, i] will be the all 0s vector
            # Also if  A[online_node, greedy_match] is 0, we can equivalently say that the node isn't matched (and it avoids using up a match for nothing). Due to sparsity
            if online_node != match:
                matching_weight += A[online_node, match]
            if match != unmatched_node:
                # Checking that a same offline node is not assigned twice
                assert matched[match] == 1
                matched[match] = 0
                matching[match] = online_node
            matching[online_node] = match
        # TODO remove Checking if the model isn't matching with a stupid node
        # assert match < m


    # print(f"hint matching: {matching}")

    return matching_weight


def compute_greedy_matching_weight(i, A, mask, current_nodes, random_match = False):
    matching_weight = 0
    matched = np.ones(A.shape[0])
    # m is the number of offline nodes, -1 to not count the no match node
    m = int(np.sum(mask)) - 1
    # n is the number of online nodes
    online_mask = 1 - mask
    n = int(np.sum(online_mask))
    unmatched_node = m + n
    matched[m:m+n] = 0

    matching = np.arange(m+n+1)

    #TODO remove, is not great but here for eval, need 0-valued edges to the unmatched node
    A[unmatched_node, :] = 0
    A[:, unmatched_node] = 0

    # current_nodes shape is iteration x samples x nodes
    for iter in range(current_nodes.shape[0]):
        # current_nodes[iter, i] is a one-hot encoding of the current online node
        online_node = np.argmax(current_nodes[iter, i])
        possible_matches = A[online_node, :] * matched

        if random_match:
            choices = np.arange(m+n+1)[mask == 1]
            greedy_match = np.random.choice(choices)
        else:
            greedy_match = np.argmax(possible_matches)


        if np.sum(current_nodes[iter, i]) != 0 and possible_matches[greedy_match] != 0:
            # If not the case, we have reached the end of the hints for this instance (there can be more if other instances require more hints) but then current_nodes[iter, i] will be the all 0s vector
            # Also if possible_matches[greedy_match] is 0, we can equivalently say that the node isn't matched (and it avoids using up a match for nothing). Due to sparsity
            if online_node != greedy_match:
                matching_weight += A[online_node, greedy_match]
            if greedy_match != unmatched_node:
                # Checking that a same offline node is not assigned twice
                assert matched[greedy_match] == 1
                matched[greedy_match] = 0
                matching[greedy_match] = online_node
            matching[online_node] = greedy_match
            # TODO remove Checking if doesn't do something stupid (this should have weight 0)
            assert greedy_match < m

    # print(f"greedy matching: {matching}")

    return matching_weight

In [135]:
test_feedback = next(test_sampler)


In [136]:
predictions, hints = model.predict(rng_key, test_feedback.features, return_hints = True)
matching_value(test_feedback, hints, predictions, predict_match_h = True)


match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
max weight: 1.7282347637124968 predicted weight: 0, greedy weight: 1.5924672515305591
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
max weight: 2.8167674621222405 predicted weight: 0, greedy weight: 1.6082363686321732
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
max weight: 2.516411003995997 predicted weight: 0, greedy weight: 1.8535205493567268
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
match predicted: 18
max weight: 1.4370672667245779 predicted weight: 0, greedy weight: 1.7110854265254556
match predicted: 

(0.0, 0.7102440164429182)

In [108]:
predictions['match'].data[0]

Array([10.,  5.,  3.,  2., 18.,  1.,  6.,  7., 18.,  0.,  0., 11., 12.,
       13., 14., 15., 16., 17., 18.], dtype=float32)

In [110]:
hints[0]['match_h'][0]

Array([ 5.256347  ,  5.3603277 ,  2.412511  ,  0.76202863,  0.32679403,
        0.4426093 , -0.7362897 , -0.73782516,  0.33354574,  0.37478045,
        0.35014975, -0.72401756, -0.69578147, -0.654595  , -0.66573054,
       -0.57426596, -0.55346304, -0.56289846,  6.020763  ], dtype=float32)

In [115]:
test_feedback.features.hints

[DataPoint(name="value_to_go_h",	location=node,	type=scalar,	data=Array(9, 40, 19)),
 DataPoint(name="L_h",	location=node,	type=mask,	data=Array(9, 40, 19)),
 DataPoint(name="match_h",	location=node,	type=mask_one,	data=Array(9, 40, 19)),
 DataPoint(name="modified_node",	location=node,	type=mask_one,	data=Array(9, 40, 19))]

In [128]:
iter = 3
test_feedback.features.hints[0].data[iter, 0]

array([0.27700264, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.94055817])

In [129]:
test_feedback.features.hints[2].data[iter, 0]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 1.])

In [109]:
np.sum(mat[8:16] == 0)

75

In [110]:

np.sum(mat[8:16] != 0)


61

In [111]:
test_feedback.features.inputs[1].data[3]

array([[0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.89949414, 0.98475698,
        0.60281735, 0.72707501, 0.1797913 , 0.5798067 , 0.29400526,
        0.        , 1.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.32709095, 0.64599815,
        0.46296552, 0.68987385, 0.33884224, 0.        , 0.84047621,
        0.59821317, 1.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.        , 0.62142034,
        0.26787452, 0.97212067, 0.54769404, 0.73009603, 0.        ,
        0.        , 1.        ],
       [0.        , 0.        , 0.        , 0.        , 0.        ,
        0.        , 0.        , 0.        , 0.90102657, 0.        ,
        0.98141306, 0.55263826, 0.33889225, 0.17873712, 0.        ,
        0.62468549, 1.        ],
       [0.        , 0.        , 0.        , 0.      

In [38]:
test_feedback.features.hints

[DataPoint(name="value_to_go_h",	location=node,	type=scalar,	data=Array(7, 40, 17)),
 DataPoint(name="L_h",	location=node,	type=mask,	data=Array(7, 40, 17)),
 DataPoint(name="match_h",	location=node,	type=pointer,	data=Array(7, 40, 17)),
 DataPoint(name="modified_node",	location=node,	type=mask_one,	data=Array(7, 40, 17))]

In [45]:
hints[5]["modified_node"][0]

Array([-2.4598014, -2.433341 , -2.343672 , -2.4668386, -2.469148 ,
       -2.4711394, -2.472333 , -2.4693463, -1.9480118, -1.7099818,
       -1.8171515, -2.936937 , -2.936448 , -2.9389527, -2.9396846,
       -2.0308576, -2.378499 ], dtype=float32)

In [21]:
import copy


def variation_testing(train_sampler_spec, test_sampler_spec, epochs = 300, model = None, bypass_training = False):
    if model is None and bypass_training:
        print("Need a model to bypass training")
        return

    matching_values = []
    for train_param, test_param in zip(train_sampler_spec, test_sampler_spec):
        # test_param['num_samples'] = 40
        # test_param['batch_size'] = 40
        # schematics = test_param['schematics']
        # schematics[0]['length'] = 1000
        # test_param['schematics'] = schematics

        print("starting test generation")
        test_sampler, _ = samplers(test_param)
        print("finished test generation")

        if not bypass_training:
            train_sampler, spec = samplers(train_param)
            model = define_model(spec, train_sampler, model = "mpnn")
            train(model, epochs, train_sampler, test_sampler)
        else:
            print("Bypassing training")

        test_feedback = next(test_sampler)
        predictions, _ = model.predict(rng_key, test_feedback.features)
        accuracy = matching_value(test_feedback, predictions, partial = False, match_rest = False, opt_scipy = True)

        matching_values.append((train_param, test_param, accuracy))
    return model, matching_values


weight_params = [{"low": 0, "high": 0.001},
                 {"low": 1, "high": 1.001},
                 {"low": 1, "high": 1.1},
                 {"low": 1, "high": 2},
                 {"low": 0, "high": 0.1},
                 {"low": 0, "high": 1},
                 # {"low": 0, "high": 10},
                 # {"low": 0, "high": 100},
                 # {"low": 50, "high": 200},
                 # {"low": 500, "high": 2000},
                 # {"low": 5000, "high": 20000}
                 ]

train_sampler_spec = [
    {
        'num_samples': 100, 'batch_size': 32,
        'schematics':  [
            {
                'generator':  'ER',
                'proportion': 1,
                'length':     64,
                'kwargs':     {'low': 0, 'high': 0.001, 'weighted': True}
            }
        ]
    },
]

test_sampler = [
    {
        'num_samples': 10, 'batch_size': 10,
        'schematics':  [
            {
                'generator':  'ER',
                'proportion': 1,
                'length':     1000,
                'kwargs':     {'p': 0.1, 'low': 0, 'high': 1, 'weighted': True}
            }
        ]
    },
]

length_training = [{"generator": "ER"}]
length_testing = [{"generator": "ER", "length": 1000, "p": 0.01}]

model, results = variation_testing(train_sampler_spec, copy.deepcopy(test_sampler), model = model,
                                   bypass_training = True)

results




starting test generation
finished test generation
Bypassing training


[({'num_samples': 100,
   'batch_size': 32,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 64,
     'kwargs': {'low': 0, 'high': 0.001, 'weighted': True}}]},
  {'num_samples': 10,
   'batch_size': 10,
   'schematics': [{'generator': 'ER',
     'proportion': 1,
     'length': 1000,
     'kwargs': {'p': 0.1, 'low': 0, 'high': 1, 'weighted': True}}]},
  (0.8577524175395072, 0.9537646702986899))]

In [9]:
# import copy
# model2 = copy.deepcopy(model)

ER p=0.25, 100 8x8 train and 40 32x32 test => 0.94 in 100 iterations

BA param=3, 100 8x8 train and 40 32x32 test => 0.97 in 100 iterations

BA param=5, 100 8x8 train and 40 32x32 test => 0.95 in 100 iterations (0.951 in 200 so has pretty much converged after 100)

BA param=7, 100 8x8 train and 40 32x32 test => 0.946 in 100 iterations

#### Cross training
BA param=7 to BA param=3

BA param=7 to ER p=0.25 0.946 with BA to 0.939 with ER (same as if trained only on BA)

ER p=0.25 to BA param=3 went from 0.939 with ER to 0.967 with BA (BA param 3 was 0.97 so basically nothing lost)

#### Weight variations
Uniform
* 0,0.001 -> 0.928
* 1,1.001 -> 0.962
* 0,0.1 -> 0.931
* 0,10 -> 0.883
* 0,100 -> 0.77
* 50, 200 -> 0.72
* 500, 2000 -> 0.69
* 5000, 20000 -> 0.7

Normal:
Basically same.

Gumbel
* 0,0.001 -> 0.323
* 1,1.001 -> 0.849
* 0,0.1 -> 0.498
* 5,10 -> 0.82
* 5,100 -> 0.8

#### Weight cross training
Train ER p=0.25 unif 0,1:
* 0,0.001 -> 0.948
* 1,1.001 -> 0.967
* 0,0.1 -> 0.916
* 0,10 -> 0.86
* 0,100 -> 0.75
* 50, 200 -> 0.72
* 500, 2000 -> 0.72
* 5000, 20000 -> 0.69

=> Seems to weight generalize quite well. Actually even better because basically no statistical difference with if we trained separately.

Train normal 5000, 20000:
* 0, 0.001 -> 0.39 (maybe it's the large to small that was a problem here? Also those values make little sense for a normal RV)

Other direction train normal 0, 0.001 (got to 0.78):
* 5000, 20000 -> 0.76

=> small to large seems better


#### Larger graphs
Same training
ER p=0.25 8x8 train:
* 100x100 test goes to 0.88
* 200x200 goes to 0.63 (only 12 prediction mismatches though)
* 200x200 p=0.3 =>
* 250x250 => 0.9448 (BUT p=0.1 to not kill my computer)
*
Try this but 16x16 train

#### RIDESHARE
8x8 train,
* 32x32 test => 0.96
* 50x50 test => 0.96
* 100x100 test => 0.938
* 250x250 test => 0.9

#### Double max
8x8 train 32x23 test
300 iterations gets us to 0.93 as normal max (though normal max takes 100 iterations to get there), 600 iterations gets us to 0.965
==> Testing single max on 600 iterations => 0.956
==> Testing single max with 64 hidden dim embeddings on 600 iterations 0.96 (already in 200) (seeing if gain is only from more parameters or if double max is actually more aligned)

Conclusion, it was mainly due to more iterations + some amount of more parameters but only 1% so probably not statistically significant.

#### Training with scaling
Train/test with 5000, 200000 weights ==> 0.76 accuracy
But if normalize 0, 1 on training (or just train on normalized) ==> 0.91 (same acc as had train/testing on normalized)

#### More weight scales training
300 epochs for all, 4x4 train, 32x32 test
* 0, 1: 0.946, 0.917
* 1, 1.01: 0.993, 0.970
* 1, 1.001: 0.956, 0.976
* 1, 1.1: 0.989, 0.972
* 1, 1.2: 0.986, 0.958
* 1, 1.5: 0.966, 0.949
* 1, 2: 0.957, 0.939
* 2, 2.1: 0.9927, 0.9757
* 10, 10.001: 0.4, 0.97
Realization: shifting just doesn't makes sense (val + 1000) / (opt + 1000) > val / opt

==> find the best range
* 0, 0.001: 0.79, 0.929
* 0, 0.01: 0.79, 0.922
* 0, 0.1: 0.76, 0.934
* 0, 1: 0.94, 0.9334
* 0, 10: 0.897, 0.928
* 0, 100: 0.7, 0.93

### Teacher forcing
1.0:  100 epochs => 0.92   | 200 epochs => 0.968
0.75: 100 epochs => 0.956 | 200 epochs => 0.959 | 300 => 0.957 | 400 => 0.966 | 500 => 0.954 | 600 => 0.964
0.5: 100 epochs => 0.943 | 200 => 0.943 | 300 => 0.952 | 400 => 0.938 | 500 => 0.928 | 0.953
0.25: 100 => 0.923 | 200 => 0.935 | 300 => 0.932 | 400 => 0.936 | 500 => 0.948 | 600 => 0.94
0: 100 => 0.9 | 200 => 0.977 | 300 => 0.937 | 400 => 0.924 | 500 => 0.923 | 600 => 0.958 | 800 => 0.98

### GMPNN
MPNN: 100 => 0.92  | 200 => 0.968
GMPNN 100 => 0.956 | 200 => 0.957 | 300 => 0.948 | 400 => 0.935

### Larger graphs
GPMNN with 16 node (8x8) graphs as training
100 => 0.965 | 200 => 0.968 | 300 => 0.969 | 400 => 0.971 | 500 => 0.972 | 600 => 0.971 | 700 => 0.981 | 800 => 0.973

### Train on larger
GPMNN with 16 node (8x8) graphs as training
100 => 0.965 | 200 => 0.968 | 300 => 0.969 | 400 => 0.971 | 500 => 0.972 | 600 => 0.971 | 700 => 0.981 | 800 => 0.973

### Soft pointers


#### Cross training with p value for ER
GMPNN 400 epochs
ER
0.05: 0.87, 0.96
0.1: 0.76, 0.94
0.2: 0.75, 0.91
0.5: 0.93, 0.93
0.75: 0.94, 0.95
1: 0.88, 0.96

Compare to if learned directly: (train on those parameters then test i.e. no cross-training) + is MPNN, not GMPNN
0.05: 0.95, 0.96
0.1: 0.86, 0.93
0.2: 0.7, 0.9
0.5: 0.9, 0.94
0.75: 0.93, 0.95
1: 0.85, 0.95

Already got ER + Rideshare generalization + BA generalization to other parameters to larger graphs


## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [27]:
# For two-way
def count_mismatches_two_way(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        for i in range(32):
            owner = data[datapoint][i]
            good = data[datapoint][int(owner)]
            if good != i:
                count += 1
    print(f"average number of edges contradicting matching: {count / nb_graphs}")

average number of edges contradicting matching: 12.2


In [17]:
# For self-loops
def count_mismatches_self_loop(predictions):
    count = 0
    data = predictions["match"].data
    nb_graphs = data.shape[0]
    for datapoint in range(data.shape[0]):
        owners = set(np.array(data[datapoint][32:64]))
        count += 32 - len(owners)
    print(f"average number of edges contradicting matching: {count / nb_graphs}")


average number of edges contradicting matching: 0.4


In [14]:
a = np.array([1, 2])
b = np.array([2, 3])
print(np.concatenate((a, b)))

[1 2 2 3]
