In [None]:
import numpy as np
from tensorflow.keras.layers import Activation, Conv2D, Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

In [None]:
model = Sequential()
model.add(Reshape((6, 6, 1), input_dim=36))
model.add(Conv2D(32, (6, 6), activation="linear", bias_initializer="zeros"))
model.add(Activation("relu"))
model.add(Flatten())
model.add(Dense(130, activation="linear"))
model.add(Activation("relu"))
model.add(Dense(1, activation="linear"))

In [None]:
import sys

sys.path.append("../..")
from decomon.models.convert import clone as convert

In [None]:
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(6, 6))
        self.fc1 = nn.Linear(32, 130)
        self.fc2 = nn.Linear(130, 1)
        self.layers = [self.conv1, self.fc1, self.fc2]

    def forward(self, x):
        x = torch.reshape(x, (-1, 1, 6, 6))
        x = self.conv1(x)
        x = F.relu(x)
        x = torch.reshape(x, (-1, 32))
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

    def reset_weights(self, model):
        layers = model.layers
        index = 0
        for layer_keras in layers:
            if len(layer_keras.get_weights()):
                layer_torch = self.layers[index]
                weights = layer_keras.get_weights()
                if len(weights[0].shape) == 2:
                    # dense layer
                    layer_torch.weight.data = torch.from_numpy(np.transpose(weights[0]))
                    layer_torch.bias.data = torch.from_numpy(np.transpose(weights[1]))
                else:
                    layer_torch.weight.data = torch.from_numpy(np.transpose(weights[0], (3, 2, 0, 1)))
                    layer_torch.bias.data = torch.from_numpy(weights[1])
                index += 1

In [None]:
model_torch = NeuralNet()

In [None]:
model_torch.reset_weights(model)

In [None]:
x = np.array([np.random.rand() for _ in range(100)])[:, None] * np.ones((100, 36))

In [None]:
x_train_tensor = torch.from_numpy(x).float().to("cpu")
y_pred_torch = model_torch(x_train_tensor).cpu().detach().numpy()

In [None]:
y_pred_keras = model.predict(x)

In [None]:
from numpy.testing import assert_almost_equal, assert_array_less

In [None]:
assert_almost_equal(y_pred_keras.flatten(), y_pred_torch.flatten())

In [None]:
import sys

sys.path.append("../..")
from decomon import get_range_box
from decomon.models.convert import clone as convert

In [None]:
# convert our model into a decomon model:
decomon_model_0 = convert(model, method="crown-ibp")

In [None]:
decomon_model_1 = convert(model, ibp=True, forward=False, method="crown")

In [None]:
from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm

In [None]:
# define the intervals


def get_range_box_comparison(method, model_decomon_1, model_torch, x_=x, eps=0.1):
    X_min = x - eps
    X_max = x + eps
    X_lirpa_ = (X_min + X_max) / 2.0

    # convert X_lirpa into a pytorch tensor
    X_lirpa = torch.from_numpy(X_lirpa_).float().to("cpu")

    model_lirpa = BoundedModule(model_torch, X_lirpa)
    ptb = PerturbationLpNorm(norm=np.inf, eps=eps)
    input_lirpa = BoundedTensor(X_lirpa, ptb)

    if method == "crown":
        IBP = False
    else:
        IBP = True

    lb, ub = model_lirpa.compute_bounds(x=(input_lirpa,), IBP=IBP, method=method)

    lb_ = lb.cpu().detach().numpy()
    ub_ = ub.cpu().detach().numpy()

    upper_, lower_ = get_range_box(model_decomon_1, X_min, X_max, fast=True)

    return X_lirpa_, model.predict(X_lirpa_), lb_, ub_, lower_, upper_

In [None]:
x_samples, y_samples, lb_p_0, ub_p_0, lb_t_0, ub_t_0 = get_range_box_comparison(
    "crown-ibp", decomon_model_0, model_torch
)

In [None]:
x_samples, y_samples, lb_p_1, ub_p_1, lb_t_1, ub_t_1 = get_range_box_comparison("crown", decomon_model_1, model_torch)

In [None]:
assert_almost_equal(ub_p_0, ub_t_0, decimal=5)

In [None]:
assert_almost_equal(lb_p_0, lb_t_0, decimal=5)

In [None]:
assert_almost_equal(ub_p_1, ub_t_1, decimal=5)

In [None]:
assert_almost_equal(lb_p_1, lb_t_1, decimal=5)

In [None]:
np.abs(ub_t_1 - ub_p_1).max()

In [None]:
np.abs(lb_t_1 - lb_p_1).max()

In [None]:
assert_array_less(ub_t_1, ub_p_1)

In [None]:
assert_array_less(lb_p_1, lb_t_1)