In [18]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


In [23]:
class DualIntervalTensor:

    def __init__(self, real_lb: torch.Tensor, real_ub: torch.Tensor, dual_lb: torch.Tensor, dual_ub: torch.Tensor):
        if not (
                real_lb.size() == dual_lb.size() and real_lb.size() == real_ub.size() and dual_lb.size() == dual_ub.size()):
            raise TypeError("Input sizes not match.")

        if not (
                real_lb.device == dual_lb.device and real_lb.device == real_ub.device and dual_lb.device == dual_ub.device):
            raise TypeError("Input devices not match.")

        if not (torch.all(torch.le(real_lb, real_ub)) and torch.all(torch.le(dual_lb, dual_ub))):
            raise ValueError("Some lower bound greater than upper bound.")

        self.real_lb = real_lb
        self.real_ub = real_ub
        self.dual_lb = dual_lb
        self.dual_ub = dual_ub
        self.device = real_lb.device
        self.dtype = real_lb.dtype

    def __repr__(self) -> str:
        from pprint import pformat
        return pformat(vars(self), indent=4, width=1, sort_dicts=False)




In [27]:

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=device), torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=torch.device("cpu")),
#                         torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))


d1 = DualIntervalTensor(torch.tensor([1., 3., 5.], device=device), torch.tensor([2., 4., 7.], device=device),
                        torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

# d2 = DualIntervalTensor()
print(d1)
# print(d1 + d2)
# print(d1 - d2)
# print(d1 * d2)
# print(tanh(d1))

{   'real_lb': tensor([1., 3., 5.], device='mps:0'),
    'real_ub': tensor([2., 4., 7.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
