In [2]:
from JAXEmbeddings import ComplexPreference, train_on_func, evaluate, unlabeled_batch_generator, get_ranges_and_evals
import jax
import jax.numpy as jnp
from jax import random
from Stackelberg.src.utils.utility_functions import ackley, branin, eggholder, hoelder, matyas, michalewicz, rosenbrock, bukin
from collections import namedtuple
import json
import flax.serialization as serialization

In [3]:
# 1) Define your “function” to optimize (in NumPy). E.g., a simple 2D paraboloid:
def my_paraboloid(X: jnp.ndarray) -> jnp.ndarray:
    # X has shape [N, 2]; return f(x,y) = x^2 + y^2
    return jnp.sum(X**2, axis=1)

In [13]:
# 2) Create a Flax/Linen module instance
in_dim = 2
factor = 2      # must be even
hidden_sizes = [256, 128, 64]
# 3) Create a PRNGKey for initialization and training randomness
key = random.PRNGKey(1)
AffineTransform = namedtuple("AffineTransform", ["affine_transform"])
affine_transform = AffineTransform(jnp.array([1.,0.]))
funcs = {
    "ackley": {
        "func": ackley,
        "bounds": [[-32.768, 32.768],]*2,
        "sizes": [512, 256, 128 ,64 ,32],
        "branches": 3
    },
    "branin": {
        "func": branin,
        "bounds": [[-5, 10], [0, 15]],
        "sizes": [256, 128 ,64]
    },
    "eggholder": {
        "func": eggholder,
        "bounds": [[-512,512],]*2,
        "sizes": [512, 256, 128 ,64, 32]
    },
    "hoelder":{
        "func": hoelder,
        "bounds": [[-10,10],]*2,
        "sizes": [512, 256, 128 ,64],
        "branches": 2
    },
    "matyas": {
        "func": matyas,
        "bounds": [[-10,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "michalewicz": {
        "func": michalewicz,
        "bounds": [[0,jnp.pi],]*2,
        "sizes": [256, 128 ,64]
    },
    "rosenbrock": {
        "func": rosenbrock,
        "bounds": [[-5,10],]*2,
        "sizes": [256, 128 ,64]
    },
    "bukin": {
        "func": bukin,
        "bounds": [[-15, -5],[-3,3]],
        "sizes": [512, 256, 128, 64],
        "branches": 2
    }
}

In [None]:
# 4) Train
#for name, func_data in funcs.items():
name = "matyas"
func_data = funcs[name]
print("-"*40)
print(f"# {name}")
print("-"*40)
sizes = func_data["sizes"]
bounds = func_data["bounds"]
branches = func_data.get("branches", 1)
func = jax.vmap(lambda x: -func_data["func"](x, affine_transform))
model_def = ComplexPreference(in_dim=in_dim, factor=factor, sizes=sizes, branches=branches)
results = train_on_func(
    rng_key=key,
    model_def=model_def,
    func=func,
    bounds=bounds,  # domain for both x and y
    in_dim=in_dim,
    num_pairs=50_000,
    batch_size=512,
    epochs=300,
    lr=1e-3,
    patience=20
)
learned_params = results['params']

In [17]:
# 5) Evaluate
eval_metrics = evaluate(
    rng_key=key,
    model_def=model_def,
    params=learned_params,
    func= func,
    bounds=bounds,
    in_dim=in_dim,
    num_pairs=100_000,
    batch_size=512
)
print("Eval metrics:", eval_metrics)
hparams = {
    "in_dim": in_dim,
    "factor": factor,
    "sizes": sizes,  # e.g. [128, 64]
    "branches": branches,
}

Accuracy : 0.9972
Precision: 0.9973
Recall   : 0.9971
F1-score : 0.9972
ROC-AUC  : 1.0000
Confusion Matrix:
[[49918   136]
 [  144 49802]]
Eval metrics: {'accuracy': 0.9972, 'precision': 0.9972766230125355, 'recall': 0.9971168862371361, 'f1': 0.9971967482279445, 'roc_auc': 0.9999596541529406, 'confusion_matrix': array([[49918,   136],
       [  144, 49802]])}


In [15]:
for dict_key in eval_metrics:
    #cast to json friendly types
    if dict_key == "confusion_matrix":
        eval_metrics[dict_key] = eval_metrics[dict_key].tolist()
    else:
        eval_metrics[dict_key] = float(eval_metrics[dict_key])
func_metadata ={
    "hparams": hparams,
    "eval_metrics": eval_metrics,
}
#write hyperparams + eval results for later reference
hparam_path = f"../Embedding_Model_Weights/{name}.json"
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)
#Write actual func params
param_bytes = serialization.to_bytes(learned_params)
with open(f"../Embedding_Model_Weights/{name}.msgpack", "wb") as f:
    f.write(param_bytes)

In [16]:
func_metadata ={
    "hparams": hparams,
    "eval_metrics": eval_metrics,
}
with open(hparam_path, "w") as fp:
    json.dump(func_metadata, fp, indent=2)

In [8]:
X, Y, fx, fy = get_ranges_and_evals(num_pairs=1000, func=func, in_dim=in_dim, bounds=bounds)

In [9]:
batcher = unlabeled_batch_generator(X, Y, fx, fy, 1024, True)
for batch in batcher:
    x, y, l = batch
print(x.shape, y.shape, l.shape)

(1000, 2) (1000, 2) (1000,)
