# NN surrogate for the Rosenbrock function

This notebook provides an example of training a neural network as a surrogate model to an analytic function, and running the Airobas verification pipeline to determine the stability of the trained model.

We use the [Rosenbrock function](https://fr.wikipedia.org/wiki/Fonction_de_Rosenbrock) as our toy use case. 

## Mathematical definition

The Rosenbrock function is a N-dimentional non-convex function such that:
$$
f(\mathbf{x}) = \sum_{i=1}^{N-1} 100 (x_{i+1} - x_i^2 )^2 + (1-x_i)^2 \quad \text{where} \quad \mathbf{x} = [x_1, \ldots, x_N] \in \mathbb{R}^N.
$$

## SMT library

For the surrogate model training step, we also encourage people to look at the surrogate modeling toolbox ([SMT](https://github.com/SMTorg/smt/blob/master/tutorial/SMT_Tutorial.ipynb)) and tutorials that provides a range of surrogate options on the same function.

Some utilities function used in this tutorial are borrowed from this SMT library e.g., the creation of training/testing dataset or the 3D plotting stage, directly adapted from the code of the SMT tutorial. Special thanks to SMT creators for their agreement for us to reuse the code.

## Usefull Imports

In [None]:
import sys

sys.path.append("../")  # Import the airobas lib, edit the path accordingly if needed.
from IPython.display import clear_output

clear_output()
import random

import keras
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from keras.layers import Dense
from keras.models import Sequential
from keras.optimizers import Adam
from keras.regularizers import l2
from rosenbrock_verification import decomon_computation, image_dump_folder
from sklearn.metrics import mean_squared_error, root_mean_squared_error

In [None]:
# 3D plot routine


def plot_3d(xt, yt, xtest, ytest, fun, name_figure="Rosenbrock"):
    x = np.linspace(-2, 2, 50)
    res = []
    for x0 in x:
        for x1 in x:
            res.append(fun(np.array([[x0, x1]])))
    res = np.array(res)
    res = res.reshape((50, 50)).T
    X, Y = np.meshgrid(x, x)
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(projection="3d")
    surf = ax.plot_surface(X, Y, res, cmap=matplotlib.colormaps["viridis"], linewidth=0, antialiased=False, alpha=0.5)
    if xt is not None:
        ax.scatter(xt[:, 0], xt[:, 1], yt, zdir="z", marker="x", c="b", s=200, label="Training point")
    if xtest is not None:
        ax.scatter(xtest[:, 0], xtest[:, 1], ytest, zdir="z", marker=".", c="k", s=200, label="Validation point")
    ax.set_title(name_figure)
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.legend()

In [None]:
from smt.problems import Rosenbrock

In [None]:
plot_3d(None, None, None, None, Rosenbrock(ndim=2))
clear_output()

## Training a neural network as surrogate to the Rosenbrock function

### Data generation

In order to train the neural network, we first create a training and testing dataset. 

We implement 2 different methods to build these dataset:

- A basic grid-based solution and
- One based on Latin Hypercube Sampling (LHS; see SMT documentation [here](https://smt.readthedocs.io/en/latest/_src_docs/sampling_methods/lhs.html)) to generate quasi-random sampling distribution.


In [None]:
# Grid-based sampling


def create_points_grid(grid_size: int = 51, fraction_training: float = 0.2):
    fun = Rosenbrock(ndim=2)
    x = np.linspace(-2, 2, grid_size)
    res = []
    points = []
    for x0 in x:
        for x1 in x:
            res.append(fun(np.array([[x0, x1]])))
            points.append([x0, x1])
    random_indexes = set(random.sample(range(len(points)), k=int(fraction_training * len(points))))
    xt = np.array([points[i] for i in random_indexes])
    yt = np.array([res[i] for i in random_indexes])
    xtest = np.array([points[i] for i in range(len(points)) if i not in random_indexes])
    ytest = np.array([res[i][0] for i in range(len(points)) if i not in random_indexes])
    return xt, yt, xtest, ytest, fun

In [None]:
from smt.sampling_methods import LHS

In [None]:
# LHS-based sampling


def create_points(n_training: int = 20, n_test=200):
    ########### Initialization of the problem, construction of the training and validation points
    ndim = 2
    n_training = n_training
    # Define the function
    fun = Rosenbrock(ndim=ndim)
    # Construction of the DOE
    # in order to have the always same LHS points, random_state=1
    sampling = LHS(xlimits=fun.xlimits, criterion="ese", random_state=1)
    xt = sampling(n_training)
    # Compute the outputs
    yt = fun(xt)
    # Construction of the validation points
    n_test = n_test
    sampling = LHS(xlimits=fun.xlimits, criterion="ese", random_state=1)
    xtest = sampling(n_test)
    ytest = fun(xtest)
    return xt, yt, xtest, ytest, fun

In [None]:
xt_g, yt_g, xtest_g, ytest_g, fun = create_points_grid(grid_size=51, fraction_training=0.2)
xt_lhs, yt_lhs, xtest_lhs, ytest_lhs, fun = create_points(n_training=100, n_test=500)

for xt, xtest, tag in [(xt_g, xtest_g, "grid"), (xt_lhs, xtest_lhs, "Latin Hypercube")]:
    fig = plt.figure(figsize=(6, 6))
    plt.scatter(xt[:, 0], xt[:, 1], marker="x", c="b", s=50, label="Training points")
    plt.scatter(xtest[:, 0], xtest[:, 1], marker=".", c="k", s=50, label="Testing points")
    plt.title(f"Training & testing points with {tag} sampling")
    plt.xlabel("x1")
    plt.ylabel("x2")
    plt.legend()

### Training a feedforward neural network

We train a feedforward model with a few ReLu and a final linear layer. 

We note that the choice of the number of activation neurons and layers were here taken arbitrary and could be optimized for better performance.

In [None]:
def train_model(xt, yt, xtest, ytest, nb_epoch: int = 5000):
    np.random.seed(42)
    # Define the neural network model
    model = Sequential()
    model.add(Dense(64, input_dim=2, activation="relu"))  # , kernel_regularizer=l2(0.01)))
    model.add(Dense(64, activation="relu"))  # , kernel_regularizer=l2(0.01)))
    model.add(Dense(64, activation="relu"))  # , kernel_regularizer=l2(0.01)))
    model.add(Dense(64, activation="relu"))  # , kernel_regularizer=l2(0.01)))
    model.add(Dense(1, activation="linear"))

    # Compile the model
    model.compile(optimizer=Adam(learning_rate=0.001), loss="mse")

    # Train the model
    history = model.fit(xt, yt, epochs=nb_epoch, batch_size=32, validation_split=0.1, verbose=1)

    print("Training loss over epochs:")
    print(history.history["loss"])

    print("Validation loss over epochs:")
    print(history.history["val_loss"])

    y_pred = model.predict(xtest)
    test_loss = root_mean_squared_error(ytest, y_pred)
    print("Test Loss (Root Mean Squared Error):", test_loss)
    print()
    return model, y_pred

In [None]:
model, y_pred = train_model(xt_g, yt_g, xtest_g, ytest_g, nb_epoch=300)
clear_output()

In [None]:
model.summary()

## Local stability analysis

In the present study, the concept of local stability refers to how much the predicted values vary in a immediate neighborhood of a given test point. 

In real-life applications, system and safety requirements may impose that for two points close in values in the input space (e.g., a reference test point and a point generated from this reference point by perturbating it by a small input noise), the difference in the model predicted values stays within a given stability range.

We wish to assess and hopefully guarantee this local stability property for the neural network surrogate of the Rosenbrook function.

### Surrogate lower and upper bounds obtained via abstract interpretation method

We derive lower and upper bound estimates of the prediction for a set of grid point (using the rosenbrock_example.decomon_computation script).

The default values for the allowed input noise in a $L_{\infty}$ box of +/- 0.1 on each input dimension.
The default method used to derive bounds is [CROWN](https://github.com/IBM/CROWN-Robustness-Certification), implemented in the open-source library [decomon](https://github.com/airbus/decomon).

In [None]:
decomon_computation(model)

We invite you to have a look at the ```images/``` folder where a number of analysis plots are saved.

The following figures shows the $x1 = 0.0$, $x1 = 2.0$, $x2 = 0.0$, $x2 = 0.0$ slices respectively for the Rosenbrook function, surrogate model as well as estimated upper and lower bounds. 

![im](rosenbrock_images/slice_x1eq0.0.png)
![im](rosenbrock_images/slice_x1eq2.0.png)
![im](rosenbrock_images/slice_x2eq0.0.png)
![im](rosenbrock_images/slice_x2eq2.0.png)

<div class="alert alert-warning">If images not rendered (see slice_x1eq0.0.png, slice_x1eq2.0.png, slice_x2eq0.0.png and slice_x2eq2.0.png in the "images" folder)</div>

These results provide insights in the model stability (or lack of). The visualisations allow the identification of areas where the lower and upper bounds on the output might be too large.

### Stability property assessment using a combinaison of verification techniques

##### Defining the problem container

Let's start by defining the stability property we want to assess/guarantee.
Here, we will consider: 

- an input perturbation of +/- ```abs_noise_input``` for each tested input point (x1,x2)
- a stability property depending on the predicted output value range $pred$:
     - if $abs(pred)\leq $ ```threshold_for_abs_noise```: the "stable" output range is [$pred$-```abs_noise_output```, $pred$+```abs_noise_output```]
     - else: the "stable" output range is [$pred$ - ```rel_noise_output```.abs($pred$), $pred$+```rel_noise_output```.abs($pred$)]

In [None]:
from airobas.verif_pipeline import (
    BoundsDomainBoxParameter,
    BoundsDomainBoxParameterPerValueInterval,
    ProblemContainer,
)


class RosenbrockContainer(ProblemContainer):
    @staticmethod
    def create_rosenbrock_container(
        model: keras.Model,
        abs_noise_input: float = 0.03,
        abs_noise_output: float = 10.0,
        rel_noise_output: float = 0.2,
        threshold_for_abs_noise: float = 200,
        use_different_zones_for_output: bool = True,
    ) -> "RosenbrockContainer":
        if use_different_zones_for_output:
            output_bound_domain_param = BoundsDomainBoxParameterPerValueInterval(
                [
                    (
                        -float("inf"),
                        -threshold_for_abs_noise,
                        BoundsDomainBoxParameter(rel_noise=rel_noise_output, use_relative=True),
                    ),
                    (
                        -threshold_for_abs_noise,
                        threshold_for_abs_noise,
                        BoundsDomainBoxParameter(abs_noise=abs_noise_output, use_relative=False),
                    ),
                    (
                        threshold_for_abs_noise,
                        float("inf"),
                        BoundsDomainBoxParameter(rel_noise=rel_noise_output, use_relative=True),
                    ),
                ]
            )
        else:
            output_bound_domain_param = BoundsDomainBoxParameter(abs_noise=abs_noise_output, use_relative=False)
        stability_property = StabilityProperty(
            input_bound_domain_param=BoundsDomainBoxParameter(abs_noise=abs_noise_input, use_relative=False),
            output_bound_domain_param=output_bound_domain_param,
        )
        return RosenbrockContainer(tag_id="rosenbrock", model=model, stability_property=stability_property)

To create the output property, we make use of the ```BoundsDomainBoxParameterPerValueInterval``` class, allowing to define different output properties by interval of values. 
For e.g, 
```python

(-float("inf"), -threshold_for_abs_noise, BoundsDomainBoxParameter(rel_noise=rel_noise_output, use_relative=True))
```
means that if the expected value is between $-\infty$ and  ```-threshold_for_abs_noise```, then the output property will use a relative noise.

#### Verification pipeline definition

The verification pipeline is built from a sequence of individual blocks of verification that are executed on the remaining test points for which the stability property has not been assessed yet (See e.g., Figure 1 of [Airobas](https://arxiv.org/pdf/2401.06821)'s paper).



In [None]:
import time

from airobas.blocks_hub.adv_block import CleverHansMultiIndexAdvBlock
from airobas.blocks_hub.decomon_block import DecomonBlock
from airobas.blocks_hub.marabou_block import MarabouBlock
from airobas.verif_pipeline import (
    BoundsDomainBoxParameter,
    BoundsDomainBoxParameterPerValueInterval,
    ProblemContainer,
    StabilityProperty,
    StatusVerif,
    compute_bounds,
    full_verification_pipeline,
)

container = RosenbrockContainer.create_rosenbrock_container(
    model,
    abs_noise_input=0.03,
    abs_noise_output=20,
    rel_noise_output=0.1,
    threshold_for_abs_noise=200,
    use_different_zones_for_output=True,
)

Let's first build a verification pipeline consisting in only an adversarial attack generation.

In [None]:
blocks = [
    (
        CleverHansMultiIndexAdvBlock,
        {"list_params_adv_block": [{"index_target": i, "attack_up": True, "fgs": True} for i in range(yt_g.shape[1])]},
    )
]

In [None]:
t1 = time.perf_counter()
global_verif = full_verification_pipeline(
    problem=container,
    input_points=xtest_g,
    output_points=y_pred,  # or ytest if you target ground truth
    blocks_verifier=blocks,
    verbose=True,
)
t2 = time.perf_counter()
clear_output()

In [None]:
from airobas.verif_pipeline import StatusVerif

print(np.sum(global_verif.status == StatusVerif.VERIFIED), " verified points")
print(np.sum(global_verif.status == StatusVerif.VIOLATED), " violated points")
print(np.sum(global_verif.status == StatusVerif.TIMEOUT), " timeout points")
print(np.sum(global_verif.status == StatusVerif.UNKNOWN), " unknown points")

We observe that the adversarial attack brick allows to identify test points whose stability property is disproven. It does not procure any robustness guarantee.

Let's now add a layer of abstract interpretation-based incomplete formal verification in order to derive fast verification guarantees for a fraction of the test points.

In [None]:
blocks = [
    (
        CleverHansMultiIndexAdvBlock,
        {"list_params_adv_block": [{"index_target": i, "attack_up": True, "fgs": True} for i in range(yt_g.shape[1])]},
    )
]
blocks += [(DecomonBlock, {})]

t1 = time.perf_counter()
global_verif = full_verification_pipeline(
    problem=container,
    input_points=xtest_g,
    output_points=y_pred,  # or ytest if you target ground truth
    blocks_verifier=blocks,
    verbose=True,
)
t2 = time.perf_counter()
clear_output()

from airobas.verif_pipeline import StatusVerif

print(np.sum(global_verif.status == StatusVerif.VERIFIED), " verified points")
print(np.sum(global_verif.status == StatusVerif.VIOLATED), " violated points")
print(np.sum(global_verif.status == StatusVerif.TIMEOUT), " timeout points")
print(np.sum(global_verif.status == StatusVerif.UNKNOWN), " unknown points")

We observe that the incomplete formal verification allows to provide robustness guarantee for a number of test points on top of the non-stability cases identified by the adversarial attack generation alone.

Let's finally build a verification pipeline consisting in:

- a first step of adversarial attack generation, followed by
- an abstract verification method step in order to converge on verification guarantees for a fraction of the test points and 
- a complete/exact method step based on the Satisfiability modulo theory using [Marabou](https://github.com/NeuralNetworkVerification/Marabou).

In [None]:
blocks = [
    (
        CleverHansMultiIndexAdvBlock,
        {"list_params_adv_block": [{"index_target": i, "attack_up": True, "fgs": True} for i in range(yt_g.shape[1])]},
    )
]
blocks += [(DecomonBlock, {}), (MarabouBlock, {"time_out": 100})]

t1 = time.perf_counter()
global_verif = full_verification_pipeline(
    problem=container,
    input_points=xtest_g,
    output_points=y_pred,  # or ytest if you target ground truth
    blocks_verifier=blocks,
    verbose=True,
)
t2 = time.perf_counter()
clear_output()

from airobas.verif_pipeline import StatusVerif

print(np.sum(global_verif.status == StatusVerif.VERIFIED), " verified points")
print(np.sum(global_verif.status == StatusVerif.VIOLATED), " violated points")
print(np.sum(global_verif.status == StatusVerif.TIMEOUT), " timeout points")
print(np.sum(global_verif.status == StatusVerif.UNKNOWN), " unknown points")

The verification is here complete with all test points having been evaluated and labeled as stable/non-stable.

Let's now look at the details on how which method of the pipeline was able to conclude on a given test point.

In [None]:
methods = np.array(global_verif.methods)
index_that_concluded = global_verif.index_block_that_concluded
methods_concluded = methods[index_that_concluded]
# Count the unique values and their counts
unique_values, counts = np.unique(methods_concluded, return_counts=True)
# Print the results
print("Methods: ", unique_values)
print("Points where the methods concluded (nb_met1, nb_met2, nb_met3) : ", counts)

In this final experiment: 
- the adversarial attack brick was able to find counter examples for nb_met1 points out of 2000 test points. (2000-nb_met1)) remain to be assessed.
- decomon (abstract interpretation) concluded on nb_met2 test points. 2000-nb_met1-nb_met2 remain to be assessed.
- The Marabou solver concludes on the last nb_met3 points.

#### Visualisation of non-robust test points
In this final section, we propose an additional visualisation in 2d of potential adversarial attacks.
If a given point $(x1,x2)$ has been successfully attacked, we focus on a zone $[x=x1, y\in [x2-\delta, x2+\delta]]$. For different values of $y$ discretized between $x2-\delta$ and $x2+\delta$, we can compute and plot : 
- the expected output bounds computed considering our input perturbation domain
-  bounds found by abstract interpretation
-  ground truth and actual prediction of the model.
- Finally we draw a cross where the y-axis value is the output value of the found adversarial attack and we annotate the input coordinate of the attack.

In [None]:
import os

from decomon.models import clone

indexes = np.nonzero(global_verif.status == StatusVerif.VIOLATED)
cnt_x = np.array([x for x in global_verif.inputs if x is not None])
cnt_y = np.array(
    [global_verif.outputs[i] for i in range(len(global_verif.outputs)) if global_verif.outputs[i] is not None]
)
y_exp_min, y_exp_max = compute_bounds(container.stability_property, y_pred, is_input=False)
decomon_model = clone(model)
for index_counter_example in range(min(cnt_x.shape[0], 20)):
    original_point = xtest_g[indexes[0][index_counter_example]]
    # x_val = cnt_x[index_counter_example, 0]
    x_val = original_point[0]
    y_val = original_point[1]
    expected_value = fun(np.array([original_point]))
    found_value = cnt_y[index_counter_example]
    y = np.linspace(max(-2, y_val - 0.2), min(2, y_val + 0.2), 100)
    vals = np.array([[x_val, yi] for yi in y])
    x_min_, x_max_ = compute_bounds(container.stability_property, vals, is_input=True)
    box = np.concatenate([x_min_[:, None], x_max_[:, None]], 1)
    y_up_, y_low_ = decomon_model.predict(box)
    ground_truth = fun(vals)
    output_property = model.predict(vals)
    y_min_, y_max_ = compute_bounds(container.stability_property, output_property, is_input=False)
    fig, ax = plt.subplots(1)
    ax.plot(y, ground_truth, color="blue", label="ground truth")
    ax.plot(y, output_property, color="green", label="surrogate")

    ax.plot(y, y_min_, color="orange", linestyle="--", label="lower bound stability")
    ax.plot(y, y_max_, color="red", linestyle="--", label="upper bound stability")

    ax.plot(y, y_low_, color="orange", label="lower bound decomon")
    ax.plot(y, y_up_, color="red", label="upper bound decomon")

    ax.scatter([y_val], [found_value], marker="x", s=500)
    ax.annotate(
        f"""
                coor cnt example = ({cnt_x[index_counter_example][0]:.3g}, {cnt_x[index_counter_example][1]:.3g})
                coor orig point = ({x_val:.3g}, {y_val:.3g})
                """,
        (y_val, found_value),
        xytext=(0, 10),
        textcoords="offset points",
        ha="center",
        va="bottom",
        bbox=dict(boxstyle="round,pad=0.", fc="blue", ec="black", alpha=0.3),
    )
    ax.legend(loc="lower left")
    ax.set_title(f"Slice around counter example, x1={x_val}")
    ax.set_xlabel("x2")
    plt.tight_layout()
    fig.savefig(os.path.join(image_dump_folder, f"cnt_example_{index_counter_example}.png"))

    if index_counter_example != 5:
        plt.close(fig)

clear_output()

The ```images/``` folder contains some examples of visualisation of counter examples.

Examples of counter examples:

![im](rosenbrock_images/cnt_example_0.png)
![im](rosenbrock_images/cnt_example_1.png)
![im](rosenbrock_images/cnt_example_2.png)
![im](rosenbrock_images/cnt_example_3.png)
