In [349]:
import torch

In [341]:
class UnitaryModel(nn.Module):
    def __init__(self, num_qubits):
        super(UnitaryModel, self).__init__()
        d = 2**num_qubits
        self.U = orthogonal(nn.Linear(d,d, bias=False, dtype=torch.cfloat), orthogonal_map='matrix_exp')

    def forward(self, x):
        x = self.U(x)
        return x
    
model = UnitaryModel(num_qubits=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [350]:
#dummy_x = torch.tensor([[1,0],[0,1]], dtype=torch.cfloat, requires_grad=True)
#dummy_y = torch.tensor([[1,-1j],[1,1j]],dtype=torch.cfloat)/torch.sqrt(torch.tensor(2))
dummy_x = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat, requires_grad=True)
dummy_y = torch.tensor([[1,0],[1,0],[1,0],[1,0]], dtype=torch.cfloat)


print(dummy_x)
print(dummy_y)

tensor([[1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j]], requires_grad=True)
tensor([[1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j]])


In [351]:
def fidelity_loss(output, target, model, reg=1e-6):
    loss = 0
    for i in range(len(output)):
        inner_prod = torch.dot(output[i].T.conj(), target[i])
        #loss += inner_prod.conj() * inner_prod
        loss -= torch.abs(inner_prod)**2 / len(output)
    
    U = list(model.parameters())[0].detach()
    orth_constraint = torch.matmul(U,U.T.conj()) - torch.eye(U.shape[0])
    return loss + reg*orth_constraint.abs().sum()

loss = fidelity_loss(dummy_x, dummy_y, model, 1)
print(loss)
loss.backward()

tensor(1.3331, grad_fn=<AddBackward0>)


In [352]:
NUM_EPOCHS = 1000
fidelity = torch.zeros(NUM_EPOCHS, dtype=torch.cfloat)

for i in range(NUM_EPOCHS):
    optimizer.zero_grad()
    output = model(dummy_x)
    loss = fidelity_loss(output, dummy_y, model, reg=1)
    loss.backward()
    optimizer.step()
    fidelity[i] = loss.detach()


In [353]:
fidelity.abs()

tensor([1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331, 1.3331,
        1.3331, 1.3331, 1.3331, 1.3331, 

In [354]:
for layer in model.named_parameters():
    print(layer)

('U.parametrizations.weight.original', Parameter containing:
tensor([[-1.0000-0.0030j,  0.0000+0.0000j],
        [ 0.8103+0.1589j, -1.0000-0.0021j]], requires_grad=True))


In [355]:
U = model.state_dict()['U.parametrizations.weight.original']
print("Unitarity: ", torch.dist(U.T.conj() @ U, torch.eye(2)))
print(U)

print(fidelity_loss(model(dummy_x), dummy_y, model, reg=0))

Unitarity:  tensor(1.3521)
tensor([[-1.0000-0.0030j,  0.0000+0.0000j],
        [ 0.8103+0.1589j, -1.0000-0.0021j]])
tensor(-1., grad_fn=<AddBackward0>)


In [356]:
print(model(dummy_x))
print(dummy_y)

tensor([[7.3442e-01-6.7870e-01j, 8.4132e-05-6.6310e-05j],
        [7.3442e-01-6.7870e-01j, 8.4132e-05-6.6310e-05j],
        [7.3442e-01-6.7870e-01j, 8.4132e-05-6.6310e-05j],
        [7.3442e-01-6.7870e-01j, 8.4132e-05-6.6310e-05j]],
       grad_fn=<MmBackward0>)
tensor([[1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j],
        [1.+0.j, 0.+0.j]])
