In [1]:
import pytest
import torch
import random
import numpy as np

import geoopt
import sys
sys.path.append("..")

from umps import uMPS
from tpcp_mps import MPSTPCP
import unitary_optimizer
from tpcp_mps import ManifoldType




In [2]:
# Define the loss and accuracy functions as provided
def loss_batch(outputs, labels):
    device = outputs.device
    loss = torch.zeros(1, device=device, dtype=torch.float64)
    for i in range(len(outputs)):
        # For each sample, if label==0 then probability=outputs[i], else 1-outputs[i]
        prob = outputs[i] if labels[i] == 0 else 1 - outputs[i]
        loss -= torch.log(prob + 1e-8)
    return loss

def calculate_accuracy(outputs, labels):
    predictions = (outputs < 0.5).float()  # output < 0.5 means prediction of 1 (or vice versa)
    correct = (predictions == labels).float().sum()
    accuracy = correct / labels.numel()
    return accuracy.item()

# Integration test marked with @pytest.mark.integtest

In [3]:

# Fix random seeds for reproducibility.
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Dataset parameters
N = 2       # number of qubits per data sample
bs = 3      # batch size

# Generate random input data with shape (N, bs, 2) and normalize the 2D vectors.
data = torch.randn(bs, N, 2, dtype=torch.float64)
data = data / torch.norm(data, dim=-1, keepdim=True)
target = torch.randint(0, 2, (bs,), dtype=torch.float64)

chi = 2
layers = 1
umps_model = uMPS(N=N, chi=chi, d=2, l=2, layers=layers, device=torch.device("cpu"), init_with_identity=True)
mpstpcp_model = MPSTPCP(N=N, K=1, d=2, with_identity=True, manifold=ManifoldType.EXACT)

Path is not set, setting...
Found the path
Initialized MPS params


In [16]:
with torch.no_grad():
  mpstpcp_model.kraus_ops[0].copy_(torch.randn(4, 4, dtype=torch.float64))

In [19]:
mpstpcp_model.manifold.check_point_on_manifold(mpstpcp_model.kraus_ops[0])

True

In [18]:
mpstpcp_model.proj_stiefel(check_on_manifold=True)

In [3]:
# Fix random seeds for reproducibility.
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Dataset parameters
N = 2       # number of qubits per data sample
bs = 3      # batch size

# Generate random input data with shape (N, bs, 2) and normalize the 2D vectors.
data = torch.randn(bs, N, 2, dtype=torch.float64)
data = data / torch.norm(data, dim=-1, keepdim=True)
target = torch.randint(0, 2, (bs,), dtype=torch.float64)

chi = 2
layers = 1
umps_model = uMPS(N=N, chi=chi, d=2, l=2, layers=layers, device=torch.device("cpu"), init_with_identity=True)
mpstpcp_model = MPSTPCP(N=N, K=1, d=2, with_identity=True, manifold=ManifoldType.EXACT)

# Assert if both models return the same outputs for data
assert torch.allclose(umps_model(data.permute(1, 0, 2)), mpstpcp_model(data), atol=1e-6), \
    "The outputs of uMPS and MPSTPCP models do not match."
lr = 0.1
optimizer_umps = unitary_optimizer.SGD(umps_model, lr=lr)
optimizer_mpstpcp = geoopt.optim.RiemannianSGD(mpstpcp_model.parameters(), lr=lr)

# Number of epochs (randomly chosen between 5 and 10)
# Set both models to train mode.
umps_model.train()
mpstpcp_model.train()

# For uMPS, the expected input shape is (N, bs, 2); for MPSTPCP it is (bs, N, 2).
# Prepare inputs accordingly.
input_for_umps = data.permute(1, 0, 2)  # shape (N, bs, 2)
input_for_tpcp = data  # shape (bs, N, 2)

# Zero the gradients
optimizer_umps.zero_grad()
optimizer_mpstpcp.zero_grad()

# Forward pass
outputs_umps = umps_model(input_for_umps)
outputs_tpcp = mpstpcp_model(input_for_tpcp)

# Compute loss for each model.
loss_umps = loss_batch(outputs_umps, target)
loss_tpcp = loss_batch(outputs_tpcp, target)

# print(loss_umps)
# print(loss_tpcp)

# # Backward pass and parameter update.
loss_umps.backward()
loss_tpcp.backward()
# optimizer_umps.step()
# optimizer_mpstpcp.step()

Path is not set, setting...
Found the path
Initialized MPS params


In [23]:
u_mps = umps_model.params[0].reshape(4,4)
rg_mps = unitary_optimizer.riemannian_gradient(u_mps, umps_model.params[0].grad.reshape(4,4))

u_tpcp = mpstpcp_model.kraus_ops[0]
rg_tpcp = u_tpcp.manifold.egrad2rgrad(u_tpcp, u_tpcp.grad)
print(rg_mps)
print(rg_tpcp)

tensor([[ 0.0000,  1.9855,  0.0000, -0.0674],
        [-1.9855,  0.0000,  0.0674,  0.0000],
        [ 0.0000, -0.0674,  0.0000, -0.2032],
        [ 0.0674,  0.0000,  0.2032,  0.0000]], dtype=torch.float64,
       grad_fn=<DivBackward0>)
tensor([[ 0.0000, -1.9855,  0.0000,  0.0674],
        [ 1.9855,  0.0000, -0.0674,  0.0000],
        [ 0.0000,  0.0674,  0.0000,  0.2032],
        [-0.0674,  0.0000, -0.2032,  0.0000]], dtype=torch.float64,
       grad_fn=<SubBackward0>)


In [27]:
new_u_tpcp = u_tpcp.manifold.retr(u_tpcp, -rg_tpcp * 0.1)
new_u_mps = unitary_optimizer.exp_map(u_mps, rg_mps.T * 0.1)

print(new_u_tpcp)
print(new_u_mps)

tensor([[ 9.8033e-01,  1.9725e-01,  5.9838e-04, -6.6969e-03],
        [-1.9725e-01,  9.8033e-01,  6.6969e-03,  5.9838e-04],
        [ 5.9838e-04, -6.6969e-03,  9.9977e-01, -2.0319e-02],
        [ 6.6969e-03,  5.9838e-04,  2.0319e-02,  9.9977e-01]],
       dtype=torch.float64, grad_fn=<MmBackward0>)
tensor([[ 9.8033e-01,  1.9725e-01,  5.9838e-04, -6.6969e-03],
        [-1.9725e-01,  9.8033e-01,  6.6969e-03,  5.9838e-04],
        [ 5.9838e-04, -6.6969e-03,  9.9977e-01, -2.0319e-02],
        [ 6.6969e-03,  5.9838e-04,  2.0319e-02,  9.9977e-01]],
       dtype=torch.float64, grad_fn=<MmBackward0>)


In [12]:
optimizer_umps.step()
optimizer_mpstpcp.step()

In [14]:
umps_model.params[0].reshape(4,4)

tensor([[ 9.8033e-01, -1.9725e-01,  5.9838e-04,  6.6969e-03],
        [ 1.9725e-01,  9.8033e-01, -6.6969e-03,  5.9838e-04],
        [ 5.9838e-04,  6.6969e-03,  9.9977e-01,  2.0319e-02],
        [-6.6969e-03,  5.9838e-04, -2.0319e-02,  9.9977e-01]],
       dtype=torch.float64, grad_fn=<ViewBackward0>)

In [15]:
mpstpcp_model.kraus_ops[0].reshape(4,4)

tensor([[ 9.8033e-01,  1.9725e-01,  5.9838e-04, -6.6969e-03],
        [-1.9725e-01,  9.8033e-01,  6.6969e-03,  5.9838e-04],
        [ 5.9838e-04, -6.6969e-03,  9.9977e-01, -2.0319e-02],
        [ 6.6969e-03,  5.9838e-04,  2.0319e-02,  9.9977e-01]],
       dtype=torch.float64, grad_fn=<ViewBackward0>)