In [53]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [54]:
class Zonotope:
    def __init__(self, centers, generators):
        if centers.size(1) != generators.size(1):
            raise TypeError("The second dimension of input tensors do not match.")
        if centers.device != generators.device:
            raise TypeError("Input devices do not match.")
        self.centers = centers
        self.generators = generators
        self.device = generators.device

    def __repr__(self) -> str:
        return "Centers: \n" + self.centers.__str__() + "\n" + "Generators: \n" + self.generators.__str__() + "\n"

    def __add__(self, other):
        if isinstance(other, self.__class__):
            if self.get_num_noise() < other.get_num_noise():
                self.expand(other.get_num_noise() - self.get_num_noise())
            else:
                other.expand(self.get_num_noise() - other.get_num_noise())
            return Zonotope(self.centers + other.centers, self.generators + other.generators)
        if isinstance(other, (int, float)):
            return Zonotope(self.centers + other, self.generators)
        if isinstance(other, torch.Tensor):
            if self.centers.size() != other.size():
                raise TypeError("Invalid size of input tensor.")
            return Zonotope(self.centers + other, self.generators)
        raise TypeError(f"Unsupported operation `+` for class {self.__class__} and {type(other)}")

    __radd__ = __add__

    def get_num_vars(self):
        return self.centers.size(1)

    def get_num_noise(self):
        return self.generators.size(0)

    def expand(self, n):
        if n < 0:
            raise ValueError("Invalid size!")
        if n != 0:
            self.generators = torch.cat((self.generators, torch.zeros(n, self.get_num_vars(), device=device)), dim=0)
        return


In [55]:
z1 = Zonotope(centers=torch.tensor([[1., -3., 4.]], device=device),
              generators=torch.tensor([[1., 3., 4.], [2., 5., 8.]], device=device))
z2 = Zonotope(centers=torch.tensor([[-1., 3., 5.]], device=device),
              generators=torch.tensor([[1.5, 5.5, 4.8], [2.9, 5.2, 4.5], [3.4, 2.2, 7.7]], device=device))
print(z1)
print(z2)

Centers: 
tensor([[ 1., -3.,  4.]], device='mps:0')
Generators: 
tensor([[1., 3., 4.],
        [2., 5., 8.]], device='mps:0')

Centers: 
tensor([[-1.,  3.,  5.]], device='mps:0')
Generators: 
tensor([[1.5000, 5.5000, 4.8000],
        [2.9000, 5.2000, 4.5000],
        [3.4000, 2.2000, 7.7000]], device='mps:0')



In [56]:
print(z1 + z2)
print(z1 + 3)
print(z1 + torch.tensor([[2., 4., 6.]], device=device))


Centers: 
tensor([[0., 0., 9.]], device='mps:0')
Generators: 
tensor([[ 2.5000,  8.5000,  8.8000],
        [ 4.9000, 10.2000, 12.5000],
        [ 3.4000,  2.2000,  7.7000]], device='mps:0')

Centers: 
tensor([[4., 0., 7.]], device='mps:0')
Generators: 
tensor([[1., 3., 4.],
        [2., 5., 8.],
        [0., 0., 0.]], device='mps:0')

Centers: 
tensor([[ 3.,  1., 10.]], device='mps:0')
Generators: 
tensor([[1., 3., 4.],
        [2., 5., 8.],
        [0., 0., 0.]], device='mps:0')

