# Box Transformer

In this notebook, you are going to implement a very simple box propagation scheme with the goal of verifying the neural network shown below.

![Simple Net](../imgs/box_network.png)

Weights on the edges are weights. So if there's an edge connecting $a$ and $b$ with weight $2$, this means that $a = 2b$.

If multiple edges terminate at a single node, the edges are added, so if both $a$ and $b$ terminate at $c$, this means that $c = a + b$.

Finally, numbers next to nodes denote biases. For instance, if $a$ is connected to $b$ and there's a $3$ next to it, this means $b = a + 3$.

Remember that $\operatorname{ReLU}(x) = \max(0, x)$.

In this exercise, we want to prove that $o_5 > o_6$! So, after the box has been propagated through the network,
we want all possible values of $o_5$ to be bigger than all possible values of $o_6$. 

We've implemented all the "layers" of this network for you below.

In [28]:
import torch
import torch.nn as nn

class AddLayer(nn.Module):
    def __init__(self, bias=0.0):
        super(AddLayer, self).__init__()
        self.bias = nn.Parameter(torch.tensor(bias))

    def forward(self, x):
        return torch.sum(x, dim=1) + self.bias
    
class ScalarMulLayer(nn.Module):
    def __init__(self, weights=1.0):
        super(ScalarMulLayer, self).__init__()
        self.weights = nn.Parameter(torch.tensor(weights))

    def forward(self, x):
        return self.weights * x
    
class ReLULayer(nn.Module):
    def __init__(self):
        super(ReLULayer, self).__init__()

    def forward(self, x):
        return torch.max(x, torch.zeros_like(x))

As a warm-up, use the layers from above to implement the network shown in the image.
Note that the layers are `Modules` so you can use `nn.Sequential`!

In [85]:
# TODO: Your code goes here :)
model = nn.Sequential(
        ScalarMulLayer(
            torch.tensor([
                [1.0, 1.0],
                [1.0, -1.0]
            ])
        ),
        AddLayer(
            torch.tensor([0.0, 0.0])
        ),
        ReLULayer(),
        ScalarMulLayer(
            torch.tensor([
                [1.0, 1.0],
                [1.0, -1.0]
            ])
        ),
        AddLayer(torch.tensor([0.5, -0.5])),

)

model(torch.tensor([0.1, 0.1]))

  self.weights = nn.Parameter(torch.tensor(weights))
  self.bias = nn.Parameter(torch.tensor(bias))


tensor([ 0.7000, -0.3000], grad_fn=<AddBackward0>)

Next, we'll use the box abstraction to verify this network. Here are the rules:

- Addition: $[a, b] +^\# [c, d] = [a + c, b + d]$
- Negation: $-^\#[a, b] = [-b, -a]$
- ReLU: $ReLU^\#[a, b] = ReLU(a, b)$
- Scalar multiplication: $\lambda\cdot^\#[a, b] = [\lambda\cdot a, \lambda\cdot b]$ for $\lambda > 0$

A common pattern is to iterate over the layers in a module and build an "abstract copy" in parallel. One way of achieving this in PyTorch is to implement `nn.Modules` which perform the abstract operations.
Note that the input for the new, abstract layers is slightly different from the old, concrete ones: We have to pass both the lower and upper bound. So, when we were passing a tensor of shape `(3,3)` before, we will
now be passing one of shape `(2, 3, 3)` (we have a `3x3` tensor each, one for the lower and one upper bounds).
Below we provide you with an example for the `AddLayer` and ask you to implement the missing layers.

In [108]:
class AbstractAddLayer(nn.Module):
    def __init__(self, concrete_layer: AddLayer):
        super(AbstractAddLayer, self).__init__()
        self.bias = concrete_layer.bias

    def forward(self, bounds):
        return torch.sum(bounds, dim=-1) + self.bias
    
class AbstractScalarMulLayer(nn.Module):
    def __init__(self, concrete_layer: ScalarMulLayer):
        super(AbstractScalarMulLayer, self).__init__()
        self.weights = concrete_layer.weights

    def forward(self, bounds):
        # TODO: Implement the forward pass of the box abstraction of the scalar multiplication layer
        # NOTE: Make sure to handle negative weights correctly, i.e. using the negation rule!
        lower = self.weights * bounds[0]
        upper = self.weights * bounds[1]

        # Where the weights are negative, we need to swap the lower and upper bound entries
        return torch.stack([
                torch.where(self.weights < 0, upper, lower), 
                torch.where(self.weights >= 0, upper, lower)
            ])

class AbstractReLULayer(nn.Module):
    def __init__(self, _concrete_layer: ReLULayer):
        super(AbstractReLULayer, self).__init__()

    def forward(self, bounds):
        # TODO: Implement the forward pass of the box abstraction of the ReLU layer
        return torch.max(bounds, torch.zeros_like(bounds))

With the layers implemented, you can now simply iterate over the original network:

In [109]:
abstract_model = []
for layer in model:
    if isinstance(layer, AddLayer):
        abstract_model.append(AbstractAddLayer(layer))
    elif isinstance(layer, ScalarMulLayer):
        abstract_model.append(AbstractScalarMulLayer(layer))
    elif isinstance(layer, ReLULayer):
        abstract_model.append(AbstractReLULayer(layer))
    else:
        raise ValueError(f"Unknown layer type: {layer}")
abstract_model = nn.Sequential(*abstract_model)

In [111]:
bounds = abstract_model(torch.tensor([[0.0, 0.1], [0.3, 0.4]]))
print(f'o_5 in [{bounds[0, 0]}, {bounds[1, 0]}] and o_6 in [{bounds[0, 1]}, {bounds[1, 1]}]')

o_5 in [0.6000000238418579, 1.4000000953674316] and o_6 in [-0.6000000238418579, 0.20000004768371582]


As you can see, we succeeded in proving that $o_5 > o_6$ for our given input ranges.
What happens if you increase the input boxes to $[0, 0.6]$ and $[0.1, 0.7]$?


In [115]:
bounds = abstract_model(torch.tensor([[0, 0.1], [0.6, 0.7]]))
print(f'o_5 in [{bounds[0, 0]}, {bounds[1, 0]}] and o_6 in [{bounds[0, 1]}, {bounds[1, 1]}]')


o_5 in [0.6000000238418579, 2.299999952316284] and o_6 in [-0.8999999761581421, 0.7999999523162842]


Hm, this fails :(

This is why methods that are based on bounding box propagation are generally _incomplete_! Using Box, we fail to prove that the network classifies the inputs correctly, even though the property actually holds!