In [1]:
import torch
import numpy as np
from neural_interaction_detection import get_interactions
from multilayer_perceptron import MLP, train, get_weights
from utils import (
    preprocess_data,
    get_pairwise_auc,
    get_anyorder_R_precision,
    set_seed,
    print_rankings,
)

In [32]:
use_main_effect_nets = True  # toggle this to use "main effect" nets
num_samples = 1000
num_features = 100

## Generate synthetic data with ground truth interactions

In [42]:
import numpy as np


def _ellipsoid(X):
    # X: (N, n)
    n = X.shape[1]
    if n == 1:
        w = np.array([1.0], dtype=X.dtype)
    else:
        w = 10.0 ** (6.0 * np.arange(n) / (n - 1))  # [1, 1e6]まで指数的に増加
    return np.sum(w * (X**2), axis=1)


def _rastrigin(X):
    n = X.shape[1]
    return 10.0 * n + np.sum(X**2 - 10.0 * np.cos(2.0 * np.pi * X), axis=1)


def _ackley(X):
    n = X.shape[1]
    s1 = np.sum(X**2, axis=1)
    s2 = np.sum(np.cos(2.0 * np.pi * X), axis=1)
    term1 = -20.0 * np.exp(-0.2 * np.sqrt(s1 / n))
    term2 = -np.exp(s2 / n)
    return term1 + term2 + 20.0 + np.e


def _rosenbrock(X):
    # X: (N, n)
    # sum_{i=1}^{n-1} [100(x_{i+1} - x_i^2)^2 + (x_i - 1)^2]
    xi = X[:, :-1]
    xnext = X[:, 1:]
    return np.sum(100.0 * (xnext - xi**2) ** 2 + (xi - 1.0) ** 2, axis=1)


def synth_func(X):
    """
    X: shape (N, 100)
    Returns:
      Y: shape (N,)
      ground_truth: list[set[int]] (1-indexed variable indices)
    """
    X = np.asarray(X)
    if X.ndim != 2 or X.shape[1] != 100:
        raise ValueError(f"X must be (N, 100). Got {X.shape}")

    X_e = X[:, 0:25]  # 1..25
    X_r = X[:, 25:50]  # 26..50
    X_a = X[:, 50:75]  # 51..75
    X_rb = X[:, 75:100]  # 76..100

    Y = _ellipsoid(X_e) + _rastrigin(X_r) + _ackley(X_a) + _rosenbrock(X_rb)

    ground_truth = [
        set(range(76, 101)),
    ]

    return Y, ground_truth

In [43]:
set_seed(42)
X = np.random.uniform(low=-1, high=1, size=(num_samples, num_features))
Y, ground_truth = synth_func(X)
data_loaders = preprocess_data(
    X, Y, valid_size=100, test_size=100, std_scale=True, get_torch_loaders=True
)

## Train a multilayer perceptron (MLP)

In [44]:
device = torch.device("cpu")
model = MLP(
    num_features, [140, 100, 60, 20], use_main_effect_nets=use_main_effect_nets
).to(device)

In [45]:
model, mlp_loss = train(
    model, data_loaders, device=device, learning_rate=1e-2, l1_const=5e-5, verbose=True
)

starting to train
early stopping enabled
[epoch 1, total 100] train loss: 3.4727, val loss: 1.0503
[epoch 3, total 100] train loss: 1.1290, val loss: 0.9284
[epoch 5, total 100] train loss: 0.9798, val loss: 0.8819
[epoch 7, total 100] train loss: 0.8218, val loss: 0.7768
[epoch 9, total 100] train loss: 0.5839, val loss: 0.4960
[epoch 11, total 100] train loss: 0.1865, val loss: 0.1349
[epoch 13, total 100] train loss: 0.0822, val loss: 0.0891
[epoch 15, total 100] train loss: 0.0501, val loss: 0.0641
[epoch 17, total 100] train loss: 0.0296, val loss: 0.0449
[epoch 19, total 100] train loss: 0.0201, val loss: 0.0311
[epoch 21, total 100] train loss: 0.0148, val loss: 0.0258
[epoch 23, total 100] train loss: 0.0128, val loss: 0.0224
[epoch 25, total 100] train loss: 0.0100, val loss: 0.0190
[epoch 27, total 100] train loss: 0.0105, val loss: 0.0153
[epoch 29, total 100] train loss: 0.0076, val loss: 0.0131
[epoch 31, total 100] train loss: 0.0067, val loss: 0.0137
[epoch 33, total 100

## Get the MLP's learned weights

In [46]:
model_weights = get_weights(model)

## Detect interactions from the weights

In [47]:
anyorder_interactions = get_interactions(model_weights, one_indexed=True)
pairwise_interactions = get_interactions(model_weights, pairwise=True, one_indexed=True)


print_rankings(pairwise_interactions, anyorder_interactions, top_k=10, spacing=14)

Pairwise interactions              Arbitrary-order interactions
(np.int64(16), np.int64(84))0.0000                      (np.int64(15), np.int64(18))0.0000        
(np.int64(6), np.int64(16))0.0000                      (np.int64(15), np.int64(18), np.int64(92))0.0000        
(np.int64(15), np.int64(29))0.0000                      (np.int64(6), np.int64(32))0.0000        
(np.int64(6), np.int64(32))0.0000                      (np.int64(15), np.int64(18), np.int64(76), np.int64(92))0.0000        
(np.int64(5), np.int64(18))0.0000                      (np.int64(6), np.int64(32), np.int64(37))0.0000        
(np.int64(6), np.int64(29))0.0000                      (np.int64(15), np.int64(18), np.int64(35), np.int64(76), np.int64(92))0.0000        
(np.int64(5), np.int64(29))0.0000                      (np.int64(3), np.int64(6), np.int64(32), np.int64(37))0.0000        
(np.int64(29), np.int64(70))0.0000                      (np.int64(3), np.int64(6), np.int64(32), np.int64(37), np.int64(42))0.

## Evaluate the interactions

In [48]:
auc = get_pairwise_auc(pairwise_interactions, ground_truth)
r_prec = get_anyorder_R_precision(anyorder_interactions, ground_truth)

print("Pairwise AUC", auc, ", Any-order R-Precision", r_prec)

Pairwise AUC 0.5105806451612903 , Any-order R-Precision 0.0
