In [1]:
from JAXEmbeddings import ComplexPreference, train_on_func, evaluate
import jax
import jax.numpy as jnp
from jax import random
from Stackelberg.src.utils.utility_functions import ackley, branin, eggholder, hoelder, matyas, michalewicz, rosenbrock, bukin
from collections import namedtuple
import json
import flax.serialization as serialization

In [2]:
# 1) Define your “function” to optimize (in NumPy). E.g., a simple 2D paraboloid:
def my_paraboloid(X: jnp.ndarray) -> jnp.ndarray:
    # X has shape [N, 2]; return f(x,y) = x^2 + y^2
    return jnp.sum(X**2, axis=1)

In [3]:
# 2) Create a Flax/Linen module instance
in_dim = 2
factor = 2      # must be even
hidden_sizes = [256, 128, 64]
# 3) Create a PRNGKey for initialization and training randomness
key = random.PRNGKey(0)
AffineTransform = namedtuple("AffineTransform", ["affine_transform"])
affine_transform = AffineTransform(jnp.array([1.,0.]))
funcs = {
    "ackley": {
        "func": ackley,
        "bounds": [[-32.768, 32.768],]*2,
        "sizes": [512, 256, 128 ,64 ,32],
        "branches": 3
    },
    "branin": {
        "func": branin,
        "bounds": [[-5, 10], [0, 15]],
        "sizes": [256, 128 ,64]
    },
    "eggholder": {
        "func": eggholder,
        "bounds": [[-512,512],]*2,
        "sizes": [512, 256, 128 ,64, 32]
    },
    "hoelder":{
        "func": hoelder,
        "bounds": [[-10,10],]*2,
        "sizes": [512, 256, 128 ,64],
        "branches": 2
    },
    "matyas": {
        "func": matyas,
        "bounds": [[-10,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "michalewicz": {
        "func": michalewicz,
        "bounds": [[0,jnp.pi],]*2,
        "sizes": [256, 128 ,64]
    },
    "rosenbrock": {
        "func": rosenbrock,
        "bounds": [[-5,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "bukin": {
        "func": bukin,
        "bounds": [[-15, -5],[-3,3]],
        "sizes": [512, 256, 128, 64],
        "branches": 2
    }
}

In [13]:
# 4) Train
for name, func_data in funcs.items():
    print("-"*40)
    print(f"# {name}")
    print("-"*40)
    sizes = func_data["sizes"]
    bounds = func_data["bounds"]
    branches = func_data.get("branches", 1)
    func = jax.vmap(lambda x: func_data["func"](x, affine_transform))
    model_def = ComplexPreference(in_dim=in_dim, factor=factor, sizes=sizes, branches=branches)
    results = train_on_func(
        rng_key=key,
        model_def=model_def,
        func=func,
        bounds=bounds,  # domain for both x and y
        in_dim=in_dim,
        factor=factor,
        sizes=hidden_sizes,
        num_pairs=50_000,
        batch_size=512,
        epochs=200,
        lr=1e-3,
        patience=20
    )
    learned_params = results['params']
    # 5) Evaluate
    eval_metrics = evaluate(
        rng_key=key,
        model_def=model_def,
        params=learned_params,
        func= func,
        bounds=bounds,
        in_dim=in_dim,
        num_pairs=10_000,
        batch_size=512
    )
    print("Eval metrics:", eval_metrics)
    hparams = {
        "in_dim": in_dim,
        "factor": factor,
        "sizes": sizes,  # e.g. [128, 64]
        "branches": branches,
    }
    for dict_key in eval_metrics:
        #cast to json friendly types
        if dict_key == "confusion_matrix":
            eval_metrics[dict_key] = eval_metrics[dict_key].tolist()
        else:
            eval_metrics[dict_key] = float(eval_metrics[dict_key])
    func_metadata ={
        "hparams": hparams,
        "eval_metrics": eval_metrics,
    }
    #write hyperparams + eval results for later reference
    hparam_path = f"../Embedding_Model_Weights/{name}.json"
    with open(hparam_path, "w") as fp:
        json.dump(func_metadata, fp, indent=2)
    #Write actual func params
    param_bytes = serialization.to_bytes(learned_params)
    with open(f"../Embedding_Model_Weights/{name}.msgpack", "wb") as f:
        f.write(param_bytes)

----------------------------------------
# ackley
----------------------------------------
Epoch 010 — avg loss: 0.398116
Epoch 020 — avg loss: 0.384889
Epoch 030 — avg loss: 0.381876
Epoch 040 — avg loss: 0.376256
Epoch 050 — avg loss: 0.374469
Epoch 060 — avg loss: 0.374608
Epoch 070 — avg loss: 0.373753
Epoch 080 — avg loss: 0.373312
Epoch 090 — avg loss: 0.372846
Epoch 100 — avg loss: 0.372446
Epoch 110 — avg loss: 0.372101
Epoch 120 — avg loss: 0.371991
Epoch 130 — avg loss: 0.371106
Epoch 140 — avg loss: 0.370338
Epoch 150 — avg loss: 0.369448
Epoch 160 — avg loss: 0.368703
Epoch 170 — avg loss: 0.368222
Epoch 180 — avg loss: 0.368045
Epoch 190 — avg loss: 0.368044
Epoch 200 — avg loss: 0.368044
Accuracy : 0.8013
Precision: 0.8040
Recall   : 0.7945
F1-score : 0.7992
ROC-AUC  : 0.9041
Confusion Matrix:
[[4059  964]
 [1023 3954]]
Eval metrics: {'accuracy': 0.8013, 'precision': 0.8039853599023994, 'recall': 0.7944544906570223, 'f1': 0.7991915108640728, 'roc_auc': 0.9041296913842698,

In [9]:
func_metadata ={
    "hparams": hparams,
    "eval_metrics": eval_metrics,
}
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)

NameError: name 'hparam_path' is not defined

In [58]:
type(eval_metrics["confusion_matrix"][0])

list

In [11]:
print([type(v) for v in eval_metrics.values()])

[<class 'float'>, <class 'float'>, <class 'float'>, <class 'float'>, <class 'float'>, <class 'numpy.ndarray'>]


In [8]:
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)

NameError: name 'hparam_path' is not defined