# COMPARISON LIRPA VS DECOMON: FULLY CONNECTED MNIST

# PART A: TENSORFLOW

In [None]:
import sys

import numpy as np
import tensorflow.keras as keras
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras.models import Sequential

sys.path.append("..")
import time

from numpy.testing import assert_almost_equal

In [None]:
import matplotlib.pyplot as plt
import tensorflow.keras as keras

%matplotlib inline
import numpy as np
from tensorflow.keras.layers import Activation, Dense
from tensorflow.keras.models import Sequential

print("Notebook run using keras:", keras.__version__)
import sys

sys.path.append("../..")
from decomon.models.convert import clone as convert

In [None]:
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

### Build and Train a Neural Network on a sinusoide

The sinusoide funtion is defined on a $[-1 ; 1 ]$ interval. We put a factor in the sinusoide to have several periods of oscillations. 


In [None]:
x = np.linspace(-1, 1, 1000)
y = np.sin(10 * x)

We approximate this function by a fully connected network composed of 4 hidden layers of size 100, 100, 20 and 20 respectively. Rectified Linear Units (ReLU) are chosen as activation functions for all the neurons. 

In [None]:
layers = []
layers.append(Dense(100, activation="linear", input_dim=1))  # specify the dimension of the input space
layers.append(Activation("relu"))
layers.append(Dense(100, activation="linear"))
layers.append(Activation("relu"))
layers.append(Dense(20, activation="linear"))
layers.append(Activation("relu"))
layers.append(Dense(20, activation="linear"))
layers.append(Activation("relu"))
layers.append(Dense(1, activation="linear"))
model = Sequential(layers)

we specify the optimization method and the metric, in this case a classical Means Square Error. 

In [None]:
model.compile("adam", "mse")

we train the neural network

In [None]:
model.fit(x, y, batch_size=32, shuffle=True, epochs=100, verbose=0)
# verbose=0 removes the printing along the training

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

In [None]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.hidden_0 = nn.Linear(1, 100)  # input_dim = 1; output_dim = 100
        self.hidden_1 = nn.Linear(100, 100)
        self.hidden_2 = nn.Linear(100, 20)
        self.hidden_3 = nn.Linear(20, 20)
        self.hidden_4 = nn.Linear(1, 20)

        self.layers = [self.hidden_0, self.hidden_1, self.hidden_2, self.hidden_3, self.hidden_4]

    def forward(self, x):
        x = self.hidden_0(x)
        x = F.relu(x)
        x = self.hidden_1(x)
        x = F.relu(x)
        x = self.hidden_2(x)
        x = F.relu(x)
        x = self.hidden_3(x)
        x = F.relu(x)
        x = self.hidden_4(x)
        return x
        # x = x.view(-1, 128)
        # return x

    def reset_weights(self, model):
        layers = model.layers
        index = 0
        for layer_keras in layers:
            if len(layer_keras.get_weights()):
                print(layer_keras.name)
                layer_torch = self.layers[index]
                weights = layer_keras.get_weights()
                layer_torch.weight.data = torch.from_numpy(np.transpose(weights[0]))
                layer_torch.bias.data = torch.from_numpy(np.transpose(weights[1]))
                index += 1

In [None]:
model_torch = NeuralNet()
model_torch.reset_weights(model)

In [None]:
model_torch.reset_weights(model)

In [None]:
decomon_model_0 = convert(model, method="crown-ibp", ibp=True, forward=False)

In [None]:
0.22840400000000116 + 0.000354999999998995

In [None]:
# convert our model into a decomon model:
decomon_model_1 = convert(model, method="crown", ibp=True, forward=False)

### check the predictions

In [None]:
x_train_tensor = torch.from_numpy(x[:, None]).float().to("cpu")
y_pred_torch = model_torch(x_train_tensor).cpu().detach().numpy()

In [None]:
y_pred_torch = model_torch(x_train_tensor).cpu().detach().numpy()
y_pred_keras = model.predict(x)

In [None]:
assert_almost_equal(y_pred_keras, y_pred_torch, decimal=6)

In [None]:
plt.plot(x, y_pred_torch, "x")
plt.plot(x, y_pred_keras)

# AUTO LIRPA

In [None]:
# define the intervals


def get_range_box_comparison(method, model_decomon_1, model_torch, x_min=x.min(), x_max=x.max(), n_split=10):
    alpha = np.linspace(0, 1, n_split + 1)
    x_samples = (1 - alpha) * x_min + alpha * x_max
    X_min = x_samples[:-1][:, None]
    X_max = x_samples[1:][:, None]
    X_lirpa_ = (X_min + X_max) / 2.0
    eps = 0.5 * (x_max - x_min) / n_split

    # convert X_lirpa into a pytorch tensor
    X_lirpa = torch.from_numpy(X_lirpa_).float().to("cpu")
    import time

    start_time_torch = time.process_time()
    model_lirpa = BoundedModule(model_torch, X_lirpa)
    ptb = PerturbationLpNorm(norm=np.inf, eps=eps)
    input_lirpa = BoundedTensor(X_lirpa, ptb)

    lb, ub = model_lirpa.compute_bounds(x=(input_lirpa,), method=method)

    lb_ = lb.cpu().detach().numpy()
    ub_ = ub.cpu().detach().numpy()
    end_time_torch = time.process_time()

    start_time_decomon = time.process_time()
    boxes = np.concatenate([X_min[:, None], X_max[:, None]], 1)
    upper_, lower_ = model_decomon_1.predict(boxes)
    end_time_decomon = time.process_time()

    print(end_time_decomon - start_time_decomon, end_time_torch - start_time_torch)

    # upper_0, lower_0 = get_range_noise(model_decomon_0, X_lirpa_, eps, p=np.inf)
    # upper_, lower_ = get_range_box(model_decomon_1, X_min, X_max, fast=True)
    # upper_ = np.minimum(upper_0, upper_0)
    # lower_ = np.maximum(lower_1, lower_1)

    return X_lirpa_, model.predict(X_lirpa_), lb_, ub_, lower_, upper_

In [None]:
x_samples, y_samples, lb_p_0, ub_p_0, lb_t_0, ub_t_0 = get_range_box_comparison(
    "IBP+backward", decomon_model_0, model_torch, n_split=10
)

In [None]:
x_samples, y_samples, lb_p_1, ub_p_1, lb_t_1, ub_t_1 = get_range_box_comparison(
    "crown", decomon_model_1, model_torch, n_split=10
)

In [None]:
assert_almost_equal(ub_p_0, ub_t_0, decimal=5)

In [None]:
assert_almost_equal(lb_p_0, lb_t_0, decimal=5)

In [None]:
assert_almost_equal(lb_p_1, lb_t_1, decimal=5)

In [None]:
assert_almost_equal(ub_p_1, ub_t_1, decimal=5)