In [18]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [48]:
class DualIntervalTensor:

    def __init__(self, real_lb: torch.Tensor, real_ub: torch.Tensor, dual_lb: torch.Tensor, dual_ub: torch.Tensor):
        if not (
                real_lb.size() == dual_lb.size() and real_lb.size() == real_ub.size() and dual_lb.size() == dual_ub.size()):
            raise TypeError("Input sizes not match.")

        if not (
                real_lb.device == dual_lb.device and real_lb.device == real_ub.device and dual_lb.device == dual_ub.device):
            raise TypeError("Input devices not match.")

        if not (torch.all(torch.le(real_lb, real_ub)) and torch.all(torch.le(dual_lb, dual_ub))):
            raise ValueError("Some lower bound greater than upper bound.")

        self.real_lb = real_lb
        self.real_ub = real_ub
        self.dual_lb = dual_lb
        self.dual_ub = dual_ub
        self.device = real_lb.device
        self.dtype = real_lb.dtype

    def __repr__(self) -> str:
        from pprint import pformat
        return pformat(vars(self), indent=4, width=1, sort_dicts=False)

    def __add__(self, other):
        if isinstance(other, self.__class__):
            new_real_lb = self.real_lb + other.real_lb
            new_real_ub = self.real_ub + other.real_ub
            new_dual_lb = self.dual_lb + other.dual_lb
            new_dual_ub = self.dual_ub + other.dual_ub
            return DualIntervalTensor(new_real_lb, new_real_ub, new_dual_lb, new_dual_ub)
        if isinstance(other, (int, float)):
            return DualIntervalTensor(self.real_lb + other, self.real_ub + other, self.dual_lb, self.dual_ub)
        raise TypeError(f"Unsupported operation `+` for class {self.__class__} and {type(other)}")

    __radd__ = __add__




In [49]:

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=device), torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=torch.device("cpu")),
#                         torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

d1 = DualIntervalTensor(torch.tensor([1., 3., 5.], device=device), torch.tensor([2., 4., 7.], device=device),
                        torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

d2 = DualIntervalTensor(torch.tensor([-1., -2., -4.], device=device), torch.tensor([-0.5, 4.3, 3.2], device=device),
                        torch.tensor([2.8, -2., -5.6], device=device), torch.tensor([3., -1.4, 3.], device=device))

print(d1)
print(d2)

print(d1 + d2)
print(d1 + 3)
print(3 + d1)

# print(d1 - d2)
# print(d1 * d2)
# print(tanh(d1))

{   'real_lb': tensor([1., 3., 5.], device='mps:0'),
    'real_ub': tensor([2., 4., 7.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([-1., -2., -4.], device='mps:0'),
    'real_ub': tensor([-0.5000,  4.3000,  3.2000], device='mps:0'),
    'dual_lb': tensor([ 2.8000, -2.0000, -5.6000], device='mps:0'),
    'dual_ub': tensor([ 3.0000, -1.4000,  3.0000], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([0., 1., 1.], device='mps:0'),
    'real_ub': tensor([ 1.5000,  8.3000, 10.2000], device='mps:0'),
    'dual_lb': tensor([  0.8000,  -5.0000, -10.6000], device='mps:0'),
    'dual_ub': tensor([4.0000, 1.6000, 8.0000], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([4., 6., 8.], device='mps:0'),
  