In [None]:
import clrs
import numpy as np
import jax
import jax.numpy as jnp

import pprint

rng = np.random.RandomState(1234)
rng_key = jax.random.PRNGKey(rng.randint(2**32))

In [None]:
# =================== BIPARTITE GRAPH GENERATOR SPECS =========================
# - Erdos Reyni (ER)
#     Generates an Erdos-Reyni random graph with edge probability [p]. If
#       [weighted], edge weights are iid uniform([low], [high])
#
#     {'generator': ER, 'weighted': False, 'p': 0.5, 'low': 0.0, 'high': 1.0}
#
# - Barabasi-Albert (BA)
#     Generates a Barabasi-Albert random graph with parameter [ba_param]. If
#       [weighted], edge weights are iid uniform([low], [high])
#
#     {'generator': BA, 'ba_param': 1, 'weighted': False, 'low': 0.0, 'high': 1.0}
#
# - Geometric (GEOMETRIC)
#     Generates a random graph by embedding nodes uniformly over the unit
#       square. Edge weights are the euclidean distance between two nodes,
#       with weights below [threshold] set to 0, then scaled by [scaling].
#
#     {'generator': GEOMETRIC, 'threshold': 0.25, 'scaling': 1.0}
#
# - Flow (FLOW)
#     Generates a random ER reduction to a max flow input, with edge
#       probability [p].
#     {'generator': FLOW, 'p': 0.5}



train_sampler, spec = clrs.build_sampler(
    name='simplified_min_sum',
    num_samples=1,
    length=16,
    generator='DATASET',
    filepath='data/gmission_edges.txt'
    )

test_sampler, spec = clrs.build_sampler(
    name='simplified_min_sum',
    num_samples=1,
    length=4,
    weighted=True,
    generator='ER')

pprint.pprint(spec)

def _iterate_sampler(sampler, batch_size):
  while True:
    yield sampler.next(batch_size)

train_sampler = _iterate_sampler(train_sampler, batch_size=1)
test_sampler = _iterate_sampler(test_sampler, batch_size=1)



In [None]:
x = next(train_sampler)[0][0][1].data
print(np.sum(x != 0))

In [None]:
processor_factory = clrs.get_processor_factory('mpnn', use_ln=True, nb_triplet_fts=0) #use_ln => use layer norm
# processor_factory = clrs.get_processor_factory('gat', use_ln=True, nb_heads = 4, nb_triplet_fts = 0)
model_params = dict(
    processor_factory=processor_factory, # contains the processor_factory
    hidden_dim=64, # TODO put back to 32 if no difference
    encode_hints=True,
    decode_hints=True,
    #decode_diffs=False,
    #hint_teacher_forcing_noise=1.0,
    hint_teacher_forcing=1.0,
    use_lstm=False,
    learning_rate=0.001,
    checkpoint_path='/tmp/checkpt',
    freeze_processor=False, # Good for post step
    dropout_prob=0.5,
    # nb_msg_passing_steps=3,
)

dummy_trajectory = next(train_sampler) # jax needs a trajectory that is plausible looking to init

model = clrs.models.BaselineModel(
    spec=spec,
    dummy_trajectory=dummy_trajectory,
    **model_params
)

model.init(dummy_trajectory.features, 1234) # 1234 is a random seed

In [None]:
import copy

step = 0

while step <= 100:
  feedback, test_feedback = next(train_sampler), next(test_sampler)
  # TODO remove - testing if uses hints on tests
  # shape = test_feedback.features.hints[0].data[0].shape
  # test_feedback.features.hints[0].data = test_feedback.features.hints[0].data[0, :, :].reshape((1, *shape))

  rng_key, new_rng_key = jax.random.split(rng_key) # jax needs new random seed at step
  cur_loss = model.feedback(rng_key, feedback) # loss is contained in model somewhere
  rng_key = new_rng_key
  if step % 10 == 0:
    predictions_val, _ = model.predict(rng_key, feedback.features)
    out_val = clrs.evaluate(feedback.outputs, predictions_val)
    predictions, _ = model.predict(rng_key, test_feedback.features)
    out = clrs.evaluate(test_feedback.outputs, predictions)
    print(predictions)
    print(f'step = {step} | loss = {cur_loss} | val_acc = {out_val["score"]} | test_acc = {out["score"]}') # here, val accuracy is actually training accuracy, not great but is example
  step += 1
model2 = copy.deepcopy(model)

## Some intermediate results
ALL: 0.5 dropout

MPNN 100 train, 40 test, 100 epochs, self-loops -> loss = 0.9931782484054565 | val_acc = 0.8515625 | test_acc = 0.7484375238418579 | accuracy = 0.65, average nb non-matched: 13.45

MPNN 200 train, 40 test, 100 epochs, self-loops -> loss = 0.8129420280456543 | val_acc = 0.8125 | test_acc = 0.77734375 | accuracy = 0.72, average nb non-matched: 11.125

MPNN 100 train, 40 test, 100 epochs, double links -> loss = 0.8689386248588562 | val_acc = 0.7109375 | test_acc = 0.42695313692092896 | accuracy =? NOTE: only started "learning" in the last epochs => trying more, interestingly has less loss than self-loops but less accuracy too

MPNN 100 train, 40 test, 200 epochs, double links -> step = 100 | loss = 0.6802611351013184 | val_acc = 0.806640625 | test_acc = 0.681640625 | accuracy = 0.89, average nb non-matched: 5.65

MPNN 300 train, 40 test, 400 epochs, double links -> loss = 0.5485531091690063 | val_acc = 0.775390625 | test_acc = 0.6910156607627869 | accuracy = 0.928, average nb non-matched: 4.075 Note: best test_acc 0.727, similar test_acc to 100 train 200 epochs but better accuracy + still does not converge on training accuracy though

Diff: length 100 testing instead of 64
MPNN 100 train, 40 test LENGTH 100, 200 epochs, double links -> loss = 0.6958761215209961 | val_acc = 0.787109375 | test_acc = 0.503250002861023 | accuracy = 0.759, average nb non-matched: 7.9/100


#### Now with actually bipartite graph (no owner-owner / good-good edges)
Doesn't really change results

ALL with 0 dropout

#### No hints
0 dropout Can get up to 0.78 of OPT, average nb non-matched: 10.5/64


#### Training with 64 hidden dimensions
MPNN 100 train, 40 test, 0 dropout, double links -> Get to 90% acc in 30 iterations, 93% in 60
GAT 100 train, 40 test, 0 dropout, double links -> Get to 0.78 in 30 iterations, 91% in 60, 92% in 100, 92% in 200

Same with 0.5 dropout
MPNN 100 train, 40 test, 0.5 dropout, double links -> Get to 85% in 30 iterations, 93% in 60 iterations, 93% in 100 iterations
GAT 100 train, 40 test, 0.5 dropout, double links -> 94.7% in 200

#### 3 message passing steps
64 dims, 3 message passing steps
MPNN 100 train, 40 test, 0.5 dropout, double links -> 91.7% in 100 iterations (worse than 1 MP step), 93% in 200 iterations

Back to 1 message passing step

#### Larger MLP
[out_size, out_size, out_size] MLP
MPNN 100 train, 40 test, 0.5 dropout, double links, 3 layer MLP -> loss = 0.6642395257949829 | val_acc = 0.75390625 | test_acc = 0.699999988079071 | 90% in 100 iterations (worse than smaller MLP) | 93% in 200 iterations


In [None]:
from scipy.optimize import linear_sum_assignment


def matching_value(samples, predictions, partial = False, match_rest = False, opt_scipy = False):
    features = samples.features
    gt_matchings = samples.outputs[0].data
    # inputs for the matrix A are at index 1 (see spec.py)
    data = features.inputs[1].data
    masks = features.inputs[2].data
    pred_accuracy = 0

    # Iterating over all the samples
    for i in range(data.shape[0]):
        max_weight = compute_greedy_matching_weight(i, data, masks, gt_matchings[i])

        # TODO remove
        predicted_matching = predictions["match"].data[i]
        # buyers_mask = masks[i]
        # n = int(np.sum(buyers_mask))
        # permutation = np.random.permutation(np.arange(np.sum(buyers_mask == 0)))
        # predicted_matching = np.concatenate((np.zeros(n), permutation))
        if partial:
            preds_weight = compute_partial_matching_weight(i, data, masks, predicted_matching)
            print(f"opt: {max_weight}, greedy: {preds_weight}")
        else:
            preds_weight = compute_greedy_matching_weight(i, data, masks, predicted_matching, match_rest = match_rest)
            print(f"opt: {max_weight}, partial: {preds_weight}")

        # assert preds_weight <= max_weight
        pred_accuracy += preds_weight / max_weight

    return pred_accuracy / data.shape[0]


def compute_greedy_matching_weight(i, data, masks, matching, match_rest = False):
    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # Only consider the matching values for consumers
    matching = np.where(goods_mask == 1, matching, -1)
    unmatched_goods = set(range(n, n + m))
    unmatched_buyers = set(range(n))

    for buyer in range(n):
        if buyer in matching:
            # If several goods point to the same buyer, keep the one with maximum weight
            candidates = A[buyer, matching == buyer]
            matching_weight += np.max(candidates)
            matched.remove(np.argmax(A))


    return matching_weight


def compute_partial_matching_weight(i, data, masks, matching):
    # Matching is expected to be a (n+m)x(n+m) matrix where each row sums to 1 (weights assigned to other nodes)

    matching_weight = 0
    A = data[i]
    buyers_mask = masks[i]
    n = int(np.sum(buyers_mask))
    goods_mask = 1 - buyers_mask
    m = int(np.sum(goods_mask))

    # We only care about the buyer -> good connections
    A_submatrix = A[:n, n:n + m]
    matching = matching[:n, n:n + m]

    max_weight = np.max(np.sum(matching, axis = 0))
    print(f"max weight: {max_weight}")
    matching /= max_weight
    return np.sum(matching * A_submatrix)

In [None]:
test_feedback = next(test_sampler)
predictions, _ = model.predict(rng_key, test_feedback.features)

In [None]:
predictions['match'].data[0]

In [None]:
test_feedback.outputs[0].data[0]

In [None]:
matching_value(test_feedback, predictions, partial = True)

## Preliminary results
random permutation/matching: 0.18

MPNN:
learned predictions: 0.67

GAT:
learned predictions: 0.72

Got better with double ended predictions

Partial: 0.64 while greedy was doing about 0.92 on the same instance. Main reason seems to be that max weight is around 1.5 => can get at most 2/3 OPT


### Counting the number of matching constraints violated

In [None]:
# For two-way
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    for i in range(32):
        owner = data[datapoint][i]
        good = data[datapoint][int(owner)]
        if good != i:
            count += 1
print(f"average number of edges contradicting matching: {count / nb_graphs}")

# For self-loops
count = 0
data = predictions["owners"].data
nb_graphs = data.shape[0]
for datapoint in range(data.shape[0]):
    owners = set(np.array(data[datapoint][32:64]))
    count += 32 - len(owners)
print(f"average number of edges contradicting matching: {count / nb_graphs}")


In [None]:
predictions, hints = model.predict(rng_key, test_feedback.features, return_hints = True, return_all_outputs = True)


In [None]:
predictions['owners'].data[::100].shape

In [None]:
np.argmax(hints[10]['owners_h'][0], axis =



In [None]:
len(hints)

In [None]:
arr = np.array([1, 2, 3, 4, 5, 6])
arr[np.arange(len(arr)) % 2 == 0]
arr[::2]

In [None]:
test_feedback.features.inputs[1].data[0]

In [None]:
arr = np.array([[1,2],[2,3]])
np.max(arr, axis = 0)