# Adversarial Example with PyTorch CNN

This notebook demonstrates how to embed a trained PyTorch CNN into a Gurobi model using `gurobi_ml` and construct an adversarial example.

We train a small CNN on MNIST, then, given a correctly classified image, we formulate a MILP that seeks a nearby image (in L1 norm) that flips the predicted label by maximizing the margin between a target wrong class and the correct class.

In [None]:
import os
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import gurobipy as gp
from gurobi_ml import add_predictor_constr
from matplotlib import pyplot as plt

device = torch.device("cpu")
torch.manual_seed(0)
np.random.seed(0)

## Define CNN model

We use only layers supported by the MILP embedding: Conv2d (padding=0), ReLU, MaxPool2d, Flatten, and Linear.

In [None]:
# Conv2d -> ReLU -> MaxPool2d -> Conv2d -> ReLU -> MaxPool2d -> Flatten -> Linear -> ReLU -> Linear
model = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(10, 20, kernel_size=3, padding=0),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    nn.Linear(20 * 5 * 5, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
)
model.to(device)

## Load MNIST data

Note: The first run may download MNIST. If running in an offline environment, ensure MNIST is cached locally.

In [None]:
transform = transforms.ToTensor()  # scales to [0,1]
train_set = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_set = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False)

## Quick training (optional)

A single epoch is enough for demonstration. You can increase epochs for better accuracy, or skip and load pre-trained weights if available.

In [None]:
epochs = int(os.environ.get("EPOCHS", "1"))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.train()
for epoch in range(epochs):
    running = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        running += loss.item() * xb.size(0)
    print(f"Epoch {epoch+1}: loss={running/len(train_loader.dataset):.4f}")

# Evaluate quickly
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for xb, yb in test_loader:
        logits = model(xb.to(device))
        pred = logits.argmax(dim=1).cpu()
        correct += (pred == yb).sum().item()
        total += yb.size(0)
print(f"Test accuracy: {correct/total:.3f}")

## Select a correctly classified example

We pick the first test example that the model currently classifies correctly, and define a target wrong label as the second-best logit.

In [None]:
example_img = None
true_label = None
with torch.no_grad():
    for xb, yb in test_loader:
        logits = model(xb.to(device))
        pred = logits.argmax(dim=1).cpu()
        mask = pred == yb
        if mask.any():
            idx = mask.nonzero(as_tuple=False)[0].item()
            example_img = xb[idx : idx + 1].cpu()  # shape (1,1,28,28)
            true_label = yb[idx].item()
            break

assert example_img is not None, "No correctly classified example found."
logits = model(example_img).squeeze(0)
probs = torch.softmax(logits, dim=0).cpu().detach().numpy()
sorted_labels = np.argsort(probs)
right_label = int(sorted_labels[-1])
wrong_label = (
    int(sorted_labels[-2])
    if int(sorted_labels[-1]) == true_label
    else int(sorted_labels[-1])
)
print(f"True label={true_label}, predicted={right_label}, target wrong={wrong_label}")
plt.imshow(example_img.squeeze(0).squeeze(0), cmap="gray")
plt.axis("off")
plt.show()

## Build MILP for adversarial example

We use an L1-ball constraint around the original image and maximize the margin `y[target] - y[true]`.

Important: The MILP embedding expects NHWC input (batch,height,width,channels). We convert the MNIST sample accordingly.

In [None]:
# Prepare example and shapes
example_np_nhwc = (
    example_img.squeeze(0).permute(1, 2, 0).cpu().numpy()[None, ...]
)  # (1,28,28,1)

m = gp.Model()
delta = 0.001  # L1 radius (tune as desired)

x = m.addMVar(example_np_nhwc.shape, lb=0.0, ub=1.0, name="x")
y = m.addMVar((1, 10), lb=-gp.GRB.INFINITY, name="y")

abs_diff = m.addMVar(example_np_nhwc.shape, lb=0.0, ub=1.0, name="abs_diff")
m.setObjective(y[0, wrong_label] - y[0, right_label], gp.GRB.MAXIMIZE)

# L1-ball constraints
m.addConstr(abs_diff >= x - example_np_nhwc)
m.addConstr(abs_diff >= -x + example_np_nhwc)
m.addConstr(abs_diff.sum() <= delta)

# Embed the PyTorch CNN (nn.Sequential)
pred_constr = add_predictor_constr(m, model, x, y)
pred_constr.print_stats()

In [None]:
pred_constr.layers

In [None]:
x.Start = example_np_nhwc

In [None]:
pred_constr.layers[2].input.lb = 0.0
pred_constr.layers[2].input.ub = 10.0
pred_constr.layers[5].input.lb = 0.0
pred_constr.layers[5].input.ub = 10.0

In [None]:
# Early stopping: stop when a counterexample is found or proven impossible
m.Params.BestBdStop = 0.0
m.Params.BestObjStop = 0.0
m.optimize()

## Visualize result

Plot the adversarial image and the perturbation.

In [None]:
if m.SolCount > 0:
    x_adv = x.X.squeeze(0)  # (28,28,1)
    x_adv_img = x_adv[..., 0]
    pert = x_adv_img - example_np_nhwc.squeeze(0)[..., 0]

    fig, axs = plt.subplots(1, 3, figsize=(10, 3))
    axs[0].set_title("Original")
    axs[0].imshow(example_np_nhwc.squeeze(0)[..., 0], cmap="gray")
    axs[0].axis("off")
    axs[1].set_title("Adversarial")
    axs[1].imshow(x_adv_img, cmap="gray")
    axs[1].axis("off")
    axs[2].set_title("Perturbation")
    im = axs[2].imshow(pert, cmap="bwr")
    axs[2].axis("off")
    plt.colorbar(im, ax=axs[2], fraction=0.046, pad=0.04)
    plt.show()

    # Check model's prediction on adversarial image
    with torch.no_grad():
        t_in = (
            torch.from_numpy(x_adv[None, ...]).float().permute(0, 3, 1, 2)
        )  # NHWC -> NCHW
        logits_adv = model(t_in)
        pred_adv = logits_adv.argmax(dim=1).item()
    print(f"Adversarial predicted label: {pred_adv}")
else:
    print("No adversarial example found within given L1 radius.")

In [None]:
# Fix input to the original example and check embedding error
# This cell ensures get_error() is ~0 after the PyTorch CNN fix.
try:
    x.lb = example_np_nhwc
    x.ub = example_np_nhwc
except Exception as e:
    print("Failed to set bounds on x:", e)

# Optimize with tiny time limit; only feasibility is needed
m.setObjective(0.0)
m.setParam("TimeLimit", 1.0)
m.optimize()


err = pred_constr.get_error()
print("Max error:", float(np.max(err.astype(float))))