In [1]:
import torch
import numpy as np
from neural_interaction_detection import get_interactions
from multilayer_perceptron import MLP, train, get_weights
from utils import (
    preprocess_data,
    get_pairwise_auc,
    get_anyorder_R_precision,
    set_seed,
    print_rankings,
)

In [32]:
use_main_effect_nets = True  # toggle this to use "main effect" nets
num_samples = 1000
num_features = 100

## Generate synthetic data with ground truth interactions

In [42]:
import numpy as np

BLOCK_DIM = 25
NUM_BLOCKS = 4
TOTAL_DIM = BLOCK_DIM * NUM_BLOCKS

# ---- py-benchmark integration ----
# NOTE: function names must exist in your py-benchmark installation.
# Edit POOL_NAMES to match the available function names in your environment.
POOL_NAMES = [
    "sphere",
    "rastrigin",
    "ackley",
    "rosenbrock",
    "griewank",
    "schwefel",
    "levy",
    "zakharov",
    "michalewicz",
    "dixon_price",
    "sum_squares",
    "bent_cigar",
    "discus",
    "weierstrass",
    "ellipsoid",
    "alpine1",
    "alpine2",
    "katsuura",
    "salomon",
    "whitley",
    "bohachevsky",
    "perm",
    "trid",
    "powell",
    "styblinski_tang",
]


def _import_pybenchmark():
    # Try common module names used by py-benchmark distributions.
    for mod in ("pybenchmarks", "py_benchmark", "pybenchmark", "pybench"):
        try:
            return __import__(mod)
        except Exception:
            continue
    raise ModuleNotFoundError(
        "py-benchmark is required. Install it and ensure its module name is one of "
        "[pybenchmarks, py_benchmark, pybenchmark, pybench], or edit _import_pybenchmark()."
    )


def _random_rotation_matrix(dim, rng):
    H = rng.normal(size=(dim, dim))
    Q, _ = np.linalg.qr(H)
    if np.linalg.det(Q) < 0:
        Q[:, 0] *= -1
    return Q


def _resolve_shift(shift, rng, dim, scale):
    if shift is None or shift is False:
        return None
    if isinstance(shift, str):
        if shift == "random":
            return rng.uniform(-scale, scale, size=dim)
        raise ValueError(f"Unknown shift spec: {shift}")
    shift = np.asarray(shift, dtype=float)
    if shift.shape != (dim,):
        raise ValueError(f"shift must be shape ({dim},), got {shift.shape}")
    return shift


def _resolve_rotation(rotate, rng, dim):
    if rotate is None or rotate is False:
        return None
    if rotate is True or rotate == "random":
        return _random_rotation_matrix(dim, rng)
    rotate = np.asarray(rotate, dtype=float)
    if rotate.shape != (dim, dim):
        raise ValueError(f"rotation must be shape ({dim}, {dim}), got {rotate.shape}")
    return rotate


def _apply_transform(X, shift=None, rotation=None):
    X_t = X
    if shift is not None:
        X_t = X_t - shift
    if rotation is not None:
        X_t = X_t @ rotation
    return X_t


def _wrap_callable(fn, dim):
    # Normalize to a callable that accepts X with shape (N, dim) and returns (N,)
    # py-benchmark functions often accept a single vector; we adapt as needed.
    def _eval(X):
        X = np.asarray(X)
        if X.ndim != 2 or X.shape[1] != dim:
            raise ValueError(f"Expected X shape (N, {dim}), got {X.shape}")
        try:
            y = fn(X)
        except Exception:
            y = np.array([fn(x) for x in X])
        y = np.asarray(y)
        if y.ndim == 2 and y.shape[1] == 1:
            y = y[:, 0]
        if y.ndim == 0:
            y = np.full(X.shape[0], y)
        if y.ndim != 1:
            raise ValueError("Benchmark function must return (N,) or scalar per row")
        return y

    return _eval


def _get_pybenchmark_callable(name, dim, module=None):
    module = module or _import_pybenchmark()
    last_err = None

    # 1) Factory-style functions
    for attr in ("get_function", "get_benchmark", "get_benchmark_function", "get"):
        if hasattr(module, attr):
            try:
                candidate = getattr(module, attr)(name, dim)
                return _wrap_callable(candidate, dim)
            except Exception as exc:
                last_err = exc

    # 2) Submodules with named callables/classes
    for sub in ("functions", "benchmarks", "benchmark", "funcs"):
        if hasattr(module, sub):
            submod = getattr(module, sub)
            if hasattr(submod, name):
                obj = getattr(submod, name)
                try:
                    obj = obj(dim)
                except Exception:
                    pass
                return _wrap_callable(obj, dim)

    # 3) Direct attribute
    if hasattr(module, name):
        obj = getattr(module, name)
        try:
            obj = obj(dim)
        except Exception:
            pass
        return _wrap_callable(obj, dim)

    if last_err is not None:
        raise ValueError(f"Function '{name}' not found or unsupported: {last_err}")
    raise ValueError(f"Function '{name}' not found in py-benchmark module: {module}")


def make_pool_specs(names, rotate=True, shift="random", seed_base=0, weight=1.0, interaction=True):
    specs = []
    for i, name in enumerate(names):
        specs.append(
            {
                "name": name,
                "rotate": rotate,
                "shift": shift,
                "seed": seed_base + i,
                "weight": weight,
                "interaction": interaction,
            }
        )
    return specs


def build_benchmark_pool(specs, dim, seed=123, shift_scale=1.0, module=None):
    module = module or _import_pybenchmark()
    rng_master = np.random.default_rng(seed)

    pool = []
    for spec in specs:
        spec = dict(spec)
        rng = np.random.default_rng(spec.get("seed", rng_master.integers(0, 2**32 - 1)))
        shift = _resolve_shift(spec.get("shift", None), rng, dim, shift_scale)
        rotation = _resolve_rotation(spec.get("rotate", False), rng, dim)

        fn = _get_pybenchmark_callable(spec["name"], dim, module=module)
        pool.append(
            {
                "name": spec["name"],
                "fn": fn,
                "shift": shift,
                "rotation": rotation,
                "weight": spec.get("weight", 1.0),
                "interaction": spec.get("interaction", True),
            }
        )

    return pool


def indices_from_names(pool, names):
    name_to_index = {spec["name"]: i for i, spec in enumerate(pool)}
    missing = [n for n in names if n not in name_to_index]
    if missing:
        raise ValueError(f"Names not in pool: {missing}")
    return [name_to_index[n] for n in names]


def make_composite_objective(pool, selected_indices, block_dim=BLOCK_DIM, num_blocks=NUM_BLOCKS):
    selected_indices = list(selected_indices)
    if len(selected_indices) != num_blocks:
        raise ValueError(f"selected_indices must be length {num_blocks}")

    def _objective(X):
        X = np.asarray(X)
        if X.ndim != 2 or X.shape[1] != block_dim * num_blocks:
            raise ValueError(f"X must be (N, {block_dim * num_blocks}). Got {X.shape}")

        Y = np.zeros(X.shape[0], dtype=float)
        ground_truth = []
        for block_i, pool_i in enumerate(selected_indices):
            spec = pool[pool_i]
            X_block = X[:, block_i * block_dim : (block_i + 1) * block_dim]
            X_block = _apply_transform(X_block, spec["shift"], spec["rotation"])
            Y = Y + spec["weight"] * spec["fn"](X_block)

            if spec.get("interaction", True):
                start = block_i * block_dim + 1
                end = (block_i + 1) * block_dim + 1
                ground_truth.append(set(range(start, end)))

        return Y, ground_truth

    return _objective


# ---- User-configurable pool and selection ----
POOL_SPECS = make_pool_specs(POOL_NAMES, rotate=True, shift="random", seed_base=0)
BENCHMARK_POOL = build_benchmark_pool(POOL_SPECS, dim=BLOCK_DIM, seed=123, shift_scale=1.0)

# Select 4 functions (each 25D) to compose a 100D objective.
# Option A: select by indices
SELECTED_POOL_INDICES = [0, 1, 2, 3]

# Option B: select by names (uncomment if you prefer names and have unique names)
# SELECTED_POOL_NAMES = ["sphere", "rastrigin", "ackley", "rosenbrock"]
# SELECTED_POOL_INDICES = indices_from_names(BENCHMARK_POOL, SELECTED_POOL_NAMES)

synth_func = make_composite_objective(BENCHMARK_POOL, SELECTED_POOL_INDICES)


In [43]:
set_seed(42)
X = np.random.uniform(low=-1, high=1, size=(num_samples, num_features))
Y, ground_truth = synth_func(X)
data_loaders = preprocess_data(
    X, Y, valid_size=100, test_size=100, std_scale=True, get_torch_loaders=True
)

## Train a multilayer perceptron (MLP)

In [44]:
device = torch.device("cpu")
model = MLP(
    num_features, [140, 100, 60, 20], use_main_effect_nets=use_main_effect_nets
).to(device)

In [45]:
model, mlp_loss = train(
    model, data_loaders, device=device, learning_rate=1e-2, l1_const=5e-5, verbose=True
)

starting to train
early stopping enabled
[epoch 1, total 100] train loss: 3.4727, val loss: 1.0503
[epoch 3, total 100] train loss: 1.1290, val loss: 0.9284
[epoch 5, total 100] train loss: 0.9798, val loss: 0.8819
[epoch 7, total 100] train loss: 0.8218, val loss: 0.7768
[epoch 9, total 100] train loss: 0.5839, val loss: 0.4960
[epoch 11, total 100] train loss: 0.1865, val loss: 0.1349
[epoch 13, total 100] train loss: 0.0822, val loss: 0.0891
[epoch 15, total 100] train loss: 0.0501, val loss: 0.0641
[epoch 17, total 100] train loss: 0.0296, val loss: 0.0449
[epoch 19, total 100] train loss: 0.0201, val loss: 0.0311
[epoch 21, total 100] train loss: 0.0148, val loss: 0.0258
[epoch 23, total 100] train loss: 0.0128, val loss: 0.0224
[epoch 25, total 100] train loss: 0.0100, val loss: 0.0190
[epoch 27, total 100] train loss: 0.0105, val loss: 0.0153
[epoch 29, total 100] train loss: 0.0076, val loss: 0.0131
[epoch 31, total 100] train loss: 0.0067, val loss: 0.0137
[epoch 33, total 100

## Get the MLP's learned weights

In [46]:
model_weights = get_weights(model)

## Detect interactions from the weights

In [47]:
anyorder_interactions = get_interactions(model_weights, one_indexed=True)
pairwise_interactions = get_interactions(model_weights, pairwise=True, one_indexed=True)


print_rankings(pairwise_interactions, anyorder_interactions, top_k=10, spacing=14)

Pairwise interactions              Arbitrary-order interactions
(np.int64(16), np.int64(84))0.0000                      (np.int64(15), np.int64(18))0.0000        
(np.int64(6), np.int64(16))0.0000                      (np.int64(15), np.int64(18), np.int64(92))0.0000        
(np.int64(15), np.int64(29))0.0000                      (np.int64(6), np.int64(32))0.0000        
(np.int64(6), np.int64(32))0.0000                      (np.int64(15), np.int64(18), np.int64(76), np.int64(92))0.0000        
(np.int64(5), np.int64(18))0.0000                      (np.int64(6), np.int64(32), np.int64(37))0.0000        
(np.int64(6), np.int64(29))0.0000                      (np.int64(15), np.int64(18), np.int64(35), np.int64(76), np.int64(92))0.0000        
(np.int64(5), np.int64(29))0.0000                      (np.int64(3), np.int64(6), np.int64(32), np.int64(37))0.0000        
(np.int64(29), np.int64(70))0.0000                      (np.int64(3), np.int64(6), np.int64(32), np.int64(37), np.int64(42))0.

## Evaluate the interactions

In [48]:
auc = get_pairwise_auc(pairwise_interactions, ground_truth)
r_prec = get_anyorder_R_precision(anyorder_interactions, ground_truth)

print("Pairwise AUC", auc, ", Any-order R-Precision", r_prec)

Pairwise AUC 0.5105806451612903 , Any-order R-Precision 0.0
