In [29]:
import torch.nn
import geoopt


# package.nn.modules.py
def create_ball(ball=None, c=None):
    """
    Helper to create a PoincareBall.
    Sometimes you may want to share a manifold across layers, e.g. you are using scaled PoincareBall.
    In this case you will require same curvature parameters for different layers or end up with nans.
    Parameters
    ----------
    ball : geoopt.PoincareBall
    c : float
    Returns
    -------
    geoopt.PoincareBall
    """
    if ball is None:
        assert c is not None, "curvature of the ball should be explicitly specified"
        ball = geoopt.PoincareBall(c)
    # else trust input
    return ball


class MobiusLinear(torch.nn.Linear):
    def __init__(self, *args, nonlin=None, ball=None, c=1.0, **kwargs):
        super().__init__(*args, **kwargs)
        # for manifolds that have parameters like Poincare Ball
        # we have to attach them to the closure Module.
        # It is hard to implement device allocation for manifolds in other case.
        self.ball = create_ball(ball, c)
        if self.bias is not None:
            self.bias = geoopt.ManifoldParameter(self.bias, manifold=self.ball)
        self.nonlin = nonlin
        self.reset_parameters()

    def forward(self, input):
        return mobius_linear(
            input,
            weight=self.weight,
            bias=self.bias,
            nonlin=self.nonlin,
            ball=self.ball,
        )

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.eye_(self.weight)
        self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))
        if self.bias is not None:
            self.bias.zero_()


# package.nn.functional.py
def mobius_linear(input, weight, bias=None, nonlin=None, *, ball: geoopt.PoincareBall):
    output = ball.mobius_matvec(weight, input)
    if bias is not None:
        output = ball.mobius_add(output, bias)
    if nonlin is not None:
        output = ball.logmap0(output)
        output = nonlin(output)
        output = ball.expmap0(output)
    return output

In [31]:
classifier = MobiusLinear()

TypeError: __init__() missing 2 required positional arguments: 'in_features' and 'out_features'

In [None]:
class Distance2PoincareHyperplanes(torch.nn.Module):
    n = 0
    # 1D, 2D versions of this class ara available with a one line change
    # class Distance2PoincareHyperplanes2d(Distance2PoincareHyperplanes):
    #     n = 2

    def __init__(
        self,
        plane_shape: int,
        num_planes: int,
        signed=True,
        squared=False,
        *,
        ball,
        std=1.0,
    ):
        super().__init__()
        self.signed = signed
        self.squared = squared
        # Do not forget to save Manifold instance to the Module
        self.ball = ball
        self.plane_shape = geoopt.utils.size2shape(plane_shape)
        self.num_planes = num_planes

        # In a layer we create Manifold Parameters in the same way we do it for
        # regular pytorch Parameters, there is no difference. But geoopt optimizer
        # will recognize the manifold and adjust to it
        self.points = geoopt.ManifoldParameter(
            torch.empty(num_planes, plane_shape), manifold=self.ball
        )
        self.std = std
        # following best practives, a separate method to reset parameters
        self.reset_parameters()

    def forward(self, input):
        input_p = input.unsqueeze(-self.n - 1)
        points = self.points.permute(1, 0)
        points = points.view(points.shape + (1,) * self.n)

        distance = self.ball.dist2plane(
            x=input_p, p=points, a=points, signed=self.signed, dim=-self.n - 2
        )
        if self.squared and self.signed:
            sign = distance.sign()
            distance = distance ** 2 * sign
        elif self.squared:
            distance = distance ** 2
        return distance

    def extra_repr(self):
        return (
            "plane_shape={plane_shape}, "
            "num_planes={num_planes}, "
            .format(**self.__dict__)
        )

    @torch.no_grad()
    def reset_parameters(self):
        direction = torch.randn_like(self.points)
        direction /= direction.norm(dim=-1, keepdim=True)
        distance = torch.empty_like(self.points[..., 0]).normal_(std=self.std)
        self.points.set_(self.ball.expmap0(direction * distance.unsqueeze(-1)))

In [1]:
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import orthogonal

import geoopt

In [2]:
class UnitaryLinear(torch.nn.Module):
    def __init__(self, num_qubits):
        super().__init__()
        # for manifolds that have parameters like Poincare Ball
        # we have to attach them to the closure Module.
        # It is hard to implement device allocation for manifolds in other case.
        self.man = geoopt.Stiefel()
        self.dim = 2**num_qubits
        self.weight = geoopt.ManifoldParameter(torch.empty(self.dim,self.dim, dtype=torch.cfloat), manifold=geoopt.Stiefel())
        self.reset_parameters()

    def forward(self, input):
        return unitary_linear(
            input,
            weight=self.weight,
            man=self.man,
        )

    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.eye_(self.weight)
        #self.weight.add_(torch.rand_like(self.weight).mul_(1e-3))

            
# package.nn.functional.py
def unitary_linear(input, weight,*, man:geoopt.Stiefel):
    #output = ball.mobius_matvec(weight, input)
    output = torch.matmul(weight, input.T).T
    return output

model = UnitaryLinear(num_qubits=1)
optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=0.001)

In [3]:
for layer in model.named_parameters():
    print(layer)

('weight', Parameter on Stiefel(canonical) manifold containing:
tensor([[1.+0.j, 0.+0.j],
        [0.+0.j, 1.+0.j]], requires_grad=True))


In [21]:
# class UnitaryModel(nn.Module):
#     def __init__(self, num_qubits):
#         super(UnitaryModel, self).__init__()
#         d = 2**num_qubits
#         #self.U = orthogonal(nn.Linear(d,d, bias=False, dtype=torch.cfloat), orthogonal_map=None)
#         #self.U = geoopt.ManifoldParameter(self.U.weight, manifold=geoopt.Stiefel())
        
        
#     def forward(self, x):
#         x = self.U(x)
#         return x
    
# model = UnitaryModel(num_qubits=1)
# optimizer = geoopt.optim.RiemannianAdam(model.parameters(), lr=0.01)

In [4]:
dummy_x = torch.tensor([[1,0]], dtype=torch.cfloat, requires_grad=True)
dummy_y = torch.tensor([[1,-1j]],dtype=torch.cfloat)/torch.sqrt(torch.tensor(2))
#dummy_x = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat, requires_grad=True)
#dummy_y = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat)


print(dummy_x)
print(dummy_y)

tensor([[1.+0.j, 0.+0.j]], requires_grad=True)
tensor([[0.7071+0.0000j, -0.0000-0.7071j]])


In [118]:
len(output)

1

In [5]:
def fidelity_loss(output, target, model, reg=1e-6):
    loss = 0
    print('output: ', output)
    print('target: ', target)
    for i in range(len(output)):
        inner_prod = torch.dot(output[i].T.conj(), target[i])
        #loss += inner_prod.conj() * inner_prod
        loss -= torch.abs(inner_prod)**2 / len(output)
    
    U = list(model.parameters())[0].detach()
    orth_constraint = torch.matmul(U,U.T.conj()) - torch.eye(U.shape[0])
    return loss + reg*orth_constraint.abs().sum()

loss = fidelity_loss(dummy_x, dummy_y, model, 0)
print(loss)
loss.backward()

output:  tensor([[1.+0.j, 0.+0.j]], requires_grad=True)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
tensor(-0.5000, grad_fn=<AddBackward0>)


In [6]:
NUM_EPOCHS = 1000
fidelity = torch.zeros(NUM_EPOCHS, dtype=torch.cfloat)

for i in range(NUM_EPOCHS):
    optimizer.zero_grad()
    output = model(dummy_x)
    loss = fidelity_loss(output, dummy_y, model, reg=0)
    loss.backward()
    optimizer.step()
    fidelity[i] = loss.detach()


output:  tensor([[1.+0.j, 0.+0.j]], grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 1.0000-4.0000e-14j, -0.0020-2.0000e-11j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 1.0000-1.5996e-13j, -0.0040-3.9991e-11j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 1.0000-3.5965e-13j, -0.0060-5.9941e-11j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 1.0000-6.3877e-13j, -0.0080-7.9845e-11j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9999-9.9735e-13j, -0.0100-9.9732e-11j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9999-1.4334e-12j, -0.0120-1.1945e-10j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0

output:  tensor([[ 0.9709+4.7665e-09j, -0.2396+1.9310e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9704+4.9431e-09j, -0.2416+1.9855e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9699+5.1243e-09j, -0.2435+2.0409e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9694+5.3100e-09j, -0.2455+2.0971e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9689+5.5003e-09j, -0.2474+2.1541e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9684+5.6950e-09j, -0.2493+2.2119e-08j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.9679+5.8944e-09j, -0.2513+2.2705e-08j]],
       grad_fn=<PermuteBackward0>)
tar

output:  tensor([[ 0.8737+1.2919e-07j, -0.4864+2.3205e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8727+1.3131e-07j, -0.4882+2.3475e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8718+1.3346e-07j, -0.4899+2.3747e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8708+1.3563e-07j, -0.4917+2.4021e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8698+1.3783e-07j, -0.4934+2.4297e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8688+1.4005e-07j, -0.4951+2.4573e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.8678+1.4228e-07j, -0.4969+2.4850e-07j]],
       grad_fn=<PermuteBackward0>)
tar

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.7024+6.9280e-07j, -0.7118+6.8371e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.7010+6.9881e-07j, -0.7132+6.8688e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.6996+7.0486e-07j, -0.7146+6.9006e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.6981+7.1096e-07j, -0.7160+6.9325e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.6967+7.1712e-07j, -0.7174+6.9647e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.6953+7.2333e-07j, -0.7188+6.9969e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.6938+7.2958e-07j, -0.7201+

output:  tensor([[ 0.4801+1.8814e-06j, -0.8772+1.0297e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4784+1.8921e-06j, -0.8782+1.0307e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4766+1.9028e-06j, -0.8791+1.0316e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4748+1.9135e-06j, -0.8801+1.0324e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4731+1.9243e-06j, -0.8810+1.0333e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4713+1.9351e-06j, -0.8820+1.0341e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.4695+1.9459e-06j, -0.8829+1.0349e-06j]],
       grad_fn=<PermuteBackward0>)
tar

output:  tensor([[ 0.2190+3.5442e-06j, -0.9757+7.9544e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2170+3.5567e-06j, -0.9762+7.9075e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2151+3.5691e-06j, -0.9766+7.8602e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2131+3.5814e-06j, -0.9770+7.8124e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2112+3.5938e-06j, -0.9775+7.7642e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2092+3.6062e-06j, -0.9779+7.7154e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[ 0.2073+3.6185e-06j, -0.9783+7.6662e-07j]],
       grad_fn=<PermuteBackward0>)
tar

output:  tensor([[-0.0612+5.0734e-06j, -0.9981-3.1108e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0632+5.0820e-06j, -0.9980-3.2181e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0652+5.0905e-06j, -0.9979-3.3258e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0672+5.0990e-06j, -0.9978-3.4337e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0692+5.1074e-06j, -0.9976-3.5421e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0712+5.1158e-06j, -0.9975-3.6507e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.0732+5.1242e-06j, -0.9973-3.7597e-07j]],
       grad_fn=<PermuteBackward0>)
tar

target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3328+5.7769e-06j, -0.9430-2.0387e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3347+5.7778e-06j, -0.9424-2.0520e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3366+5.7787e-06j, -0.9417-2.0653e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3384+5.7795e-06j, -0.9410-2.0787e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3403+5.7803e-06j, -0.9403-2.0920e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3422+5.7810e-06j, -0.9397-2.1053e-06j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
output:  tensor([[-0.3441+5.7816e-06j, -0.9390-

In [117]:
fidelity.abs()

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 

In [119]:
for layer in model.named_parameters():
    print(layer)

('weight', Parameter on Stiefel(canonical) manifold containing:
tensor([[-0.6535-6.8982e-07j, -0.7570+5.9561e-07j],
        [ 0.7570-5.9543e-07j, -0.6535-6.8975e-07j]], requires_grad=True))


In [120]:
U = model.state_dict()['weight']
print("Unitarity: ", torch.dist(U.T.conj() @ U, torch.eye(2)))
print(U)

print(fidelity_loss(model(dummy_x), dummy_y, model, reg=0))

Unitarity:  tensor(0.0001)
tensor([[-0.6535-6.8982e-07j, -0.7570+5.9561e-07j],
        [ 0.7570-5.9543e-07j, -0.6535-6.8975e-07j]])
output:  tensor([[-0.6535-6.8982e-07j,  0.7570-5.9543e-07j]],
       grad_fn=<PermuteBackward0>)
target:  tensor([[0.7071+0.0000j, -0.0000-0.7071j]])
tensor(-0.5000, grad_fn=<AddBackward0>)


In [121]:
print(model(dummy_x))
print(dummy_y)

tensor([[-0.6535-6.8982e-07j,  0.7570-5.9543e-07j]],
       grad_fn=<PermuteBackward0>)
tensor([[0.7071+0.0000j, -0.0000-0.7071j]])


In [26]:
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
b = a
torch.tensordot(a.T,b.T,dims=2)

tensor(285)