In [9]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [10]:
class Zonotope:
    def __init__(self, centers, generators):
        if centers.size(1) != generators.size(1):
            raise TypeError("The second dimension of input tensors do not match.")
        if centers.device != generators.device:
            raise TypeError("Input devices do not match.")
        self.centers = centers
        self.generators = generators
        self.device = generators.device

    def __repr__(self) -> str:
        return "Centers: \n" + self.centers.__str__() + "\n" + "Generators: \n" + self.generators.__str__() + "\n"

    def __add__(self, other):
        if isinstance(other, self.__class__):
            if self.get_num_noise() < other.get_num_noise():
                self.expand(other.get_num_noise() - self.get_num_noise())
            else:
                other.expand(self.get_num_noise() - other.get_num_noise())
            return Zonotope(self.centers + other.centers, self.generators + other.generators)
        if isinstance(other, (int, float)):
            return Zonotope(self.centers + other, self.generators)
        if isinstance(other, torch.Tensor):
            if self.centers.size() != other.size():
                raise TypeError("Invalid size of input tensor.")
            return Zonotope(self.centers + other, self.generators)
        raise TypeError(f"Unsupported operation `+` for class {self.__class__} and {type(other)}")

    __radd__ = __add__

    def __neg__(self):
        return Zonotope(-self.centers, self.generators)

    def __sub__(self, other):
        return -other + self

    def __mul__(self, other):
        if isinstance(other, self.__class__):
            if self.get_num_noise() < other.get_num_noise():
                self.expand(other.get_num_noise() - self.get_num_noise())
            else:
                other.expand(self.get_num_noise() - other.get_num_noise())
            _temp = self.centers * other.generators + other.centers * self.generators
            _new_noise = torch.diag(self.get_rad() * other.get_rad())
            _generators = torch.cat((_temp, _new_noise), dim=0)
            return Zonotope(self.centers * other.centers, _generators)
        if isinstance(other, (int, float)):
            return Zonotope(self.centers * other, self.generators * other)
        if isinstance(other, torch.Tensor):
            if self.centers.size() != other.size():
                raise TypeError("Invalid size of input tensor.")
            return Zonotope(self.centers * other, self.generators * other)
        raise TypeError(f"Unsupported operation `*` for class {self.__class__} and {type(other)}")

    def get_num_vars(self):
        return self.centers.size(1)

    def get_num_noise(self):
        return self.generators.size(0)

    def get_rad(self):
        return torch.sum(torch.abs(self.generators), dim=0)

    def get_lb(self):
        return self.centers - self.get_rad()

    def get_ub(self):
        return self.centers + self.get_rad()

    def expand(self, n):
        if n < 0:
            raise ValueError("Invalid size!")
        if n != 0:
            self.generators = torch.cat((self.generators, torch.zeros(n, self.get_num_vars(), device=device)), dim=0)
        return


In [11]:
z1 = Zonotope(centers=torch.tensor([[1., -3., 4.]], device=device),
              generators=torch.tensor([[1., 3., -4.], [2., 5., 8.]], device=device))
z2 = Zonotope(centers=torch.tensor([[-1., 3., 5.]], device=device),
              generators=torch.tensor([[1.5, -5.5, 4.8], [2.9, -5.2, 4.5], [-3.4, 2.2, 7.7]], device=device))
print(z1)
print(z2)

Centers: 
tensor([[ 1., -3.,  4.]], device='mps:0')
Generators: 
tensor([[ 1.,  3., -4.],
        [ 2.,  5.,  8.]], device='mps:0')

Centers: 
tensor([[-1.,  3.,  5.]], device='mps:0')
Generators: 
tensor([[ 1.5000, -5.5000,  4.8000],
        [ 2.9000, -5.2000,  4.5000],
        [-3.4000,  2.2000,  7.7000]], device='mps:0')



In [12]:
print(z1 + z2)
print(z1 + 3)
print(z1 + torch.tensor([[2., 4., 6.]], device=device))

Centers: 
tensor([[0., 0., 9.]], device='mps:0')
Generators: 
tensor([[ 2.5000, -2.5000,  0.8000],
        [ 4.9000, -0.2000, 12.5000],
        [-3.4000,  2.2000,  7.7000]], device='mps:0')

Centers: 
tensor([[4., 0., 7.]], device='mps:0')
Generators: 
tensor([[ 1.,  3., -4.],
        [ 2.,  5.,  8.],
        [ 0.,  0.,  0.]], device='mps:0')

Centers: 
tensor([[ 3.,  1., 10.]], device='mps:0')
Generators: 
tensor([[ 1.,  3., -4.],
        [ 2.,  5.,  8.],
        [ 0.,  0.,  0.]], device='mps:0')



In [13]:
print(z1.get_rad())
print(z1.get_lb())
print(z1.get_ub())

tensor([ 3.,  8., 12.], device='mps:0')
tensor([[ -2., -11.,  -8.]], device='mps:0')
tensor([[ 4.,  5., 16.]], device='mps:0')


In [14]:
print(z1 * 3)
print(z1 * torch.tensor([[1., 2., 3.]], device=device))
print(z1 * z2)

Centers: 
tensor([[ 3., -9., 12.]], device='mps:0')
Generators: 
tensor([[  3.,   9., -12.],
        [  6.,  15.,  24.],
        [  0.,   0.,   0.]], device='mps:0')

Centers: 
tensor([[ 1., -6., 12.]], device='mps:0')
Generators: 
tensor([[  1.,   6., -12.],
        [  2.,  10.,  24.],
        [  0.,   0.,   0.]], device='mps:0')

Centers: 
tensor([[-1., -9., 20.]], device='mps:0')
Generators: 
tensor([[  0.5000,  25.5000,  -0.8000],
        [  0.9000,  30.6000,  58.0000],
        [ -3.4000,  -6.6000,  30.8000],
        [ 23.4000,   0.0000,   0.0000],
        [  0.0000, 103.2000,   0.0000],
        [  0.0000,   0.0000, 204.0000]], device='mps:0')



In [15]:
def sigmoid(zono: Zonotope):
    _lb = zono.get_lb()
    _ub = zono.get_ub()
    _f_lb = torch.sigmoid(_lb)
    _f_ub = torch.sigmoid(_ub)
    _lambda_opt = torch.min((1 - _f_lb) * _f_lb, (1 - _f_ub) * _f_ub)
    _mu1 = 0.5 * (_f_ub + _f_lb - _lambda_opt * (_ub + _lb))
    _mu2 = 0.5 * (_f_ub - _f_lb - _lambda_opt * (_ub - _lb))
    _centers = _lambda_opt * zono.centers + _mu1
    _generators = torch.cat((_lambda_opt * zono.generators, torch.diag(torch.squeeze(_mu2))), dim=0)
    zono.expand(_generators.size(0) - zono.get_num_noise())
    return Zonotope(_centers, _generators)


In [16]:
print(sigmoid(z1))
print(z1)

Centers: 
tensor([[0.5506, 0.4967, 0.5002]], device='mps:0')
Generators: 
tensor([[ 1.7663e-02,  5.0103e-05, -4.7684e-07],
        [ 3.5325e-02,  8.3506e-05,  9.5367e-07],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 3.7842e-01,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  4.9651e-01,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  4.9983e-01]], device='mps:0')

Centers: 
tensor([[ 1., -3.,  4.]], device='mps:0')
Generators: 
tensor([[ 1.,  3., -4.],
        [ 2.,  5.,  8.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]], device='mps:0')

