In [1]:
from JAXEmbeddings import ComplexPreference, train_on_func, evaluate, unlabeled_batch_generator, get_ranges_and_evals
import jax
import jax.numpy as jnp
from jax import random
from Stackelberg.src.utils.utility_functions import ackley, branin, eggholder, hoelder, matyas, michalewicz, rosenbrock, bukin
from collections import namedtuple
import json
import flax.serialization as serialization

In [2]:
# 1) Define your “function” to optimize (in NumPy). E.g., a simple 2D paraboloid:
def my_paraboloid(X: jnp.ndarray) -> jnp.ndarray:
    # X has shape [N, 2]; return f(x,y) = x^2 + y^2
    return jnp.sum(X**2, axis=1)

In [3]:
# 2) Create a Flax/Linen module instance
in_dim = 2
factor = 2      # must be even
# 3) Create a PRNGKey for initialization and training randomness
key = random.PRNGKey(1)
AffineTransform = namedtuple("AffineTransform", ["affine_transform"])
affine_transform = AffineTransform(jnp.array([1.,0.]))
funcs = {
    "ackley": {
        "func": ackley,
        "bounds": [[-5,5],]*2,#[[-32.768, 32.768],]*2, Imitating the bounds in ackley/config.yaml
        "sizes": [512, 256, 128 ,64 ,32],
        "branches": 3
    },
    "branin": {
        "func": branin,
        "bounds": [[-5, 10], [0, 15]],
        "sizes": [256, 128 ,64]
    },
    "eggholder": {
        "func": eggholder,
        "bounds": [[-512,512],]*2,
        "sizes": [512, 256, 128 ,64, 32]
    },
    "hoelder":{
        "func": hoelder,
        "bounds": [[-10,10],]*2,
        "sizes": [512, 256, 128 ,64],
        "branches": 2
    },
    "matyas": {
        "func": matyas,
        "bounds": [[-10,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "michalewicz": {
        "func": michalewicz,
        "bounds": [[0,jnp.pi],]*2,
        "sizes": [256, 128 ,64]
    },
    "rosenbrock": {
        "func": rosenbrock,
        "bounds": [[-5,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "bukin": {
        "func": bukin,
        "bounds": [[-15, -5],[-3,3]],
        "sizes": [512, 256, 128, 64],
        "branches": 2
    }
}

In [5]:
# 4) Train
#for name, func_data in funcs.items():
name = "ackley"
func_data = funcs[name]
print("-"*40)
print(f"# {name}")
print("-"*40)
sizes = func_data["sizes"]
bounds = func_data["bounds"]
branches = func_data.get("branches", 1)
func = jax.vmap(lambda x: -func_data["func"](x, affine_transform))
model_def = ComplexPreference(in_dim=in_dim, factor=factor, sizes=sizes, branches=branches)
results = train_on_func(
    rng_key=key,
    model_def=model_def,
    func=func,
    bounds=bounds,  # domain for both x and y
    in_dim=in_dim,
    num_pairs=50_000,
    batch_size=512,
    epochs=500,
    lr=1e-3,
    patience=20,
    #ackley=True
)
learned_params = results['params']

----------------------------------------
# ackley
----------------------------------------
Epoch 010 — avg loss: 0.189206
Epoch 020 — avg loss: 0.185711
Epoch 030 — avg loss: 0.184028
Epoch 040 — avg loss: 0.182583
Epoch 050 — avg loss: 0.181041
Epoch 060 — avg loss: 0.178187
Epoch 070 — avg loss: 0.175363
Epoch 080 — avg loss: 0.173812
Epoch 090 — avg loss: 0.170714
Epoch 100 — avg loss: 0.165158
Epoch 110 — avg loss: 0.156835
Epoch 120 — avg loss: 0.146599
Epoch 130 — avg loss: 0.126765
Epoch 140 — avg loss: 0.114719
Epoch 150 — avg loss: 0.104170
Epoch 160 — avg loss: 0.098410
Epoch 170 — avg loss: 0.095269
Epoch 180 — avg loss: 0.085000
Epoch 190 — avg loss: 0.078343
Epoch 200 — avg loss: 0.079700
Epoch 210 — avg loss: 0.071028
Epoch 220 — avg loss: 0.070845
Epoch 230 — avg loss: 0.063968
Epoch 240 — avg loss: 0.060685
Epoch 250 — avg loss: 0.055072
Epoch 260 — avg loss: 0.073334
Epoch 270 — avg loss: 0.046586
Epoch 280 — avg loss: 0.054085
Epoch 290 — avg loss: 0.041218
Epoch 300 

In [6]:
  # 5) Evaluate
eval_metrics = evaluate(
    rng_key=key,
    model_def=model_def,
    params=learned_params,
    func= func,
    bounds=bounds,
    in_dim=in_dim,
    num_pairs=100_000,
    batch_size=512
)
print("Eval metrics:", eval_metrics)
hparams = {
    "in_dim": in_dim,
    "factor": factor,
    "sizes": sizes,  # e.g. [128, 64]
    "branches": branches,
}

Accuracy : 0.9676
Precision: 0.9666
Recall   : 0.9683
F1-score : 0.9675
ROC-AUC  : 0.9953
Confusion Matrix:
[[48571  1663]
 [ 1576 48190]]
Eval metrics: {'accuracy': 0.96761, 'precision': 0.9666419272661625, 'recall': 0.9683317927902584, 'f1': 0.9674861221252974, 'roc_auc': 0.9953349622244766, 'confusion_matrix': array([[48571,  1663],
       [ 1576, 48190]])}


In [7]:
for dict_key in eval_metrics:
    #cast to json friendly types
    if dict_key == "confusion_matrix":
        eval_metrics[dict_key] = eval_metrics[dict_key].tolist()
    else:
        eval_metrics[dict_key] = float(eval_metrics[dict_key])
func_metadata ={
    "hparams": hparams,
    "eval_metrics": eval_metrics,
}
#write hyperparams + eval results for later reference
hparam_path = f"../Embedding_Model_Weights/{name}.json"
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)
#Write actual func params
param_bytes = serialization.to_bytes(learned_params)
with open(f"../Embedding_Model_Weights/{name}.msgpack", "wb") as f:
    f.write(param_bytes)

In [8]:
func_metadata ={
    "hparams": hparams,
    "eval_metrics": eval_metrics,
}
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)

In [8]:
X, Y, fx, fy = get_ranges_and_evals(num_pairs=1000, func=func, in_dim=in_dim, bounds=bounds)

In [9]:
batcher = unlabeled_batch_generator(X, Y, fx, fy, 1024, True)
for batch in batcher:
    x, y, l = batch
print(x.shape, y.shape, l.shape)

(1000, 2) (1000, 2) (1000,)
