In [127]:
import torch

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

In [155]:
def imatmul(a1: torch.Tensor, a2: torch.Tensor, b1: torch.Tensor, b2: torch.Tensor):
    """
    Also Algorithm 4.8 in https://www.tuhh.de/ti3/paper/rump/Ru11a.pdf.
    :param a1: LHS lower bound
    :param a2: LHS upper bound
    :param b1: RHS lower bound
    :param b2: RHS upper bound
    :return: resulting lower & upper bounds
    """
    m_a = (a1 + a2) / 2
    r_a = m_a - a1
    m_b = (b1 + b2) / 2
    r_b = m_b - b1
    rho_a = torch.sign(m_a) * torch.min(torch.abs(m_a), r_a)
    rho_b = torch.sign(m_b) * torch.min(torch.abs(m_b), r_b)
    r_c = torch.abs(m_a) @ r_b + r_a @ (torch.abs(m_b) + r_b) + (-torch.abs(rho_a)) @ torch.abs(rho_b)
    _c2 = m_a @ m_b + rho_a @ rho_b + r_c
    _c1 = m_a @ m_b + rho_a @ rho_b - r_c
    return _c1, _c2


In [158]:
t1_l = torch.Tensor([1., 2., 3.])
t1_u = torch.Tensor([3., 4., 6.])
t2_l = torch.Tensor([-1., -2., -4.])
t2_u = torch.Tensor([3., 2., 4])

print(imatmul(t1_l, t1_u, t2_l, t2_u))


(tensor(-35.), tensor(41.))


In [130]:
class DualIntervalTensor:

    def __init__(self, real_lb: torch.Tensor, real_ub: torch.Tensor, dual_lb: torch.Tensor, dual_ub: torch.Tensor):
        if not (
                real_lb.size() == dual_lb.size() and real_lb.size() == real_ub.size() and dual_lb.size() == dual_ub.size()):
            raise TypeError("Input sizes not match.")

        if not (
                real_lb.device == dual_lb.device and real_lb.device == real_ub.device and dual_lb.device == dual_ub.device):
            raise TypeError("Input devices not match.")

        if not (torch.all(torch.le(real_lb, real_ub)) and torch.all(torch.le(dual_lb, dual_ub))):
            raise ValueError("Some lower bound greater than upper bound.")

        self.real_lb = real_lb
        self.real_ub = real_ub
        self.dual_lb = dual_lb
        self.dual_ub = dual_ub
        self.device = real_lb.device
        self.dtype = real_lb.dtype

    def __repr__(self) -> str:
        from pprint import pformat
        return pformat(vars(self), indent=4, width=1, sort_dicts=False)

    def __add__(self, other):
        if isinstance(other, self.__class__):
            new_real_lb = self.real_lb + other.real_lb
            new_real_ub = self.real_ub + other.real_ub
            new_dual_lb = self.dual_lb + other.dual_lb
            new_dual_ub = self.dual_ub + other.dual_ub
            return DualIntervalTensor(new_real_lb, new_real_ub, new_dual_lb, new_dual_ub)
        if isinstance(other, (int, float)):
            return DualIntervalTensor(self.real_lb + other, self.real_ub + other, self.dual_lb, self.dual_ub)
        raise TypeError(f"Unsupported operation `+` for class {self.__class__} and {type(other)}")

    __radd__ = __add__

    def __neg__(self):
        return DualIntervalTensor(-self.real_ub, -self.real_lb, -self.dual_ub, -self.dual_lb)

    def __sub__(self, other):
        return self + -other

    def __rsub__(self, other):
        return other + -self


In [131]:

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=device), torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

# d0 = DualIntervalTensor(torch.tensor([1., 3., 9.], device=torch.device("cpu")),
#                         torch.tensor([2., 4., 7.], device=device),
#                         torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

d1 = DualIntervalTensor(torch.tensor([1., 3., 5.], device=device), torch.tensor([2., 4., 7.], device=device),
                        torch.tensor([-2., -3., -5.], device=device), torch.tensor([1., 3., 5.], device=device))

d2 = DualIntervalTensor(torch.tensor([-1., -2., -4.], device=device), torch.tensor([-0.5, 4.3, 3.2], device=device),
                        torch.tensor([2.8, -2., -5.6], device=device), torch.tensor([3., -1.4, 3.], device=device))

print(d1)
print(d2)



{   'real_lb': tensor([1., 3., 5.], device='mps:0'),
    'real_ub': tensor([2., 4., 7.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([-1., -2., -4.], device='mps:0'),
    'real_ub': tensor([-0.5000,  4.3000,  3.2000], device='mps:0'),
    'dual_lb': tensor([ 2.8000, -2.0000, -5.6000], device='mps:0'),
    'dual_ub': tensor([ 3.0000, -1.4000,  3.0000], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}


In [132]:
print(d1 + d2)
print(d1 + 3)
print(3 + d1)



{   'real_lb': tensor([0., 1., 1.], device='mps:0'),
    'real_ub': tensor([ 1.5000,  8.3000, 10.2000], device='mps:0'),
    'dual_lb': tensor([  0.8000,  -5.0000, -10.6000], device='mps:0'),
    'dual_ub': tensor([4.0000, 1.6000, 8.0000], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([4., 6., 8.], device='mps:0'),
    'real_ub': tensor([ 5.,  7., 10.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([4., 6., 8.], device='mps:0'),
    'real_ub': tensor([ 5.,  7., 10.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}


In [133]:
print(-d1)

print(d1 - d2)
print(d1 - 4)
print(4 - d1)

{   'real_lb': tensor([-2., -4., -7.], device='mps:0'),
    'real_ub': tensor([-1., -3., -5.], device='mps:0'),
    'dual_lb': tensor([-1., -3., -5.], device='mps:0'),
    'dual_ub': tensor([2., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([ 1.5000, -1.3000,  1.8000], device='mps:0'),
    'real_ub': tensor([ 3.,  6., 11.], device='mps:0'),
    'dual_lb': tensor([-5.0000, -1.6000, -8.0000], device='mps:0'),
    'dual_ub': tensor([-1.8000,  5.0000, 10.6000], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([-3., -1.,  1.], device='mps:0'),
    'real_ub': tensor([-2.,  0.,  3.], device='mps:0'),
    'dual_lb': tensor([-2., -3., -5.], device='mps:0'),
    'dual_ub': tensor([1., 3., 5.], device='mps:0'),
    'device': device(type='mps', index=0),
    'dtype': torch.float32}
{   'real_lb': tensor([ 2.,  0., -3.], device='mps:0'),
    'real_ub': tensor([ 3.,  

In [134]:
# print(d1 * d2)
# print(tanh(d1))


In [135]:
s1 = torch.sign(t1)
s2 = torch.sign(t2)

print(t1 / 2)
print(s2)


tensor([0.5000, 1.0000, 1.5000])
tensor([-1., -1., -1.])
