# ReLU Formulation Comparison

Compare different ReLU formulations: standard MIP (global optimal) vs. nonlinear formulations (local optimal).

**Requirements:** torch, torchvision, skorch, matplotlib, gurobipy >= 12.0

## Setup

In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from skorch import NeuralNetClassifier
import gurobipy as gp
from gurobipy import GRB
from gurobi_ml import add_predictor_constr

## Load Data and Train Model

In [None]:
# Load digits dataset (8x8 images, 10 classes)
X, y = load_digits(return_X_y=True)
X = X / 16.0  # Normalize to [0, 1]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}")
print(f"Input dimension: {X.shape[1]}, Classes: {len(np.unique(y))}")

In [None]:
# Simple neural network: 64 -> 32 -> 16 -> 10
nn_model = nn.Sequential(
    nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 10)
)

# Train the network
clf = NeuralNetClassifier(
    nn_model,
    max_epochs=20,
    lr=0.01,
    optimizer=torch.optim.Adam,
    criterion=nn.CrossEntropyLoss,
    verbose=0,
)
clf.fit(X_train.astype(np.float32), y_train.astype(np.int64))

# Test accuracy
train_acc = clf.score(X_train.astype(np.float32), y_train)
test_acc = clf.score(X_test.astype(np.float32), y_test)
print(f"Train accuracy: {train_acc:.3f}, Test accuracy: {test_acc:.3f}")

## Optimization Problem: Adversarial Example

Find minimal L1 perturbation to misclassify a correctly classified image.

In [None]:
# Pick a correctly classified test example
idx = 0
x_input = X_test[idx]
true_label = y_test[idx]
predicted_label = clf.predict(x_input.reshape(1, -1).astype(np.float32))[0]
print(f"True label: {true_label}, Predicted: {predicted_label}")

# Verify it's correctly classified
assert predicted_label == true_label, "Example not correctly classified!"
print(f"Input L1 norm: {np.sum(np.abs(x_input)):.2f}")

## Create Gurobi Model

Single model reused for all formulations using `pred_constr.remove()`.

In [None]:
# Create Gurobi model
m = gp.Model()
m.Params.OutputFlag = 1
m.Params.TimeLimit = 120

# Decision variables: perturbed input (stay in [0,1])
x = m.addMVar(x_input.shape, lb=0, ub=1, name="x")

# Absolute value of perturbation using standard formulation
delta = m.addMVar(x_input.shape, lb=-GRB.INFINITY, name="delta")
abs_delta = m.addMVar(x_input.shape, name="abs_delta")
m.addConstr(delta == x - x_input)
m.addConstr(abs_delta >= delta)
m.addConstr(abs_delta >= -delta)

# Objective: minimize L1 norm of perturbation
m.setObjective(abs_delta.sum(), GRB.MINIMIZE)
m.update()

print(f"Model created with {m.NumVars} variables")

## 1. Standard MIP ReLU (Global Optimality)

- Uses piecewise-linear max formulation with binary variables
- Solved to **global optimality**
- Baseline for comparison

In [None]:
# Add predictor with standard MIP formulation
pred_constr = add_predictor_constr(m, nn_model, x.reshape(1, -1))
output = pred_constr.output

# Misclassification constraints: output[0, i] >= output[0, true_label] + 0.01 for i != true_label
misclass_constrs = []
for i in range(10):
    if i != true_label:
        misclass_constrs.append(
            m.addConstr(
                output[0, i] >= output[0, true_label] + 0.01, name=f"misclass_{i}"
            )
        )

m.update()
print(
    f"MIP formulation: {m.NumVars} vars, {m.NumConstrs} constrs, {m.NumBinVars} binary vars"
)

# Solve to global optimality
start = time.time()
m.optimize()
mip_time = time.time() - start

mip_obj = m.ObjVal if m.Status == GRB.OPTIMAL else float("inf")
mip_gap = m.MIPGap if m.Status == GRB.OPTIMAL else 1.0
print(f"\n{'=' * 60}")
print(
    f"MIP ReLU: Obj = {mip_obj:.4f}, Time = {mip_time:.2f}s, Gap = {mip_gap * 100:.2f}%"
)
print(f"{'=' * 60}")

## 2. Sqrt ReLU (Local Optimality)

- Uses `f(x) = (x + sqrt(x²))/2`, mathematically equivalent to ReLU
- **Not smooth**: still non-differentiable at x=0 (since sqrt(x²) = |x|)
- No binary variables; uses nonlinear barrier solver
- Solved to **local optimality** (OptimalityTarget=1)

In [None]:
# Remove MIP predictor and misclassification constraints
pred_constr.remove()
for c in misclass_constrs:
    m.remove(c)

# Add sqrt ReLU formulation
m.Params.NonConvex = 2
m.Params.OptimalityTarget = 1  # Local optimality
pred_constr = add_predictor_constr(
    m, nn_model, x.reshape(1, -1), relu_formulation="smooth"
)
output = pred_constr.output

# Re-add misclassification constraints
misclass_constrs = []
for i in range(10):
    if i != true_label:
        misclass_constrs.append(
            m.addConstr(
                output[0, i] >= output[0, true_label] + 0.01, name=f"misclass_{i}"
            )
        )

m.update()
print(
    f"Sqrt ReLU formulation: {m.NumVars} vars, {m.NumConstrs} constrs, {m.NumBinVars} binary vars"
)

start = time.time()
m.optimize()
sqrt_time = time.time() - start

sqrt_obj = m.ObjVal if m.Status in [GRB.OPTIMAL, GRB.USER_OBJ_LIMIT] else float("inf")
print(f"\n{'=' * 60}")
print(f"Sqrt ReLU: Obj = {sqrt_obj:.4f}, Time = {sqrt_time:.2f}s")
print(f"Gap to MIP: {(sqrt_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## 3. Soft ReLU / Softplus (β=1.0, Local Optimality)

- Uses `f(x) = log(1 + exp(βx))/β`
- **Smooth approximation** of ReLU; differentiable everywhere
- Lower β = smoother but less accurate approximation
- Solved to **local optimality**

In [None]:
# Remove previous predictor and constraints
pred_constr.remove()
for c in misclass_constrs:
    m.remove(c)

# Add soft ReLU with beta=1.0
pred_constr = add_predictor_constr(
    m, nn_model, x.reshape(1, -1), relu_formulation="soft", soft_relu_beta=1.0
)
output = pred_constr.output

misclass_constrs = []
for i in range(10):
    if i != true_label:
        misclass_constrs.append(
            m.addConstr(
                output[0, i] >= output[0, true_label] + 0.01, name=f"misclass_{i}"
            )
        )

m.update()

start = time.time()
m.optimize()
soft1_time = time.time() - start

soft1_obj = m.ObjVal if m.Status in [GRB.OPTIMAL, GRB.USER_OBJ_LIMIT] else float("inf")
print(f"\n{'=' * 60}")
print(f"Soft ReLU (β=1.0): Obj = {soft1_obj:.4f}, Time = {soft1_time:.2f}s")
print(f"Gap to MIP: {(soft1_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## 4. Soft ReLU / Softplus (β=5.0, Local Optimality)

- Same as above but with **higher β** → closer to ReLU
- Higher β = sharper transition, better approximation
- Solved to **local optimality**

In [None]:
# Remove previous predictor and constraints
pred_constr.remove()
for c in misclass_constrs:
    m.remove(c)

# Add soft ReLU with beta=5.0
pred_constr = add_predictor_constr(
    m, nn_model, x.reshape(1, -1), relu_formulation="soft", soft_relu_beta=5.0
)
output = pred_constr.output

misclass_constrs = []
for i in range(10):
    if i != true_label:
        misclass_constrs.append(
            m.addConstr(
                output[0, i] >= output[0, true_label] + 0.01, name=f"misclass_{i}"
            )
        )

m.update()

start = time.time()
m.optimize()
soft5_time = time.time() - start

soft5_obj = m.ObjVal if m.Status in [GRB.OPTIMAL, GRB.USER_OBJ_LIMIT] else float("inf")
print(f"\n{'=' * 60}")
print(f"Soft ReLU (β=5.0): Obj = {soft5_obj:.4f}, Time = {soft5_time:.2f}s")
print(f"Gap to MIP: {(soft5_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## Summary

In [None]:
print("\n" + "=" * 75)
print(f"{'Formulation':<28} {'Objective':<12} {'Time (s)':<12} {'Gap to MIP'}")
print("=" * 75)
print(f"{'MIP ReLU (global)':<28} {mip_obj:<12.4f} {mip_time:<12.2f} {'baseline'}")
print(
    f"{'Sqrt ReLU (local)':<28} {sqrt_obj:<12.4f} {sqrt_time:<12.2f} {(sqrt_obj / mip_obj - 1) * 100:+.1f}%"
)
print(
    f"{'Soft ReLU β=1.0 (local)':<28} {soft1_obj:<12.4f} {soft1_time:<12.2f} {(soft1_obj / mip_obj - 1) * 100:+.1f}%"
)
print(
    f"{'Soft ReLU β=5.0 (local)':<28} {soft5_obj:<12.4f} {soft5_time:<12.2f} {(soft5_obj / mip_obj - 1) * 100:+.1f}%"
)
print("=" * 75)

---
Copyright © 2023-2026 Gurobi Optimization, LLC