### Training robust deep learning model using IBP (MNIST)

This notebook uses IBP to train robust models, following the approach of [Gowal et al., 2019](https://arxiv.org/pdf/1810.12715) with the improvements proposed by [Shi et al., 2021](https://proceedings.neurips.cc/paper/2021/hash/988f9153ac4fd966ea302dd9ab9bae15-Abstract.html)
Fundamentally, the code is based on [this script](https://github.com/google-deepmind/interval-bound-propagation/blob/master/examples/train.py) from [Gowal et al., 2019](https://arxiv.org/pdf/1810.12715) and is translated into the JAX framework.

1) Data loading and train/test split. 
2) IBP initialization (Shi et al., 2021)
3) Model definition
4) Loss function
5) Evaluation function
6) Training & Evaluatio

#### Run using Papermill
```bash
papermill ibp_mnist.ipynb -f PARAMETERS_FILE.yaml OUTPUT_FILE.ipynb
```
For example:
```bash
papermill ibp_mnist.ipynb -p dataset FashionMNIST ibp_fashion_mnist_out.ipynb
```
or
```bash
papermill ibp_mnist.ipynb -f emnist_params.yaml ibp_emnist_out.ipynb
```


In [None]:
import math
import pickle
from functools import partial
from typing import Any, Sequence

import flax.training.train_state
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from jax.nn.initializers import Initializer
from torchvision.datasets import EMNIST, MNIST, FashionMNIST

from formalax import Box, ibp

*Parameter Cell*: Parameterize and execute this notebook using [papermill](https://papermill.readthedocs.io/en/latest/).
- Supported datasets: MNIST, EMNIST and FashionMNIST.
- `eps`: Robustness perturbation radius.
- `num_epochs`: A 3-tuple of epoch numbers: `(prepare, warmup, main)`.
  The training process first runs `prepare` epochs without the robustness loss, then 
  linearly increaes the perturbation radius `eps` from zero to the final value in 
  the `warmup` epochs and finally trains with the full `eps` for `main` epochs.
- `model_arch`: Architecture of the neural network to train.
  The available layers are:
    - `("conv", filters, kernel_size)`: A convolutional layer with are square kernel, followed by a batch normalization layer.
    - `("relu",)`: A ReLU layer.
    - `("avg_pool", kernel_size, stride)`: An average pooling layer.
    - `("dense", features)`: A dense layer, followed by a batch normalization layer.
  Following
  Flattening layers are added automatically.
  Implicitly, the final layer is always a dense layer with the number of classes as features.
  You do not have to include this layer.

For EMNIST, we use the `bymerge` split.

In [None]:
dataset: str = "MNIST"
eps: float = 0.1
learning_rate: float = 0.001
robustness_loss_weight: float = 0.5
weight_decay: float = 1e-5
batch_size: int = 128
num_epochs: tuple[int, int, int] = (0, 20, 50)
seed: int = 0

model_arch: list[tuple[str, ...]] = [
    ("conv", 32, 3),
    ("relu",),
    ("avg_pool", 2, 2),
    ("conv", 64, 3),
    ("relu",),
    ("avg_pool", 2, 2),
    ("dense", 256),
    ("relu",),
]

data_dir: str = "../.datasets"
eval_batch_size: str = 2**11

#### 1) Data
We use PyTorch to load the datasets.
Normalize the pixel values to [0, 1].

In [None]:
match dataset:
    case "MNIST":
        dataset_cls = MNIST
        input_shape = (28, 28, 1)
        num_classes = 10
    case "EMNIST":
        dataset_cls = partial(EMNIST, split="bymerge")
        input_shape = (28, 28, 1)
        num_classes = 42
    case "FashionMNIST":
        dataset_cls = FashionMNIST
        input_shape = (28, 28, 1)
        num_classes = 10
    case _:
        raise ValueError(f"Unknown dataset: {dataset}.")


def load_data() -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Load the training and test datasets.

    Returns: Training set images, training set labels, test set images,
        test set labels
    """
    data_train = dataset_cls(data_dir, train=True, download=True)
    data_test = dataset_cls(data_dir, train=False, download=True)

    x_train = data_train.data.numpy()
    y_train = data_train.targets.numpy()
    x_test = data_test.data.numpy()
    y_test = data_test.targets.numpy()

    # Normalize the data
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    return x_train, y_train, x_test, y_test

#### 2) IBP Initialization
This code is based on https://github.com/jax-ml/jax/blob/5a2e5a5a94f78c96871a63a8a730164f1445d7a6/jax/_src/nn/initializers.py


In [None]:
def _compute_fans(
    shape: Sequence[int],
    in_axis: int | Sequence[int] = -2,
    out_axis: int | Sequence[int] = -1,
    batch_axis: int | Sequence[int] = (),
) -> tuple[float, float]:
    """
    Compute effective input and output sizes for a linear or convolutional layer.

    Axes not in in_axis, out_axis, or batch_axis are assumed to constitute the
    "receptive field" of a convolution (kernel spatial dimensions).
    """
    if len(shape) <= 1:
        raise ValueError(
            f"Can't compute input and output sizes of a {len(shape)}"
            "-dimensional weights tensor. Must be at least 2D."
        )

    if isinstance(in_axis, int):
        in_size = shape[in_axis]
    else:
        in_size = math.prod([shape[i] for i in in_axis])
    if isinstance(out_axis, int):
        out_size = shape[out_axis]
    else:
        out_size = math.prod([shape[i] for i in out_axis])
    if isinstance(batch_axis, int):
        batch_size = shape[batch_axis]
    else:
        batch_size = math.prod([shape[i] for i in batch_axis])
    receptive_field_size = math.prod(shape) / in_size / out_size / batch_size
    fan_in = in_size * receptive_field_size
    fan_out = out_size * receptive_field_size
    return fan_in, fan_out


def ibp_init(
    in_axis: int | Sequence[int] = -2,
    out_axis: int | Sequence[int] = -1,
    batch_axis: Sequence[int] = (),
) -> Initializer:
    """
    IBP-specific weight initialization as defined by Shi et al. (2021).

    Args:
      in_axis: axis or sequence of axes of the input dimension in the weights
        array.
      out_axis: axis or sequence of axes of the output dimension in the weights
        array.
      batch_axis: axis or sequence of axes in the weight array that should be
        ignored.
    """

    def init(key, shape, dtype=jnp.float32):
        # shape = jax.core.canonicalize_shape(shape)
        dtype = jax.dtypes.canonicalize_dtype(dtype)
        fan_in, _ = _compute_fans(shape, in_axis, out_axis, batch_axis)
        stddev = jnp.sqrt(2 * jnp.pi) / fan_in

        if jnp.issubdtype(dtype, jnp.floating):
            # constant is stddev of standard normal truncated to (-2, 2)
            stddev = stddev / jnp.array(0.87962566103423978, dtype)
            return jax.random.truncated_normal(key, -2, 2, shape, dtype) * stddev
        else:
            raise ValueError(f"Unsupported dtype: {dtype}")

    return init

#### 3) Model definition 

In [None]:
class CNN(nn.Module):
    """CNN model definition with the architecture from `net_arch`."""

    @nn.compact
    def __call__(self, x, training=True):
        x = x.reshape(-1, *input_shape)

        prev_is_image = len(input_shape) > 1
        for layer in model_arch:
            match layer:
                case ("conv", filters, kernel_size):
                    # Omit bias, as batch norm would effectively remove it anyway
                    x = nn.Conv(
                        features=filters,
                        kernel_size=(kernel_size, kernel_size),
                        use_bias=False,
                        kernel_init=ibp_init(),
                    )(x)
                    x = nn.BatchNorm(use_running_average=not training)(x)
                case ("relu",):
                    x = nn.relu(x)
                case ("avg_pool", window_size, stride):
                    x = nn.avg_pool(
                        x,
                        window_shape=(window_size, window_size),
                        strides=(stride, stride),
                    )
                case ("dense", features):
                    if prev_is_image:
                        x = x.reshape(x.shape[0], -1)
                        prev_is_image = False
                    x = nn.Dense(
                        features=features, use_bias=False, kernel_init=ibp_init()
                    )(x)
                    x = nn.BatchNorm(use_running_average=not training)(x)
        return nn.Dense(features=num_classes)(x)

#### 4. Loss Function


In [None]:
def robustness_loss(model_fn, images, labels_one_hot, eps):
    """Uses IBP bounds to compute the guaranteed cross entropy loss."""
    # Input bounds for IBP
    in_lb = jnp.clip(images - eps, 0.0, 1.0)  # Lower bound
    in_ub = jnp.clip(images + eps, 0.0, 1.0)  # Upper bound

    # Compute output bounds
    # Do not update batch norm stats during IBP => training=False
    out_lb, out_ub = ibp(model_fn)(Box(in_lb, in_ub), training=False)

    # True class logits to lower bound
    robust_scores = labels_one_hot * out_lb
    # Other class logits to upper bound
    robust_scores += (1.0 - labels_one_hot) * out_ub

    log_probs = nn.log_softmax(robust_scores)
    robust_loss = -jnp.mean(jnp.sum(labels_one_hot * log_probs, axis=-1))
    return robust_loss, robust_scores


def natural_loss(model_fn, images, labels_one_hot):
    """Computes the cross-entropy loss on the natural/clean data."""
    scores, variable_updates = model_fn(images, training=True, mutable=["batch_stats"])
    log_probs = nn.log_softmax(scores)
    cross_entropy = -jnp.mean(labels_one_hot * log_probs)
    return cross_entropy, (scores, variable_updates)


def l2_loss(param):
    """L2 regularization"""
    return jnp.mean(param**2)


def loss_fn(state, batch, model_params, *, robustness_eps):
    # Standard forward pass
    images = batch["images"]
    labels_one_hot = jax.nn.one_hot(batch["labels"], num_classes)

    model_fn = partial(
        state.apply_fn, {"params": model_params, "batch_stats": state.batch_stats}
    )

    nat_loss, (scores, updates) = natural_loss(model_fn, images, labels_one_hot)
    l2_reg = sum(l2_loss(w) for w in jax.tree.leaves(model_params)) / len(model_params)

    # we may set robustness_eps to 0 during warmup
    if robustness_eps > 0:
        rob_loss, robust_scores = robustness_loss(
            model_fn, images, labels_one_hot, robustness_eps
        )
        robust_acc = jnp.mean(jnp.argmax(robust_scores, axis=-1) == batch["labels"])
    else:
        rob_loss = robust_acc = 0.0

    loss = (1 - robustness_loss_weight) * (
        nat_loss + weight_decay * l2_reg
    ) + robustness_loss_weight * rob_loss
    nat_acc = jnp.mean(jnp.argmax(scores, axis=-1) == batch["labels"])
    return loss, (nat_loss, rob_loss, nat_acc, robust_acc, updates)

#### 5) Evaluation Functions

In [None]:
@jax.jit
def eval_step(state, batch):
    scores = state.apply_fn(
        {"params": state.params, "batch_stats": state.batch_stats},
        batch["images"],
        training=False,
    )
    return scores


def accuracy(state, batch):
    preds = eval_step(state, batch)
    preds = jnp.argmax(preds, axis=-1)
    return jnp.mean(preds == batch["labels"])


# Certified accuracy computation using IBP bounds
@jax.jit
def certified_accuracy(state, batch, eps):
    # Perturbation bounds
    in_lb = jnp.clip(batch["images"] - eps, 0.0, 1.0)  # Lower bound
    in_ub = jnp.clip(batch["images"] + eps, 0.0, 1.0)  # Upper bound

    # Compute output bounds with IBP
    out_lb, out_ub = ibp(state.apply_fn)(
        {"params": state.params, "batch_stats": state.batch_stats},
        Box(in_lb, in_ub),
        training=False,
    )

    correct_class = batch["labels"]

    # Get the lower bound for the correct class and the upper bounds for all other classes
    correct_class_lb = jnp.take_along_axis(
        out_lb, correct_class[:, None], axis=-1
    ).squeeze(axis=-1)
    other_classes_ub = jnp.max(
        jnp.where(
            jnp.arange(out_lb.shape[-1]) == correct_class[:, None], -jnp.inf, out_ub
        ),
        axis=-1,
    )

    certifiably_correct = correct_class_lb > other_classes_ub
    return jnp.mean(certifiably_correct)


def full(eval_fn, state, images, labels, batch_size, **kwargs):
    """Evaluates a metric like accuracy on an entire dataset."""
    metric = 0.0
    for i in range(0, len(images), batch_size):
        batch = {
            "images": images[i : i + batch_size],
            "labels": labels[i : i + batch_size],
        }
        metric += eval_fn(state, batch, **kwargs) * len(batch["images"])
    return metric / len(images)

### 5) Training 

Training step function:
- Pass inputs through model 
- Compute loss (cross entropy using log_softmax outputs + l2_regularization + robust loss)
- Compute gradients
- Update params in training state via gradients


In [None]:
rng_key = jax.random.PRNGKey(seed)

In [None]:
x_train, y_train, x_test, y_test = load_data()
x_train = x_train.reshape((-1, *input_shape))
x_test = x_test.reshape((-1, *input_shape))

In [None]:
model = CNN()
rng_key, subkey = jax.random.split(rng_key)
# Initialize both params and running batch statistics
model_variables = model.init(subkey, jnp.ones((1, *input_shape)), training=False)

In [None]:
optim = optax.adam(learning_rate)

In [None]:
class TrainState(flax.training.train_state.TrainState):
    batch_stats: Any


train_state = TrainState.create(apply_fn=model.apply, **model_variables, tx=optim)

In [None]:
@partial(jax.jit, static_argnames="robustness_eps")
def train_step(state, batch, *, robustness_eps):
    loss = partial(loss_fn, state, batch, robustness_eps=robustness_eps)
    grad_fn = jax.value_and_grad(loss, has_aux=True)
    (loss, (nat_loss, robust_loss, nat_acc, robust_acc, updates)), grads = grad_fn(
        state.params
    )
    new_state = state.apply_gradients(grads=grads)
    new_state = new_state.replace(batch_stats=updates["batch_stats"])
    return new_state, loss, nat_loss, robust_loss, nat_acc, robust_acc

Running the training loop below for EMNIST takes approx. **80 min** on machine without GPU

In [None]:
prepare_epochs, warmup_epochs, _ = num_epochs
total_epochs = sum(num_epochs)
epoch_len = len(x_train) // batch_size

log_epochs = epoch_len // 5

for epoch in range(total_epochs):
    if epoch < prepare_epochs:
        eps_ = 0.0
    elif epoch < prepare_epochs + warmup_epochs:
        eps_ = eps * (epoch - prepare_epochs + 1) / (warmup_epochs + 1)
    else:
        eps_ = eps

    running_loss = running_nat_loss = running_rob_loss = running_nat_acc = (
        running_rob_acc
    ) = 0.0

    # Reshuffle the training set every epoch
    rng_key, subkey = jax.random.split(rng_key)
    train_idx = jax.random.permutation(subkey, len(x_train))
    for i in range(epoch_len):
        batch_images = x_train[train_idx[i * batch_size : (i + 1) * batch_size]]
        batch_labels = y_train[train_idx[i * batch_size : (i + 1) * batch_size]]

        batch = {"images": batch_images, "labels": batch_labels}
        train_state, loss, nat_loss, rob_loss, nat_acc, rob_acc = train_step(
            train_state, batch, robustness_eps=eps_
        )

        running_loss += loss * (1 / log_epochs)
        running_nat_loss += nat_loss * (1 / log_epochs)
        running_rob_loss += rob_loss * (1 / log_epochs)
        running_nat_acc += nat_acc * (1 / log_epochs)
        running_rob_acc += rob_acc * (1 / log_epochs)

        if i % log_epochs == log_epochs - 1:
            test_nat_acc = full(accuracy, train_state, x_test, y_test, eval_batch_size)
            test_rob_acc = full(
                certified_accuracy,
                train_state,
                x_test,
                y_test,
                eval_batch_size,
                eps=eps,
            )
            print(
                f"[Epoch {epoch + 1}, {100 * i / epoch_len:3.0f}% (eps={eps_:.4f})] "
                f"Loss: {running_loss:.6f}, "
                f"Nat. Loss: {running_nat_loss:.6f}, "
                f"Rob. Loss: {running_rob_loss:.6f}, "
                f"Train Nat. Accuracy: {running_nat_acc * 100:.2f}%, "
                f"Train Rob. Accuracy: {running_rob_acc * 100:.2f}%, "
                f"Test Nat. Accuracy: {test_nat_acc * 100:.2f}%, "
                f"Test Rob. Accuracy: {test_rob_acc * 100:.2f}%"
            )
            running_loss = running_nat_loss = running_rob_loss = running_nat_acc = (
                running_rob_acc
            ) = 0.0

In [None]:
train_acc = full(accuracy, train_state, x_train, y_train, eval_batch_size)
train_cert_acc = full(
    certified_accuracy, train_state, x_train, y_train, eval_batch_size, eps=eps
)
test_acc = full(accuracy, train_state, x_test, y_test, eval_batch_size)
test_cert_acc = full(
    certified_accuracy, train_state, x_test, y_test, eval_batch_size, eps=eps
)

print("=" * 80)
print(
    f"Training Set: "
    f"Natural Accuracy {train_acc * 100:.2f}%, "
    f"Certified Accuracy {train_cert_acc * 100:.2f}%"
)
print(
    f"Test Set:     "
    f"Natural Accuracy {train_acc * 100:.2f}%, "
    f"Certified Accuracy {train_cert_acc * 100:.2f}%"
)

In [None]:
# Save the model parameters
with open(f"robust_{dataset.lower()}_cnn_flax.pkl", "wb") as f:
    pickle.dump(train_state.params, f)