# ReLU Formulation Comparison

Compare different ReLU formulations: standard MIP (global optimal) vs. nonlinear formulations (local optimal).

**Requirements:** torch, torchvision, skorch, matplotlib, gurobipy >= 12.0

## Setup

In [None]:
from matplotlib import pyplot as plt

import time
import numpy as np
import torch
import torchvision
from skorch import NeuralNetClassifier
import gurobipy as gp
from gurobipy import GRB
from gurobi_ml import add_predictor_constr

## Load Data and Train Model

In [None]:
# Load MNIST
# Get MNIST digit recognition data set
mnist_train = torchvision.datasets.MNIST(root="./MNIST", train=True, download=True)

mnist_test = torchvision.datasets.MNIST(root="./MNIST", train=False, download=True)

x_train = torch.flatten(mnist_train.data.type(torch.FloatTensor), start_dim=1)
y_train = mnist_train.targets
x_test = torch.flatten(mnist_test.data.type(torch.FloatTensor), start_dim=1)
y_test = mnist_test.targets

x_train /= 255.0  # scaling
x_test /= 255.0  # scaling

print(f"Training samples: {len(x_train)}, Test samples: {len(x_test)}")
print(f"Input dimension: {x_train.shape[1]}, Classes: {len(np.unique(y_train))}")

In [None]:
nn_model = torch.nn.Sequential(
    torch.nn.Linear(28 * 28, 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 50),
    torch.nn.ReLU(),
    torch.nn.Linear(50, 10),
    torch.nn.Softmax(1),
)

In [None]:
clf = NeuralNetClassifier(
    nn_model,
    max_epochs=15,
    lr=0.1,
    iterator_train__shuffle=True,
)

clf.fit(X=x_train, y=y_train)

In [None]:
print(f"Training score: {clf.score(x_train, y_train):.4}")
print(f"Validation set score: {clf.score(x_test, y_test):.4}")

In [None]:
nn_regression = torch.nn.Sequential(*nn_model[:-1])

## Optimization Problem: Adversarial Example

Find minimal L1 perturbation to misclassify a correctly classified image.

In [None]:
imageno = 10000
image = mnist_train.data[imageno, :]
plt.imshow(image, cmap="gray")

In [None]:
ex_prob = nn_regression.forward(x_train[imageno, :])
sorted_labels = torch.argsort(ex_prob)
right_label = sorted_labels[-1]
wrong_label = sorted_labels[-2]

## Create Gurobi Model

Single model reused for all formulations using `pred_constr.remove()`.

In [None]:
# Create Gurobi model
m = gp.Model()
m.Params.OutputFlag = 1
m.Params.TimeLimit = 120

delta = 5

image = x_train[imageno, :].numpy()  # We need numpy converted image

x = m.addMVar(image.shape, lb=0.0, ub=1.0, name="x")
y = m.addMVar(ex_prob.detach().numpy().shape, lb=-gp.GRB.INFINITY, name="y")

abs_diff = m.addMVar(image.shape, lb=0, ub=1, name="abs_diff")

m.setObjective(y[wrong_label] - y[right_label], gp.GRB.MAXIMIZE)

# Bound on the distance to example in norm-1
m.addConstr(abs_diff >= x - image)
m.addConstr(abs_diff >= -x + image)
m.addConstr(abs_diff.sum() <= delta)

## 1. Standard MIP ReLU (Global Optimality)

- Uses piecewise-linear max formulation with binary variables
- Solved to **global optimality**
- Baseline for comparison

In [None]:
# Add predictor with standard MIP formulation
pred_constr = add_predictor_constr(m, nn_regression, x, y)

pred_constr.print_stats()

In [None]:
# Solve to global optimality
start = time.time()
m.optimize()
mip_time = time.time() - start

mip_obj = m.ObjVal if m.Status == GRB.OPTIMAL else float("inf")
mip_gap = m.MIPGap if m.Status == GRB.OPTIMAL else 1.0
print(f"\n{'=' * 60}")
print(
    f"MIP ReLU: Obj = {mip_obj:.4f}, Time = {mip_time:.2f}s, Gap = {mip_gap * 100:.2f}%"
)
print(f"{'=' * 60}")

## 2. Sqrt ReLU (Local Optimality)

- Uses `f(x) = (x + sqrt(x²))/2`, mathematically equivalent to ReLU
- **Not smooth**: still non-differentiable at x=0 (since sqrt(x²) = |x|)
- No binary variables; uses nonlinear barrier solver
- Solved to **local optimality** (OptimalityTarget=1)

In [None]:
# Remove MIP predictor and misclassification constraints
pred_constr.remove()

# Add sqrt ReLU formulation
pred_constr = add_predictor_constr(m, nn_regression, x, y, relu_formulation="smooth")

pred_constr.print_stats()

In [None]:
m.update()
print(
    f"Sqrt ReLU formulation: {m.NumVars} vars, {m.NumConstrs} constrs, {m.NumBinVars} binary vars"
)

start = time.time()
m.Params.OptimalityTarget = 1
m.optimize()
sqrt_time = time.time() - start

sqrt_obj = m.ObjVal if m.Status in [GRB.OPTIMAL, GRB.USER_OBJ_LIMIT] else float("inf")
print(f"\n{'=' * 60}")
print(f"Sqrt ReLU: Obj = {sqrt_obj:.4f}, Time = {sqrt_time:.2f}s")
print(f"Gap to MIP: {(sqrt_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## 3. Soft ReLU / Softplus (β=1.0, Local Optimality)

- Uses `f(x) = log(1 + exp(βx))/β`
- **Smooth approximation** of ReLU; differentiable everywhere
- Lower β = smoother but less accurate approximation
- Solved to **local optimality**

In [None]:
# Remove previous predictor and constraints
pred_constr.remove()

# Add soft ReLU with beta=1.0
pred_constr = add_predictor_constr(
    m, nn_regression, x, y, relu_formulation="soft", soft_relu_beta=1.0
)

m.update()

start = time.time()
m.optimize()
soft1_time = time.time() - start

soft1_obj = (
    m.ObjVal
    if m.Status in [GRB.OPTIMAL, GRB.LOCALLY_OPTIMAL, GRB.USER_OBJ_LIMIT]
    else float("inf")
)
print(f"\n{'=' * 60}")
print(f"Soft ReLU (β=1.0): Obj = {soft1_obj:.4f}, Time = {soft1_time:.2f}s")
print(f"Gap to MIP: {(soft1_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## 4. Soft ReLU / Softplus (β=5.0, Local Optimality)

- Same as above but with **higher β** → closer to ReLU
- Higher β = sharper transition, better approximation
- Solved to **local optimality**

In [None]:
# Remove previous predictor and constraints
pred_constr.remove()

# Add soft ReLU with beta=5.0
pred_constr = add_predictor_constr(
    m, nn_regression, x, y, relu_formulation="soft", soft_relu_beta=5.0
)
m.update()

start = time.time()
m.optimize()
soft5_time = time.time() - start

soft5_obj = m.ObjVal if m.Status in [GRB.OPTIMAL, GRB.USER_OBJ_LIMIT] else float("inf")
print(f"\n{'=' * 60}")
print(f"Soft ReLU (β=5.0): Obj = {soft5_obj:.4f}, Time = {soft5_time:.2f}s")
print(f"Gap to MIP: {(soft5_obj / mip_obj - 1) * 100:+.1f}%")
print(f"{'=' * 60}")

## Summary

In [None]:
print("\n" + "=" * 75)
print(f"{'Formulation':<28} {'Objective':<12} {'Time (s)':<12} {'Gap to MIP'}")
print("=" * 75)
print(f"{'MIP ReLU (global)':<28} {mip_obj:<12.4f} {mip_time:<12.2f} {'baseline'}")
print(
    f"{'Sqrt ReLU (local)':<28} {sqrt_obj:<12.4f} {sqrt_time:<12.2f} {(sqrt_obj / mip_obj - 1) * 100:+.1f}%"
)
print(
    f"{'Soft ReLU β=1.0 (local)':<28} {soft1_obj:<12.4f} {soft1_time:<12.2f} {(soft1_obj / mip_obj - 1) * 100:+.1f}%"
)
print(
    f"{'Soft ReLU β=5.0 (local)':<28} {soft5_obj:<12.4f} {soft5_time:<12.2f} {(soft5_obj / mip_obj - 1) * 100:+.1f}%"
)
print("=" * 75)

---
Copyright © 2023-2026 Gurobi Optimization, LLC