In [33]:
from __future__ import annotations
import os, sys
sys.path.insert(0, "..")
import torch, math
from HTorch.manifolds import Euclidean, PoincareBall, Lorentz, HalfSpace, Manifold, Sphere
from torch import Tensor
from torch.nn import Parameter
import functools
from typing import Union
from HTorch.MCTensor import MCTensor
from HTorch.HTensor import HTensor

manifold_maps = {
    'Euclidean': Euclidean, 
    'PoincareBall': PoincareBall,
    'Lorentz': Lorentz, 
    'HalfSpace': HalfSpace,
    'Sphere':Sphere
}

In [34]:
torch.set_printoptions(precision=8)

In [35]:
class MCHTensor(MCTensor):
    @staticmethod
    def __new__(cls, *args, manifold='PoincareBall', curvature=-1.0, **kwargs):
        ret = super().__new__(cls, *args, **kwargs)
        if isinstance(manifold, str):
            ret.manifold: Manifold = manifold_maps[manifold]()
        elif isinstance(manifold, Manifold):
            ret.manifold: Manifold = manifold
        else:
            raise NotImplemented
        ret.curvature = curvature
        return ret

    def __init__(self, *args, manifold='PoincareBall', curvature=-1.0, **kwargs):
        super().__init__(*args, **kwargs)
        if isinstance(manifold, str):
            self.manifold: Manifold = manifold_maps[manifold]()
        elif isinstance(manifold, Manifold):
            self.manifold: Manifold = manifold
        else:
            raise NotImplemented
        self.curvature = curvature

    def __repr__(self):
        return "{}, manifold={}, curvature={}".format(
            super().__repr__(), self.manifold.name, self.curvature)

    def to_other_manifold(self, name: str) -> MCHTensor:
        """Convert to the same point on the other manifold."""
        assert name != self.manifold.name
        if name == 'Lorentz':
            ret = self.manifold.to_lorentz(self, abs(self.curvature))
        elif name == 'HalfSpace':
            ret = self.manifold.to_halfspace(self, abs(self.curvature))
        elif name == 'PoincareBall':
            ret = self.manifold.to_poincare(self, abs(self.curvature))
        else:
            raise NotImplemented
        ret.manifold = manifold_maps[name]()
        return ret

    def Hdist(self, other: MCHTensor) -> Tensor:
        """Computes hyperbolic distance to other."""
        assert self.curvature == other.curvature, "Inputs should in models with same curvature!"
        if self.manifold.name == other.manifold.name:
            dist = self.manifold.distance(self, other, abs(self.curvature))
        else:
            #### transform to a self's manifold, combine with lazy evaulation?
            other_ = other.to_other_manifold(self.manifold.name)
            dist = self.manifold.distance(self, other_, abs(self.curvature))
        return dist.as_subclass(Tensor)

    def proj(self) -> MCHTensor:
        """Projects point p on the manifold."""
        return self.manifold.proj(self, abs(self.curvature))

    def proj_(self) -> MCHTensor:
        """Projects point p on the manifold."""
        return self.copy_(self.proj())

    def proj_tan(self, u: Tensor) -> Tensor:
        """Projects u on the tangent space of p."""
        return self.manifold.proj_tan(self, u, abs(self.curvature)).as_subclass(Tensor)

    def proj_tan0(self, u: Tensor) -> Tensor:
        """Projects u on the tangent space of the origin."""
        return self.manifold.proj_tan0(u, abs(self.curvature)).as_subclass(Tensor)

    def expmap(self, x: MCHTensor, u: Tensor) -> MCHTensor:
        """Exponential map."""
        return self.manifold.expmap(x, u, abs(self.curvature))

    def expmap0(self, u: Tensor) -> MCHTensor: ## wrap u to MCHTensor???
        """Exponential map, with x being the origin on the manifold."""
        res = self.manifold.expmap0(u, abs(self.curvature))
        return MCHTensor(res, manifold=self.manifold, curvature=self.curvature, nc=self.nc)

    def logmap(self, x: MCHTensor, y: MCHTensor) -> Tensor:
        """Logarithmic map, the inverse of exponential map."""
        return self.manifold.logmap(x, y, abs(self.curvature)).as_subclass(Tensor)

    def logmap0(self, y: MCHTensor) -> Tensor:
        """Logarithmic map, where x is the origin."""
        return self.manifold.logmap0(y, abs(self.curvature)).as_subclass(Tensor)

    def mobius_add(self, x: MCHTensor, y: MCHTensor, dim: int = -1) -> MCHTensor:
        """Performs hyperboic addition, adds points x and y."""
        return self.manifold.mobius_add(x, y, abs(self.curvature), dim=dim)

    def mobius_matvec(self, m: Tensor, x: MCHTensor) -> MCHTensor:
        """Performs hyperboic martrix-vector multiplication to m (matrix)."""
        return self.manifold.mobius_matvec(m, x, abs(self.curvature))

    def check_(self) -> Tensor:
        """Check if point on the specified manifold, project to the manifold if not."""
        check_result = self.manifold.check(
            self, abs(self.curvature)).as_subclass(Tensor)
        if not check_result:
            print('Warning: data not on the manifold, projecting ...')
            self.proj_()
        return check_result
    
    @staticmethod
    def find_mani_cur(args):
        for arg in args:
            if isinstance(arg, list) or isinstance(arg, tuple):
                # Recursively apply the function to each element of the list
                manifold, curvature = MCHTensor.find_mani_cur(arg)
                break
            elif isinstance(arg, MCHTensor):
                manifold, curvature = arg.manifold, arg.curvature
                break
        return manifold, curvature
    
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        tmp = super().__torch_function__(func, types, args, kwargs)
        if type(tmp) in [MCTensor, MCHTensor] and not hasattr(tmp, 'manifold'):
            ret = cls(tmp)
            ret._nc, ret.res = tmp.nc, tmp.res
            ret.manifold, ret.curvature = cls.find_mani_cur(args)
            return ret
        return tmp

In [36]:
x = MCTensor(torch.arange(30).reshape(3, 5, 2), nc=2)
print(x.index_select(-1, torch.tensor(0)))

print(torch.arange(30).reshape(3, 5, 2).index_select(-1, torch.tensor(0)))

MCTensor([[[ 0],
           [ 2],
           [ 4],
           [ 6],
           [ 8]],

          [[10],
           [12],
           [14],
           [16],
           [18]],

          [[20],
           [22],
           [24],
           [26],
           [28]]]), nc=2
tensor([[[ 0],
         [ 2],
         [ 4],
         [ 6],
         [ 8]],

        [[10],
         [12],
         [14],
         [16],
         [18]],

        [[20],
         [22],
         [24],
         [26],
         [28]]])


In [37]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
x.res.data.add_(1e-3)
x.normalize_()
print(x, x.res, x.shape)

x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="HalfSpace")
x.res.data.add_(1e-3)
x.normalize_()
print(x, x.res, x.shape)

MCHTensor([0.20200001, 0.10200001, 0.00200000]), nc=3, manifold=Lorentz, curvature=-1.0 tensor([[-3.95812094e-09,  0.00000000e+00],
        [-3.95812094e-09,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00]]) torch.Size([3])
MCHTensor([0.20200001, 0.10200001, 0.00200000]), nc=3, manifold=HalfSpace, curvature=-1.0 tensor([[-3.95812094e-09,  0.00000000e+00],
        [-3.95812094e-09,  0.00000000e+00],
        [ 0.00000000e+00,  0.00000000e+00]]) torch.Size([3])


In [38]:
mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(mc_x.proj())

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.proj())

mc_x = MCHTensor([0.2, 0.1, 0.1], nc=3, manifold="HalfSpace")
print(mc_x.proj().data)

x = HTensor([0.2, 0.1, 0.1], manifold="HalfSpace")
print(x.proj())

MCHTensor([0.20000000, 0.10000000, 1.02469516]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([0.20000000, 0.10000000, 1.02469504]), manifold=Lorentz, curvature=-1.0
MCHTensor([0.20000000, 0.10000000, 0.10000000]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([0.20000000, 0.10000000, 0.10000000]), manifold=HalfSpace, curvature=-1.0


In [39]:
mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(mc_x.proj_tan(torch.tensor([0.2, 0.1, 0.3])))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.proj_tan(torch.tensor([0.2, 0.1, 0.3])))

mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="HalfSpace")
print(mc_x.proj_tan(torch.tensor([0.2, 0.1, 0.3])))

x = HTensor([0.2, 0.1, 0.0], manifold="HalfSpace")
print(x.proj_tan(torch.tensor([0.2, 0.1, 0.3])))


tensor([2.00000003e-01, 1.00000001e-01, 5.00000000e+05])
tensor([2.00000003e-01, 1.00000001e-01, 5.00000031e+05])
tensor([0.20000000, 0.10000000, 0.30000001])
tensor([0.20000000, 0.10000000, 0.30000001])


In [40]:
mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(mc_x.proj_tan0(torch.tensor([0.2, 0.1, 0.3])))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.proj_tan0(torch.tensor([0.2, 0.1, 0.3])))

mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="HalfSpace")
print(mc_x.proj_tan0(torch.tensor([0.2, 0.1, 0.3])))

x = HTensor([0.2, 0.1, 0.0], manifold="HalfSpace")
print(x.proj_tan0(torch.tensor([0.2, 0.1, 0.3])))


tensor([0.20000000, 0.10000000, 0.00000000])
tensor([0.20000000, 0.10000000, 0.00000000])
tensor([0.20000000, 0.10000000, 0.30000001])
tensor([0.20000000, 0.10000000, 0.30000001])


In [41]:
mc_x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
exp_mc_x = mc_x.expmap(mc_x, torch.tensor([1.3254, 0.6693, 0.0000]))
print(exp_mc_x)

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000])))

mc_x = MCHTensor([0.2, 0.1, 0.05], nc=3, manifold="HalfSpace")
exp_mc_x = mc_x.expmap(mc_x, torch.tensor([1.3254, 0.6693, 0.0000]))
print(exp_mc_x)

x = HTensor([0.2, 0.1, 0.05], manifold="HalfSpace")
print(x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000])))

MCHTensor([2.33306193, 1.17583787, 2.79745817]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([2.33306193, 1.17583787, 2.79745817]), manifold=Lorentz, curvature=-1.0
MCTensor([2.44632110e-01, 1.22538306e-01, 3.05902361e-08]), nc=3
Hyperbolic HTensor([2.44632110e-01, 1.22538306e-01, 3.05902326e-08]), manifold=HalfSpace, curvature=-1.0


In [42]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
exp_x = x.expmap0(torch.tensor([1.3254, 0.6693, 0.0000]))
print(exp_x)

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.expmap0(torch.tensor([1.3254, 0.6693, 0.0000])))

x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="HalfSpace")
exp_x = x.expmap0(torch.tensor([1.3254, 0.6693, 0.0000]))
print(exp_x)

x = HTensor([0.2, 0.1, 0.0], manifold="HalfSpace")
print(x.expmap0(torch.tensor([1.3254, 0.6693, 0.0000])))

MCHTensor([1.86899650, 0.94380516, 2.32032681]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([1.86899650, 0.94380516, 2.32032681]), manifold=Lorentz, curvature=-1.0
MCHTensor([0.80548853, 0.40675530, 0.43097377]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([0.80548853, 0.40675530, 0.43097377]), manifold=HalfSpace, curvature=-1.0


In [43]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
norm_t = x.manifold.norm_t(x)
print(norm_t)

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
norm_t = x.manifold.norm_t(x)
print(norm_t)

x = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
norm_t = x.manifold.norm_t(x, x, abs(x.curvature))
print(norm_t)

x = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
norm_t = x.manifold.norm_t(x, x, abs(x.curvature))
print(norm_t)

MCTensor([0.22360681]), nc=3
Hyperbolic HTensor([0.22360681]), manifold=Lorentz, curvature=-1.0
MCHTensor([1.24721920]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([1.24721920]), manifold=HalfSpace, curvature=-1.0


In [44]:
x = MCHTensor([1.2, 1.1, 1.5], nc=3, manifold="Lorentz")
sq_dist = x.manifold.distance(x,  x + 2, 1)
print(sq_dist)

x = HTensor([1.2, 1.1, 1.5], manifold="Lorentz")
sq_dist = x.manifold.distance(x, x + 2, 1)
print(sq_dist)

x = MCHTensor([0.2, 0.1, 0.5], nc=3, manifold="HalfSpace")
sq_dist = x.manifold.distance(x,  x + 2, 1)
print(sq_dist)

x = HTensor([0.2, 0.1, 0.5], manifold="HalfSpace")
sq_dist = x.manifold.distance(x, x + 2, 1)
print(sq_dist)

MCTensor([0.00048828]), nc=3
Hyperbolic HTensor([0.00048828]), manifold=Lorentz, curvature=-1.0
MCHTensor([2.44348931]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([2.44348931]), manifold=HalfSpace, curvature=-1.0


In [65]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
exp_u = x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000]))
log_x = x.logmap(x, exp_u)
print(log_x)

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
exp_u = x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000]))
print(x.logmap(x, exp_u))

x = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
exp_u = x.expmap(x, torch.tensor([0.3254, 0.6693, 0.1000]))
log_x = x.logmap(x, exp_u)
print(log_x)

x = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
exp_u = x.expmap(x, torch.tensor([0.3254, 0.6693, 0.1000]))
print(x.logmap(x, exp_u))


tensor([3.29362011e+00, 1.66118062e+00, 8.24842100e+06])
tensor([3.29362011e+00, 1.66118073e+00, 8.24842100e+06])
tensor([[0.32539895, 0.66929781, 0.10000819]])
tensor([0.32539999, 0.66930002, 0.09999998])


In [69]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
exp_u = x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000]))
log_x = x.logmap0(exp_u)
print(log_x)

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
exp_u = x.expmap(x, torch.tensor([1.3254, 0.6693, 0.0000]))
print(x.logmap0(exp_u))

x = MCHTensor([0.2, 0.1, 0.5], nc=3, manifold="HalfSpace")
exp_u = x.expmap(x, torch.tensor([0.3254, 0.3693, 0.1000]))
log_x = x.logmap0(exp_u)
print(log_x)

x = HTensor([0.2, 0.1, 0.5], manifold="HalfSpace")
exp_u = x.expmap(x, torch.tensor([0.3254, 0.3693, 0.1000]))
print(x.logmap0(exp_u))


tensor([1.50761509, 0.75982159, 0.00000000])
tensor([1.50761521, 0.75982165, 0.00000000])
tensor([[ 0.96742433,  0.84807873, -0.42057699]])
tensor([ 0.96744061,  0.84809279, -0.42060983])


In [47]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(x.mobius_add(x.unsqueeze(0), x.unsqueeze(0)))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.mobius_add(x.unsqueeze(0), x.unsqueeze(0)))


MCHTensor([[0.20000002, 0.10000001, 1.02469516]]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([[0.20000000, 0.10000000, 1.02469504]]), manifold=Lorentz, curvature=-1.0


In [71]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(x.mobius_matvec(torch.arange(6).reshape(2, 3).float(), x.unsqueeze(0)))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.mobius_matvec(torch.arange(6).reshape(2, 3).float(), x.unsqueeze(0)))

x = MCHTensor([0.2, 0.1, 0.5], nc=3, manifold="HalfSpace")
print(x.mobius_matvec(torch.arange(6).reshape(2, 3).float(), x.unsqueeze(0)))

x = HTensor([0.2, 0.1, 0.5], manifold="HalfSpace")
print(x.mobius_matvec(torch.arange(6).reshape(2, 3).float(), x.unsqueeze(0)))


MCHTensor([[2.18366040e-04, 1.00000000e+00]]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([[2.18365996e-04, 1.00000000e+00]]), manifold=Lorentz, curvature=-1.0
MCHTensor([[-0.33889443,  0.19450326]]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([[-0.33891195,  0.19457309]]), manifold=HalfSpace, curvature=-1.0


In [72]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(x.Hdist(MCHTensor([0.1, 0, 0.3], nc=3, manifold="Lorentz")))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.Hdist(HTensor([0.1, 0, 0.3], manifold="Lorentz")))

x = MCHTensor([0.2, 0.1, 0.5], nc=3, manifold="HalfSpace")
print(x.Hdist(MCHTensor([0.1, 0, 0.3], nc=3, manifold="HalfSpace")))

x = HTensor([0.2, 0.1, 0.5], manifold="HalfSpace")
print(x.Hdist(HTensor([0.1, 0, 0.3], manifold="HalfSpace")))


tensor([0.00048828])
tensor([0.00048828])
tensor([0.62236243])
tensor([0.62236255])


In [73]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
print(x.manifold.origin(2, 1, size=(2,2)))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
print(x.manifold.origin(2, 1, size=(2,2)))

x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="HalfSpace")
print(x.manifold.origin(2, 1, size=(2, 2)))

x = HTensor([0.2, 0.1, 0.0], manifold="HalfSpace")
print(x.manifold.origin(2, 1, size=(2, 2)))


tensor([[[0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.]]])
tensor([[[0., 0., 1.],
         [0., 0., 1.]],

        [[0., 0., 1.],
         [0., 0., 1.]]])
tensor([[[0., 1.],
         [0., 1.]],

        [[0., 1.],
         [0., 1.]]])
tensor([[[0., 1.],
         [0., 1.]],

        [[0., 1.],
         [0., 1.]]])


In [76]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
u = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="Lorentz")
v = MCHTensor([0.2, 0.2, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.inner(u, v, x, 1.0))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
u = HTensor([0.2, 0.1, 0.3], manifold="Lorentz")
v = HTensor([0.2, 0.2, 0.3], manifold="Lorentz")
print(x.manifold.inner(u, v, x, 1.0))

x = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
u = MCHTensor([0.2, 0.1, 0.4], nc=3, manifold="HalfSpace")
v = MCHTensor([0.2, 0.2, 0.5], nc=3, manifold="HalfSpace")
print(x.manifold.inner(u, v, x, 1.0))

x = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
u = HTensor([0.2, 0.1, 0.4], manifold="HalfSpace")
v = HTensor([0.2, 0.2, 0.5], manifold="HalfSpace")
print(x.manifold.inner(u, v, x, 1.0))


MCTensor([-0.03000000]), nc=3
Hyperbolic HTensor([-0.03000000]), manifold=Lorentz, curvature=-1.0
MCHTensor([2.88888860]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([2.88888860]), manifold=HalfSpace, curvature=-1.0


In [77]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
u = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.norm_t(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
u = HTensor([0.2, 0.1, 0.3], manifold="Lorentz")
print(x.manifold.norm_t(u, x, 1.0))

x = MCHTensor([0.2, 0.1, 0.2], nc=3, manifold="HalfSpace")
u = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
print(x.manifold.norm_t(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.2], manifold="HalfSpace")
u = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
print(x.manifold.norm_t(u, x, 1.0))


MCTensor([0.00031623]), nc=3
Hyperbolic HTensor([0.00031623]), manifold=Lorentz, curvature=-1.0
MCHTensor([1.87082875]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([1.87082887]), manifold=HalfSpace, curvature=-1.0


In [78]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
y = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.sqdist(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
y = HTensor([0.2, 0.1, 0.3], manifold="Lorentz")
print(x.manifold.sqdist(u, x, 1.0))

x = MCHTensor([0.2, 0.1, 0.15], nc=3, manifold="HalfSpace")
y = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
print(x.manifold.sqdist(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.15], manifold="HalfSpace")
y = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
print(x.manifold.sqdist(u, x, 1.0))


MCTensor([2.38418579e-07]), nc=3
Hyperbolic HTensor([2.38418579e-07]), manifold=HalfSpace, curvature=-1.0
MCHTensor([0.48045301]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([0.48045301]), manifold=HalfSpace, curvature=-1.0


In [79]:
x = MCHTensor([0.2, 0.1, 0.0], nc=3, manifold="Lorentz")
y = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.distance(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.0], manifold="Lorentz")
y = HTensor([0.2, 0.1, 0.3], manifold="Lorentz")
print(x.manifold.distance(u, x, 1.0))

x = MCHTensor([0.2, 0.1, 0.15], nc=3, manifold="HalfSpace")
y = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="HalfSpace")
print(x.manifold.distance(u, x, 1.0))

x = HTensor([0.2, 0.1, 0.15], manifold="HalfSpace")
y = HTensor([0.2, 0.1, 0.3], manifold="HalfSpace")
print(x.manifold.distance(u, x, 1.0))


MCTensor([0.00048828]), nc=3
Hyperbolic HTensor([0.00048828]), manifold=HalfSpace, curvature=-1.0
MCHTensor([0.69314718]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([0.69314718]), manifold=HalfSpace, curvature=-1.0


In [80]:
x = MCHTensor([0.2, 0.1, 0.4], nc=3, manifold="Lorentz")
dx = MCHTensor([0.2, 0.1, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.egrad2rgrad(x, dx, 1.0))

x = HTensor([0.2, 0.1, 0.4], manifold="Lorentz")
dx = HTensor([0.2, 0.1, 0.3], manifold="Lorentz")
print(x.manifold.egrad2rgrad(x, dx, 1.0))

x = MCHTensor([[0.2, 0.1, 0.4]], nc=3, manifold="HalfSpace")
dx = MCHTensor([[0.2, 0.1, 0.3]], nc=3, manifold="HalfSpace")
print(x.manifold.egrad2rgrad(x, dx, 1.0))

x = HTensor([[0.2, 0.1, 0.4]], manifold="HalfSpace")
dx = HTensor([[0.2, 0.1, 0.3]], manifold="HalfSpace")
print(x.manifold.egrad2rgrad(x, dx, 1.0))


MCHTensor([0.18600000, 0.09300000, 0.27200001]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([ 0.23400001,  0.11700001, -0.23199999]), manifold=Lorentz, curvature=-1.0
MCHTensor([[0.03200000, 0.01600000, 0.04800000]]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([[0.03200000, 0.01600000, 0.04800000]]), manifold=HalfSpace, curvature=-1.0


In [81]:
x = MCHTensor([0.2, 0.3, 0.4], nc=3, manifold="Lorentz")
y = MCHTensor([0.3, 0.1, 0.3], nc=3, manifold="Lorentz")
v = MCHTensor([0.2, 0.05, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.ptransp(x, y, v, 1.0))

x = HTensor([0.2, 0.3, 0.4], manifold="Lorentz")
y = HTensor([0.3, 0.1, 0.3], manifold="Lorentz")
v = HTensor([0.2, 0.05, 0.3], manifold="Lorentz")
print(x.manifold.ptransp(x, y, v, 1.0))

x = MCHTensor([0.2, 0.3, 0.4], nc=3, manifold="HalfSpace")
y = MCHTensor([0.3, 0.1, 0.3], nc=3, manifold="HalfSpace")
v = MCHTensor([0.2, 0.05, 0.2], nc=3, manifold="HalfSpace")
print(x.manifold.ptransp(x, y, v, 1.0))

x = HTensor([0.2, 0.3, 0.4], manifold="HalfSpace")
y = HTensor([0.3, 0.1, 0.3], manifold="HalfSpace")
v = HTensor([0.2, 0.05, 0.2], manifold="HalfSpace")
print(x.manifold.ptransp(x, y, v, 1.0))


MCHTensor([0.20000005, 0.05000006, 0.21666673]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([0.20000005, 0.05000010, 0.21666674]), manifold=Lorentz, curvature=-1.0
MCHTensor([[0.19567901, 0.05864198, 0.19135801]]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([0.19567901, 0.05864199, 0.19135801]), manifold=HalfSpace, curvature=-1.0


In [82]:
x = MCHTensor([0.2, 0.3, 0.4], nc=3, manifold="Lorentz")
v = MCHTensor([0.2, 0.05, 0.3], nc=3, manifold="Lorentz")
print(x.manifold.ptransp0(x, v, 1.0))

x = HTensor([0.2, 0.3, 0.4], manifold="Lorentz")
v = HTensor([0.2, 0.05, 0.3], manifold="Lorentz")
print(x.manifold.ptransp0(x, v, 1.0))

x = MCHTensor([[0.2, 0.3, 0.4]], nc=3, manifold="HalfSpace")
v = MCHTensor([[0.2, 0.05, 0.3]], nc=3, manifold="HalfSpace")
print(x.manifold.ptransp0(x, v, 1.0))

x = HTensor([[0.2, 0.3, 0.4]], manifold="HalfSpace")
v = HTensor([[0.2, 0.05, 0.3]], manifold="HalfSpace")
print(x.manifold.ptransp0(x, v, 1.0))

MCHTensor([ 0.14923078, -0.02615384,  0.05500000]), nc=3, manifold=Lorentz, curvature=-1.0
Hyperbolic HTensor([ 0.14923078, -0.02615385,  0.05500000]), manifold=Lorentz, curvature=-1.0
MCHTensor([[0.23379321, 0.10068983, 0.26761466]]), nc=3, manifold=HalfSpace, curvature=-1.0
Hyperbolic HTensor([[0.23379359, 0.10069039, 0.26761448]]), manifold=HalfSpace, curvature=-1.0


### Attentions:
1. test standard functions in MCTensor work
2. test HTensor functions
3. include arcosh, arsinh

### Test MCHTensor function 
In order to test a function, take it out from manifold definition, and run line by line to go through, 
until it returns a correct result, cross-check with torch results, below is an example;

In [60]:
def sq_norm(x, keepdim=True):
    return torch.norm(x, p=2, dim=-1, keepdim=keepdim) ** 2

def my_proj(x:Tensor, c:Union[float,Tensor]) -> Tensor:
    """Project a point outside manifold to the Lorentz manifold """
    d = x.size(-1) - 1
    y = x.narrow(-1, 0, d)
    print(y.shape)
    y_sqnorm = sq_norm(y)[..., 0] 
    print(y_sqnorm.shape)
    mask = torch.ones_like(x)
    mask[..., -1] = 0
    vals = torch.zeros_like(x)
    print(vals[..., -1])
    print((1. / c + y_sqnorm).shape)
    # print(torch.sqrt().shape)
    print(torch.sqrt(torch.clamp(1. / c + y_sqnorm, min=1e-7)))
    vals[..., -1] = torch.sqrt(torch.clamp(1. / c + y_sqnorm, min=1e-7))
    return vals + mask * x

### Line by line walk through the function

In [61]:
x = MCHTensor([0.2, 0.1, 0.0], nc=2, manifold="Lorentz")
x.res.data.add_(1e-3)
x.normalize_()
proj_x = my_proj(x, abs(x.curvature))
d = x.size(-1) - 1
y = x.narrow(-1, 0, d)
# print(y)
y_sqnorm = sq_norm(y)[..., 0]
# print(y_sqnorm)
mask = torch.ones_like(x)
mask[..., -1] = 0
vals = torch.zeros_like(x)
result = vals + mask * x
# print(result)

torch.Size([2])
torch.Size([])
MCTensor(0.), nc=2
torch.Size([])
MCTensor(1.02498877), nc=2


In [83]:
x = MCHTensor([0.2, 0.1, 0.3], nc=2, manifold="HalfSpace")
x.res.data.add_(1e-3)
x.normalize_()
proj_x = my_proj(x, abs(x.curvature))
d = x.size(-1) - 1
y = x.narrow(-1, 0, d)
# print(y)
y_sqnorm = sq_norm(y)[..., 0]
# print(y_sqnorm)
mask = torch.ones_like(x)
mask[..., -1] = 0
vals = torch.zeros_like(x)
result = vals + mask * x
# print(result)

torch.Size([2])
torch.Size([])
MCTensor(0.), nc=2
torch.Size([])
MCTensor(1.02498877), nc=2


### call the corresponding function

In [62]:
y = x.proj()
print(y, y.res)

MCHTensor([0.20100001, 0.10100000, 1.02498877]), nc=2, manifold=Lorentz, curvature=-1.0 tensor([[-1.97906047e-09],
        [-1.97906047e-09],
        [ 3.15145243e-08]])


#### please test through all manifolds functions, check, inner, ....

In [63]:
x = torch.tensor([0.2, 0.1, 0.05])
hx = HTensor(x, manifold='Lorentz', curvature=-1)
mc_x = MCHTensor([0.2, 0.1, 0.0], nc=2, manifold="Lorentz", curvature=-1)

def proj_tan(x:Tensor, v:Tensor, c:Union[float,Tensor]) -> Tensor:
    # not the standard way as x + c<x, dx>_L * x, here only modify the last dimension
    d = x.size(-1) - 1
    ux = torch.sum(x.narrow(-1, 0, d) * v.narrow(-1, 0, d), dim=-1)
    mask = torch.ones_like(v)
    mask[..., -1] = 0
    vals = torch.zeros_like(v)
    vals[..., -1] = ux / torch.clamp(x[..., -1], min=1e-6)
    return vals + mask * v


proj_tan(mc_x, mc_x, 1)

MCHTensor([2.00000003e-01, 1.00000001e-01, 5.00000000e+04]), nc=2, manifold=Lorentz, curvature=-1

#### After testing, move the MCHTensor definitions to MCHTensor.py