# Adversarial example using ONNX

In this example, we demonstrate finding adversarial examples for a neural network using Gurobi and gurobi-ml's ONNX support.

We load a pre-trained MNIST classifier (stored as an ONNX model) and use optimization to find small perturbations to an input image that cause misclassification.

This example requires:
 - [matplotlib](https://matplotlib.org/)
 - [onnx](https://onnx.ai/)
 - [keras](https://keras.io/) (only for loading MNIST data)

In [None]:
from matplotlib import pyplot as plt
import numpy as np
import onnx
from tensorflow import keras

import gurobipy as gp
from gurobi_ml import add_predictor_constr

## Load MNIST data

We use Keras only to load the MNIST dataset.

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Reshape and normalize
x_test = x_test.astype("float32") / 255.0
x_test_flat = x_test.reshape(-1, 28 * 28)

## Load pre-trained ONNX model

We load a pre-trained neural network with 2 hidden layers of 50 neurons each and ReLU activations.
The model was trained on MNIST and converted to ONNX format.

In [None]:
onnx_model = onnx.load("mnist_model.onnx")
print("ONNX model loaded successfully")
print(f"Model has {len(onnx_model.graph.node)} operations")

## Verify model predictions

Let's verify the model works by making a prediction on a test sample.

In [None]:
# Use onnxruntime for inference
import onnxruntime as ort

session = ort.InferenceSession(onnx_model.SerializeToString())
input_name = session.get_inputs()[0].name

# Predict on a test sample
sample_idx = 18
sample = x_test_flat[sample_idx : sample_idx + 1]
prediction = session.run(None, {input_name: sample})[0]

print(f"True label: {y_test[sample_idx]}")
print(f"Predicted: {np.argmax(prediction)}")

# Display the image
plt.imshow(x_test[sample_idx], cmap="gray")
plt.title(f"True: {y_test[sample_idx]}, Predicted: {np.argmax(prediction)}")
plt.axis("off")
plt.show()

## Select an example for adversarial attack

We choose a test example that is correctly classified and define the target misclassification.

In [None]:
example = x_test_flat[sample_idx : sample_idx + 1]
right_label = int(y_test[sample_idx])
wrong_label = 8

print(f"Original label: {right_label}")
print(f"Target (wrong) label: {wrong_label}")

## Build the optimization model

We create a Gurobi model to find an adversarial example.
The objective is to maximize the score difference between the wrong label and correct label,
subject to the perturbed image being close to the original (measured by L1 distance).

In [None]:
m = gp.Model()
delta = 5  # Maximum L1 distance from original image

# Decision variables
x = m.addMVar(example.shape, lb=0.0, ub=1.0, name="x")
y = m.addMVar((1, 10), lb=-gp.GRB.INFINITY, name="y")  # Network output logits
abs_diff = m.addMVar(example.shape, lb=0, ub=1, name="abs_diff")

# Objective: maximize score of wrong label minus score of correct label
m.setObjective(y[0, wrong_label] - y[0, right_label], gp.GRB.MAXIMIZE)

# Constraints: bound L1 distance from original
m.addConstr(abs_diff >= x - example)
m.addConstr(abs_diff >= -x + example)
m.addConstr(abs_diff.sum() <= delta)

# Add neural network constraints
pred_constr = add_predictor_constr(m, onnx_model, x, y)

pred_constr.print_stats()

## Solve the optimization problem

We solve the model to find an adversarial example.

In [None]:
m.Params.BestBdStop = 0.0
m.Params.BestObjStop = 0.0
m.optimize()

## Display the adversarial example

If an adversarial example was found, we display it and verify the misclassification.

In [None]:
pred_constr.get_error()

In [None]:
adversarial_image = x.X.reshape(28, 28)

# Verify classification
adv_flat = x.X.reshape(1, -1).astype(np.float32)
adv_prediction = session.run(None, {input_name: adv_flat})[0]
predicted_label = np.argmax(adv_prediction)

# Display original and adversarial images
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(example.reshape(28, 28), cmap="gray")
axes[0].set_title(f"Original (label: {right_label})")
axes[0].axis("off")

axes[1].imshow(adversarial_image, cmap="gray")
axes[1].set_title(f"Adversarial (classified as: {predicted_label})")
axes[1].axis("off")

# Show difference
diff = np.abs(adversarial_image - example.reshape(28, 28))
axes[2].imshow(diff, cmap="hot")
axes[2].set_title(f"Difference (L1: {diff.sum():.2f})")
axes[2].axis("off")

plt.tight_layout()
plt.show()
if m.ObjVal > 0.0:
    print("\nAdversarial example found!")
    print(f"Original label: {right_label}")
    print(f"Predicted label: {predicted_label}")
    print(f"L1 distance: {diff.sum():.2f}")
else:
    print("No adversarial example exists within the specified distance bound.")

copyright © 2023-2025 Gurobi Optimization, LLC