In [29]:
import torch.nn
import geoopt


# package.nn.modules.py
def create_ball(ball=None, c=None):
    """
    Helper to create a PoincareBall.
    Sometimes you may want to share a manifold across layers, e.g. you are using scaled PoincareBall.
    In this case you will require same curvature parameters for different layers or end up with nans.
    Parameters
    ----------
    ball : geoopt.PoincareBall
    c : float
    Returns
    -------
    geoopt.PoincareBall
    """
    if ball is None:
        assert c is not None, "curvature of the ball should be explicitly specified"
        ball = geoopt.PoincareBall(c)
    # else trust input
    return ball


class MobiusLinear(torch.nn.Linear):
    def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        # for manifolds that have parameters like Poincare Ball
        # we have to attach them to the closure Module.
        # It is hard to implement device allocation for manifolds in other case.
        self.ball = create_ball(ball, c)
        if self.bias is not None:
            self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
        self.nonlin = nonlin
        self.reset_parameters()

    def forward(self, input):
        return mobius_linear(
            input,
            weight=self.weight,
            bias=self.bias,
            nonlin=self.nonlin,
            ball=self.ball,
        )

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.eye_(self.weight)
        self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
        if self.bias is not None:
            self.bias.zero_()


# package.nn.functional.py
def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
    output = ball.mobius_matvec(weight, input)
    if bias is not None:
        output = ball.mobius_add(output, bias)
    if nonlin is not None:
        output = ball.logmap0(output)
        output = nonlin(output)
        output = ball.expmap0(output)
    return output

In [31]:
classifier = MobiusLinear()

TypeError: __init__() missing 2 required positional arguments: 'in_features' and 'out_features'

In [None]:
class Distance2PoincareHyperplanes(torch.nn.Module):
    n = 0
    # 1D, 2D versions of this class ara available with a one line change
    # class Distance2PoincareHyperplanes2d(Distance2PoincareHyperplanes):
    #     n = 2

    def __init__(
        self,
        plane_shape: int,
        num_planes: int,
        signed=True,
        squared=False,
        *,
        ball,
        std=1.0,
    ):
        super().__init__()
        self.signed = signed
        self.squared = squared
        # Do not forget to save Manifold instance to the Module
        self.ball = ball
        self.plane_shape = geoopt.utils.size2shape(plane_shape)
        self.num_planes = num_planes

        # In a layer we create Manifold Parameters in the same way we do it for
        # regular pytorch Parameters, there is no difference. But geoopt optimizer
        # will recognize the manifold and adjust to it
        self.points = geoopt.ManifoldParameter(
            torch.empty(num_planes, plane_shape), manifold=self.ball
        )
        self.std = std
        # following best practives, a separate method to reset parameters
        self.reset_parameters()

    def forward(self, input):
        input_p = input.unsqueeze(-self.n - 1)
        points = self.points.permute(1, 0)
        points = points.view(points.shape + (1,) * self.n)

        distance = self.ball.dist2plane(
            x=input_p, p=points, a=points, signed=self.signed, dim=-self.n - 2
        )
        if self.squared and self.signed:
            sign = distance.sign()
            distance = distance ** 2 * sign
        elif self.squared:
            distance = distance ** 2
        return distance

    def extra_repr(self):
        return (
            "plane_shape={plane_shape}, "
            "num_planes={num_planes}, "
            .format(**self.__dict__)
        )

    @torch.no_grad()
    def reset_parameters(self):
        direction = torch.randn_like(self.points)
        direction /= direction.norm(dim=-1, keepdim=True)
        distance = torch.empty_like(self.points[..., 0]).normal_(std=self.std)
        self.points.set_(self.ball.expmap0(direction * distance.unsqueeze(-1)))

In [8]:
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import orthogonal

import geoopt

In [110]:
class UnitaryLinear(torch.nn.Module):
    def __init__(self, num_qubits):
        super().__init__()
        # for manifolds that have parameters like Poincare Ball
        # we have to attach them to the closure Module.
        # It is hard to implement device allocation for manifolds in other case.
        self.man = geoopt.Stiefel()
        self.dim = 2**num_qubits
        self.weight = geoopt.ManifoldParameter(torch.empty(self.dim,self.dim, dtype=torch.cfloat), manifold=geoopt.Stiefel())
        self.reset_parameters()

    def forward(self, input):
        return unitary_linear(
            input,
            weight=self.weight,
            man=self.man,
        )

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.eye_(self.weight)
        #self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))

            
# package.nn.functional.py
def unitary_linear(input, weight,*, man:geoopt.Stiefel):
    #output = ball.mobius_matvec(weight, input)
    output = torch.matmul(weight, input.T).T
    return output

model = UnitaryLinear(num_qubits=1)
optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=0.001)

In [111]:
for layer in model.named_parameters():
    print(layer)

('weight', Parameter on Stiefel(canonical) manifold containing:
tensor([[1.+0.j, 0.+0.j],
        [0.+0.j, 1.+0.j]], requires_grad=True))


In [21]:
# class UnitaryModel(nn.Module):
#     def __init__(self, num_qubits):
#         super(UnitaryModel, self).__init__()
#         d = 2**num_qubits
#         #self.U = orthogonal(nn.Linear(d,d, bias=False, dtype=torch.cfloat), orthogonal_map=None)
#         #self.U = geoopt.ManifoldParameter(self.U.weight, manifold=geoopt.Stiefel())
        
        
#     def forward(self, x):
#         x = self.U(x)
#         return x
    
# model = UnitaryModel(num_qubits=1)
# optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=0.01)

In [112]:
dummy_x = torch.tensor([[1,0]], dtype=torch.cfloat, requires_grad=True)
dummy_y = torch.tensor([[1,-1j]],dtype=torch.cfloat)/torch.sqrt(torch.tensor(2))
#dummy_x = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat, requires_grad=True)
#dummy_y = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat)


print(dummy_x)
print(dummy_y)

tensor([[1.+0.j, 0.+0.j]], requires_grad=True)
tensor([[0.7071+0.0000j, -0.0000-0.7071j]])


In [118]:
len(output)

1

In [115]:
def fidelity_loss(output, target, model, reg=1e-6):
    loss = 0
    print('output: ', output)
    print('target: ', target)
    for i in range(len(output)):
        inner_prod = torch.dot(output[i].T.conj(), target[i])
        #loss += inner_prod.conj() * inner_prod
        loss -= torch.abs(inner_prod)**2 / len(output)
    
    U = list(model.parameters())[0].detach()
    orth_constraint = torch.matmul(U,U.T.conj()) - torch.eye(U.shape[0])
    return loss + reg*orth_constraint.abs().sum()

loss = fidelity_loss(dummy_x, dummy_y, model, 0)
print(loss)
loss.backward()

output:  tensor([[1.+0.j, 0.+0.j]], requires_grad=True)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
tensor(-0.5000, grad_fn=<AddBackward0>)


In [116]:
NUM_EPOCHS = 1000
fidelity = torch.zeros(NUM_EPOCHS, dtype=torch.cfloat)

for i in range(NUM_EPOCHS):
    optimizer.zero_grad()
    output = model(dummy_x)
    loss = fidelity_loss(output, dummy_y, model, reg=0)
    loss.backward()
    optimizer.step()
    fidelity[i] = loss.detach()


output:  tensor([[-0.4162+5.7543e-06j, -0.9093-2.6339e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4180+5.7524e-06j, -0.9085-2.6469e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4198+5.7503e-06j, -0.9076-2.6600e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4217+5.7482e-06j, -0.9068-2.6730e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4235+5.7461e-06j, -0.9059-2.6860e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4253+5.7439e-06j, -0.9051-2.6989e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.4271+5.7416e-06j, -0.9042-2.7119e-06j]],
       grad_fn=<PermuteBackward0>)
tar

output:  tensor([[-0.6467+4.9344e-06j, -0.7628-4.1839e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6483+4.9246e-06j, -0.7615-4.1925e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6498+4.9147e-06j, -0.7602-4.2011e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6513+4.9047e-06j, -0.7589-4.2096e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6528+4.8947e-06j, -0.7576-4.2181e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6543+4.8847e-06j, -0.7562-4.2264e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6558+4.8745e-06j, -0.7549-4.2348e-06j]],
       grad_fn=<PermuteBackward0>)
tar

output:  tensor([[-0.8268+3.3115e-06j, -0.5626-4.8665e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8279+3.2980e-06j, -0.5609-4.8676e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8290+3.2845e-06j, -0.5593-4.8686e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8301+3.2710e-06j, -0.5576-4.8696e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8312+3.2575e-06j, -0.5560-4.8705e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8324+3.2439e-06j, -0.5543-4.8713e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8335+3.2304e-06j, -0.5526-4.8720e-06j]],
       grad_fn=<PermuteBackward0>)
tar

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9507+1.4419e-06j, -0.3103-4.4179e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9513+1.4298e-06j, -0.3084-4.4107e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9519+1.4177e-06j, -0.3065-4.4035e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9525+1.4056e-06j, -0.3046-4.3962e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9531+1.3936e-06j, -0.3027-4.3889e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9537+1.3816e-06j, -0.3008-4.3815e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9543+1.3697e-06j, -0.2989-

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9995+1.0227e-07j, -0.0335-3.0575e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9995+9.5803e-08j, -0.0315-3.0466e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9996+8.9382e-08j, -0.0295-3.0357e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9997+8.3005e-08j, -0.0275-3.0248e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9997+7.6673e-08j, -0.0255-3.0139e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9998+7.0385e-08j, -0.0235-3.0030e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9998+6.4141e-08j, -0.0215-

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9717-4.0260e-07j,  0.2363-1.6551e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9712-4.0389e-07j,  0.2382-1.6461e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9708-4.0516e-07j,  0.2402-1.6371e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9703-4.0639e-07j,  0.2421-1.6281e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9698-4.0760e-07j,  0.2441-1.6191e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9693-4.0878e-07j,  0.2460-1.6102e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.9688-4.0992e-07j,  0.2479-

output:  tensor([[-0.8705-4.3475e-07j,  0.4922-7.6875e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8695-4.3468e-07j,  0.4939-7.6504e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8686-4.3461e-07j,  0.4957-7.6138e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8676-4.3455e-07j,  0.4974-7.5775e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8666-4.3451e-07j,  0.4991-7.5418e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8656-4.3448e-07j,  0.5009-7.5065e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.8646-4.3445e-07j,  0.5026-7.4714e-07j]],
       grad_fn=<PermuteBackward0>)
tar

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.7020-5.8016e-07j,  0.7122-5.7173e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.7006-5.8285e-07j,  0.7136-5.7209e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6992-5.8558e-07j,  0.7150-5.7247e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6977-5.8835e-07j,  0.7164-5.7288e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6963-5.9117e-07j,  0.7178-5.7333e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6949-5.9405e-07j,  0.7192-5.7382e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.6934-5.9698e-07j,  0.7206-

In [117]:
fidelity.abs()

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 

In [107]:
for layer in model.named_parameters():
    print(layer)

('weight', Parameter on Stiefel(canonical) manifold containing:
tensor([[-0.4162+5.7543e-06j,  0.9093+2.6338e-06j],
        [-0.9093-2.6339e-06j, -0.4162+5.7544e-06j]], requires_grad=True))


In [108]:
U = model.state_dict()['weight']
print("Unitarity: ", torch.dist(U.T.conj() @ U, torch.eye(2)))
print(U)

print(fidelity_loss(model(dummy_x), dummy_y, model, reg=0))

Unitarity:  tensor(7.1770e-05)
tensor([[-0.4162+5.7543e-06j,  0.9093+2.6338e-06j],
        [-0.9093-2.6339e-06j, -0.4162+5.7544e-06j]])
tensor(-0.5000, grad_fn=<AddBackward0>)


In [348]:
print(model(dummy_x))
print(dummy_y)

tensor([[7.3449e-01-6.7862e-01j, 1.8477e-06-1.2517e-06j],
        [7.3449e-01-6.7862e-01j, 1.8477e-06-1.2517e-06j],
        [7.3449e-01-6.7862e-01j, 1.8477e-06-1.2517e-06j],
        [7.3449e-01-6.7862e-01j, 1.8477e-06-1.2517e-06j]],
       grad_fn=<MmBackward0>)
tensor([[1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j]])
