In [1]:
#!/usr/bin/env python3
import sys
sys.path.append("../../")

import argparse
import os
import torch
import numpy as np
import torchvision
import torchvision.transforms as transforms
import torch.utils.data
import geoopt
import matplotlib.pyplot as plt
import time
import sys

# If your mps/ package is local, ensure it’s on sys.path or installed in editable mode:
# sys.path.append("/path/to/your/project/root")

from mps.tpcp_mps import MPSTPCP, ManifoldType
from mps.StiefelOptimizers import StiefelAdam, StiefelSGD  # make sure this is importable
from mps.radam import RiemannianAdam 

In [2]:
from importlib import reload
###############################################################################
# MNIST dataset utilities
###############################################################################
def filter_digits(dataset, allowed_digits=[0, 1]):
    """Return a subset of MNIST dataset containing only allowed_digits (0 or 1)."""
    indices = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in allowed_digits:
            indices.append(i)
    return torch.utils.data.Subset(dataset, indices)


def filiter_single_channel(img: torch.Tensor) -> torch.Tensor:
    """
    MNIST is loaded as shape [C, H, W].
    Take only the first channel => shape [H, W].
    """
    return img[0, ...]


def embedding_pixel(batch, label: int = 0):
    """
    Flatten each image from shape [H, W] => [H*W],
    then embed x => [x, 1-x], and L2-normalize along last dim.
    """
    pixel_size = batch.shape[-1] * batch.shape[-2]
    x = batch.view(*batch.shape[:-2], pixel_size)
    x = torch.stack([x, 1 - x], dim=-1)
    x = x / torch.norm(x, dim=-1).unsqueeze(-1)
    return x


###############################################################################
# Loss & Accuracy
###############################################################################
def loss_batch(outputs, labels):
    """
    Binary cross-entropy style loss for outputs in [0, 1].
    For label=0 => prob=outputs[i], else => 1 - outputs[i].
    """
    device = outputs.device
    loss = torch.zeros(1, device=device, dtype=torch.float64)
    for i in range(len(outputs)):
        prob = outputs[i] if labels[i] == 0 else (1 - outputs[i])
        loss -= torch.log(prob + 1e-8)
        # Start of Selection
        if torch.isnan(loss):
            print(f"Loss is NaN at i={i}")
            print(prob, outputs[i], labels[i])
    return loss


def calculate_accuracy(outputs, labels):
    """
    Threshold 0.5 => label 0 or 1. Compare to true labels.
    """
    predictions = (outputs < 0.5).float()
    correct = (predictions == labels).float().sum()
    return correct / labels.numel()


In [4]:
img_size = 16
batch_size = 128
transform = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Lambda(filiter_single_channel),
        transforms.Lambda(embedding_pixel),
        transforms.Lambda(lambda x: x.to(torch.float64)),  # double precision
    ]
)

trainset = torchvision.datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)
# Filter digits 0,1 only
trainset = filter_digits(trainset, allowed_digits=[0, 1])

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True
)

In [9]:
from mps.tpcp_mps import MPSTPCP, ManifoldType
import numpy as np
N = img_size * img_size
K = 1
d = 2
seed = 2025
torch.manual_seed(seed)
np.random.seed(seed)
model = MPSTPCP(
    N=N,
    K=K,
    d=d,
    with_identity=False,  # or True, depending on your preference
    manifold=ManifoldType.EXACT,
)
model.train()

optimizer = RiemannianAdam(model.parameters(), lr=0.005)

epochs = 100
for epoch in range(epochs):
    for data, target in trainloader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = loss_batch(outputs, target)
        loss.backward()
        optimizer.step()
        print(loss.item())

91.471186163515
90.58992780145758
90.64970388422005
90.17960656064757
88.85069000922996
89.14160715394425
88.70721054196494
88.77232470335193
88.70934791150289
88.5380156609986
88.5951266764257
88.58316495778342
88.27025293591217
89.33308296596097
88.60920329693712
89.64273073699528
90.28030406316451
86.12099849295677
87.87539189062213
88.64757382792315
88.98072573116697
88.6608948299592
88.33290124429467
87.17429564522283
89.17958262968608
88.84420978166311
87.99615862016093
88.16515988109435
87.65706988642607
89.8760846349806
87.6594897902322
90.34937888216777
89.14241380005802
88.64079771466244
88.18658772790059
88.33597461317275
87.6656990551811
88.87004647612672
89.75757538345697
87.62742914234934
88.0057305800472
88.93577237681349
88.13840423957085
87.92050212568692
89.93536705684174
88.80139751249762
88.4804682902366
88.3848550668743
88.86980209779014
88.7657726603463
89.10616409602481
88.18228079992174
89.43286612771884
88.45372534295011
88.21735399249515
88.17155049698042
89.0

KeyboardInterrupt: 

In [13]:
new_kraus_op[:]

torch.Size([12, 4])

In [47]:
# Create a new model instance with K = 3.
new_model = MPSTPCP(
    N=N,        # same N as before
    K=3,        # now with 3 Kraus operators
    d=d,
    with_identity=True,  # or True, as preferred
    manifold=ManifoldType.EXACT,
)
# (Optionally, initialize its parameters)
new_model.kraus_ops.init_params()

# Transfer the trained parameter and zero-out the additional ones.
with torch.no_grad():
    # Set the first Kraus operator to the trained parameter.
    for new_kraus_op, kraus_op in zip(new_model.kraus_ops, model.kraus_ops):
        trained_param = kraus_op.detach().clone()
        new_kraus_op.data[:] = 0 
        new_kraus_op.data[:d**2, :d**2] = trained_param

        # then add a small random matrix to the last Kraus operator
        rand = torch.randn(new_kraus_op.shape)
        new_kraus_op.data[:] += 0.01 * rand

new_model.proj_stiefel()


In [None]:
new_model.train()

optimizer = RiemannianAdam(new_model.parameters(), lr=0.005)

epochs = 100
for epoch in range(epochs):
    acc_tot = 0
    loss_tot = 0
    for data, target in trainloader:
        optimizer.zero_grad()
        outputs = new_model(data)
        loss = loss_batch(outputs, target)
        loss.backward()
        optimizer.step()
        acc = calculate_accuracy(outputs, target)
        acc_tot += acc
        loss_tot += loss.item()

    print(f"Epoch {epoch} / {epochs} / Loss: {loss_tot / len(trainloader)} / Accuracy: {acc_tot / len(trainloader)}")

Epoch 0 / 100 / Loss: 15.742275344467558 / Accuracy: 0.9583756923675537
Epoch 1 / 100 / Loss: 15.210885631339412 / Accuracy: 0.9688015580177307


In [5]:
data, target = next(iter(trainloader))

In [12]:
mps.kraus_ops

torch.Size([2, 4, 4])

In [15]:
I = torch.zeros(4, 4)

for kraus_op in mps.kraus_ops[0]:
    I += kraus_op.T @ kraus_op

I


tensor([[ 1.0000e+00, -1.5151e-09,  3.6242e-10, -8.2355e-09],
        [-1.5151e-09,  1.0000e+00,  3.9718e-09,  5.4807e-09],
        [ 3.6242e-10,  3.9718e-09,  1.0000e+00,  7.3616e-09],
        [-8.2355e-09,  5.4807e-09,  7.3616e-09,  1.0000e+00]],
       grad_fn=<AddBackward0>)

In [24]:
rho[0]

tensor([[0.4929, 0.1871],
        [0.1871, 0.5071]], dtype=torch.float64, grad_fn=<SelectBackward0>)

In [84]:
import tpcp_mps
from importlib import reload
reload(tpcp_mps)

mps = tpcp_mps.MPSTPCP(N=16 * 16, K=1, d=2)

outputs = mps.forward(data[:, :, :].to(torch.float64))

In [85]:
rand_gen = torch.randn(100, 16 * 16, 2)
rand_gen /= torch.norm(rand_gen, dim=-1).unsqueeze(-1)

mps.forward(rand_gen.to(torch.float64))

tensor([0.3217, 0.5300, 0.6689, 0.4137, 0.4518, 0.4848, 0.4619, 0.6848, 0.8434,
        0.4206, 0.6056, 0.4578, 0.6694, 0.7811, 0.2808, 0.4834, 0.4791, 0.4612,
        0.0448, 0.7090, 0.4482, 0.6645, 0.3723, 0.4935, 0.3612, 0.5337, 0.6680,
        0.3314, 0.4152, 0.8303, 0.4398, 0.4487, 0.7049, 0.9043, 0.7533, 0.2612,
        0.8024, 0.4165, 0.2396, 0.3580, 0.3231, 0.2031, 0.5297, 0.4053, 0.1495,
        0.6114, 0.3828, 0.1742, 0.3090, 0.6286, 0.4317, 0.6117, 0.6485, 0.6831,
        0.4552, 0.5844, 0.5443, 0.6155, 0.3997, 0.4054, 0.1457, 0.1112, 0.3017,
        0.4169, 0.4252, 0.3610, 0.7113, 0.3329, 0.3020, 0.2280, 0.4437, 0.6049,
        0.2926, 0.2494, 0.5580, 0.5982, 0.6938, 0.4332, 0.3964, 0.6479, 0.6917,
        0.5261, 0.6665, 0.5378, 0.5206, 0.6813, 0.5700, 0.5287, 0.3308, 0.7020,
        0.4728, 0.6162, 0.4287, 0.5274, 0.7051, 0.4734, 0.4715, 0.1893, 0.6181,
        0.4499], dtype=torch.float64, grad_fn=<SelectBackward0>)

In [132]:
B.shape

torch.Size([2, 1, 4, 4])

In [157]:
G = mps.kraus_ops[-1].grad[0]
G = G / torch.norm(G)
K = mps.kraus_ops[-1].detach()[0]


A = torch.cat([G, K], dim=1)
B = torch.cat([K, -G], dim=1)

rg = A @ torch.linalg.inv(torch.eye(8) + 0.5 * 0.05 * B.T @ A) @ B.T @ K
Kp = K - rg * 0.05

Kp @ Kp.T

tensor([[ 1.0000e+00,  5.5511e-17, -5.5511e-17,  0.0000e+00],
        [ 5.5511e-17,  1.0000e+00, -2.7756e-17,  2.7756e-17],
        [-5.5511e-17, -2.7756e-17,  1.0000e+00,  5.5511e-17],
        [ 0.0000e+00,  2.7756e-17,  5.5511e-17,  1.0000e+00]],
       dtype=torch.float64)

In [165]:
mps.kraus_ops[0].grad

tensor([[[ 0.0000,  0.0000,  0.0000,  1.9005],
         [ 0.0000,  0.0000,  0.0000, -0.1128],
         [ 0.0000,  0.0000,  0.0000, 21.8310],
         [ 0.0000,  0.0000,  0.0000,  2.1302]]], dtype=torch.float64)

In [176]:
optimizer.params[0]

Parameter containing:
tensor([[[-0.0781, -0.2654,  0.9404, -0.1976],
         [ 0.5015, -0.7846, -0.2378, -0.2761],
         [-0.4079,  0.1087, -0.1895, -0.8865],
         [-0.7589, -0.5496, -0.1522,  0.3143]]], dtype=torch.float64,
       requires_grad=True)

In [191]:
optimizer.params[0][0] @ optimizer.params[0][0].T

tensor([[ 1.0000e+00, -3.3307e-16, -2.6368e-16,  4.1633e-17],
        [-3.3307e-16,  1.0000e+00,  2.2898e-16, -3.6082e-16],
        [-2.6368e-16,  2.2898e-16,  1.0000e+00, -2.2204e-16],
        [ 4.1633e-17, -3.6082e-16, -2.2204e-16,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [189]:
mps.kraus_ops[0]

Parameter containing:
tensor([[[-0.0462, -0.5530, -0.5585, -0.6166],
         [ 0.3898, -0.3660,  0.7532, -0.3832],
         [ 0.6462, -0.4353, -0.2547,  0.5727],
         [-0.6544, -0.6089,  0.2366,  0.3808]]], dtype=torch.float64,
       requires_grad=True)

In [201]:
K = mps.kraus_ops[0].reshape(8, 4) 

K.T @ K


tensor([[ 1.0000e+00, -1.6662e-16, -6.0210e-17, -4.3451e-17],
        [-1.6662e-16,  1.0000e+00,  1.8244e-17, -6.5638e-17],
        [-6.0210e-17,  1.8244e-17,  1.0000e+00,  1.2042e-16],
        [-4.3451e-17, -6.5638e-17,  1.2042e-16,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [221]:
mps.kraus_ops[0].T @ mps.kraus_ops[0]

tensor([[ 1.0000e+00, -9.8817e-19,  1.8542e-17,  1.5244e-16],
        [-9.8817e-19,  1.0000e+00, -2.7013e-18, -3.7641e-17],
        [ 1.8542e-17, -2.7013e-18,  1.0000e+00, -1.6334e-17],
        [ 1.5244e-16, -3.7641e-17, -1.6334e-17,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [230]:
import kraus_optimizer
reload(kraus_optimizer)
reload(tpcp_mps)
# mps = tpcp_mps.MPSTPCP(N=16 * 16, K=2, d=2)
# optimizer = kraus_optimizer.Adam(mps.kraus_ops, lr=0.0001)
optimizer = kraus_optimizer.CayleySGDMomentum(mps.kraus_ops, lr=0.0001, beta=0.95, q=0.5, s=4)

for _ in range(1000):
    optimizer.zero_grad()
    outputs = mps.forward(data[:, :, :].to(torch.float64))
    loss = loss_batch(outputs, target)
    loss.backward()
    optimizer.step()
    print(loss.item())
    print(calculate_accuracy(outputs, target))

# for epoch in range(100):
#     acc_tot = 0
#     loss_tot = 0
#     for data, target in trainloader:
#         optimizer.zero_grad()
#         outputs = mps.forward(data[:, :, :].to(torch.float64))
#         loss = loss_batch(outputs, target)
#         loss.backward()
#         optimizer.step()

#         acc = calculate_accuracy(outputs, target)
#         acc_tot += acc
#         loss_tot += loss.item()
#     acc_tot /= len(trainloader)
#     loss_tot /= len(trainloader)
#     print(f"Accuracy: {acc_tot:.4f}")
#     print(f"Loss: {loss_tot:.4f}")


88.33191124933472
0.5390625
88.33190862632068
0.5390625
88.33190369198766
0.5390625
88.33189689400702
0.5390625
88.3318887714528
0.5390625
88.33187990826016
0.5390625
88.331870888128
0.5390625
88.3318622538474
0.5390625
88.33185447344013
0.5390625
88.3318479147869
0.5390625
88.33184282967609
0.5390625
88.33183934745813
0.5390625
88.33183747780937
0.5390625
88.33183712152645
0.5390625
88.33183808782029
0.5390625
88.33184011627898
0.5390625
88.33184290152242
0.5390625
88.33184611857837
0.5390625
88.33184944715126
0.5390625
88.33185259321003
0.5390625
88.33185530665929
0.5390625
88.33185739425168
0.5390625
88.33185872730965
0.5390625
88.33185924422692
0.5390625
88.3318589480855
0.5390625
88.3318579000275
0.5390625
88.33185620925777
0.5390625
88.3318540206942
0.5390625
88.33185150135407
0.5390625
88.33184882654099
0.5390625
88.33184616680894
0.5390625
88.33184367653107
0.5390625
88.33184148470836
0.5390625
88.3318396884362
0.5390625
88.33183834922204
0.5390625
88.33183749213111
0.5390625
8

KeyboardInterrupt: 

In [236]:
mps.kraus_ops[-1].gra

tensor([[     0.0000, -20810.3705,      0.0000,  19047.8718],
        [     0.0000,      0.0000,      0.0000,      0.0000],
        [     0.0000, -25382.7747,      0.0000,   7647.4490],
        [     0.0000,      0.0000,      0.0000,      0.0000]],
       dtype=torch.float64)

In [271]:
u = mps_tpcp.kraus_ops[-10]
u @ u.T

tensor([[ 1.0000e+00,  7.2720e-15, -1.4322e-14,  2.0721e-15],
        [ 7.2720e-15,  1.0000e+00,  2.4009e-15, -2.2520e-14],
        [-1.4322e-14,  2.4009e-15,  1.0000e+00, -5.4816e-15],
        [ 2.0721e-15, -2.2520e-14, -5.4816e-15,  1.0000e+00]],
       dtype=torch.float64, grad_fn=<MmBackward0>)

In [279]:
umpsm.params[0]

Parameter containing:
tensor([[[[ 0.5022, -0.2951],
          [ 0.5226,  0.6225]],

         [[-0.5633, -0.2183],
          [-0.4017,  0.6883]]],


        [[[ 0.4872, -0.5938],
          [-0.6219, -0.1525]],

         [[ 0.4394,  0.7160],
          [-0.4228,  0.3399]]]], dtype=torch.float64, requires_grad=True)

tensor([0.2863, 0.5068, 0.4480, 0.3460, 0.5092, 0.4719, 0.6662, 0.3924, 0.6893,
        0.5020], dtype=torch.float64, grad_fn=<SelectBackward0>)

In [324]:
reload(tpcp_mps)
N = 16 * 16
mps_tpcp = tpcp_mps.MPSTPCP(N=N, K=1, d=2)
# mps_unitary = umps.uMPS(N=N, chi=2, d=2, l=2, layers=1, device="cpu")
mps_unitary = umpsm
for i in range(len(mps_unitary.params)):
    mps_tpcp.kraus_ops[i].data[:] = mps_unitary.params[i].reshape(4,4).T
# random_gen = torch.randn(10, N, 2)
random_gen = data.to(torch.float64)
# random_gen /= torch.norm(random_gen, dim=-1).unsqueeze(-1)

print(mps_unitary.forward(random_gen.to(torch.float64).permute(1, 0, 2)))
print(mps_tpcp.forward(random_gen.to(torch.float64)))


tensor([0.8430, 0.8851, 0.8424, 0.8830, 0.0178, 0.9027, 0.0251, 0.8285, 0.8845,
        0.0244, 0.0148, 0.0168, 0.0300, 0.0149, 0.9047, 0.1952, 0.0144, 0.1937,
        0.8536, 0.8440, 0.8338, 0.7988, 0.7240, 0.0137, 0.0134, 0.8092, 0.0148,
        0.8996, 0.8683, 0.8271, 0.0218, 0.8844, 0.0131, 0.0134, 0.0241, 0.9100,
        0.0253, 0.7668, 0.8456, 0.8256, 0.0136, 0.0143, 0.0164, 0.7096, 0.9464,
        0.9265, 0.7339, 0.7295, 0.7967, 0.9186, 0.0278, 0.8579, 0.1444, 0.0309,
        0.7035, 0.0141, 0.8855, 0.0148, 0.7287, 0.0274, 0.8310, 0.0129, 0.0217,
        0.8733, 0.0219, 0.0436, 0.9204, 0.3600, 0.1956, 0.7671, 0.0143, 0.8237,
        0.0142, 0.0132, 0.0165, 0.7980, 0.4413, 0.0176, 0.7909, 0.1937, 0.0148,
        0.8425, 0.0761, 0.9172, 0.0302, 0.0144, 0.8631, 0.8260, 0.7069, 0.8629,
        0.0128, 0.0527, 0.0152, 0.8021, 0.2644, 0.8366, 0.4279, 0.0167, 0.8583,
        0.5997, 0.0132, 0.7826, 0.0138, 0.0157, 0.0142, 0.0141, 0.8855, 0.1429,
        0.0293, 0.8130, 0.0581, 0.7953, 

In [382]:
import kraus_optimizer

reload(kraus_optimizer)

N = 16
# mps_tpcp = tpcp_mps.MPSTPCP(N=N, K=2, d=2)
# data = torch.randn(100, N, 2)
# target = torch.randint(0, 2, (100,))
optimizer = kraus_optimizer.CayleyAdam(mps_tpcp.kraus_ops, lr=0.3, betas = (0.9, 0.99), q = 0.01)

for epoch in range(2000):
    acc_tot = 0
    loss_tot = 0
    optimizer.zero_grad()
    outputs = mps_tpcp.forward(data.to(torch.float64))
    loss = loss_batch(outputs, target)
    loss.backward()
    optimizer.step()

    acc = calculate_accuracy(outputs, target)
    acc_tot += acc
    loss_tot += loss.item()
    # print(acc)
    print(f"Accuracy: {acc:.4f} / Loss: {loss.item():.4f}")


# K @ K.T

Accuracy: 0.8600 / Loss: 45.5356
Accuracy: 0.8600 / Loss: 45.5349
Accuracy: 0.8600 / Loss: 45.5342
Accuracy: 0.8600 / Loss: 45.5335
Accuracy: 0.8600 / Loss: 45.5328
Accuracy: 0.8600 / Loss: 45.5321
Accuracy: 0.8600 / Loss: 45.5314
Accuracy: 0.8600 / Loss: 45.5307
Accuracy: 0.8600 / Loss: 45.5300
Accuracy: 0.8600 / Loss: 45.5293
Accuracy: 0.8600 / Loss: 45.5287
Accuracy: 0.8600 / Loss: 45.5280
Accuracy: 0.8600 / Loss: 45.5273
Accuracy: 0.8600 / Loss: 45.5266
Accuracy: 0.8600 / Loss: 45.5259
Accuracy: 0.8600 / Loss: 45.5252
Accuracy: 0.8600 / Loss: 45.5246
Accuracy: 0.8600 / Loss: 45.5239
Accuracy: 0.8600 / Loss: 45.5232
Accuracy: 0.8600 / Loss: 45.5225
Accuracy: 0.8600 / Loss: 45.5218
Accuracy: 0.8600 / Loss: 45.5211
Accuracy: 0.8600 / Loss: 45.5205
Accuracy: 0.8600 / Loss: 45.5198
Accuracy: 0.8600 / Loss: 45.5191
Accuracy: 0.8600 / Loss: 45.5184
Accuracy: 0.8600 / Loss: 45.5177
Accuracy: 0.8600 / Loss: 45.5171
Accuracy: 0.8600 / Loss: 45.5164
Accuracy: 0.8600 / Loss: 45.5157
Accuracy: 

In [394]:
import kraus_optimizer

reload(kraus_optimizer)

mps_tpcp = tpcp_mps.MPSTPCP(N=16 * 16, K=1, d=2)
# for i in range(len(umpsm.params)):
#     mps_tpcp.kraus_ops[i].data[:] = umpsm.params[i].reshape(4,4).T
optimizer = kraus_optimizer.CayleySGDMomentum(mps_tpcp.kraus_ops, lr=0.01, beta=0.9, q=0.5, s=2)
# optimizer = kraus_optimizer.CayleyAdam(mps_tpcp.kraus_ops, lr=0.005, betas = (0.9, 0.999), q = 0.5, s = 2)

for epoch in range(100):
    acc_tot = 0
    loss_tot = 0
    for data, target in trainloader:
        optimizer.zero_grad()
        outputs = mps_tpcp.forward(data.to(torch.float64))
        loss = loss_batch(outputs, target)
        loss.backward()
        optimizer.step()

        acc = calculate_accuracy(outputs, target)
        acc_tot += acc
        loss_tot += loss.item()
        # print(acc)
        # print(loss.item())
        # print(mps.kraus_ops[-1].sum())
    acc_tot /= len(trainloader)
    loss_tot /= len(trainloader)
    print(f"Accuracy: {acc_tot:.4f}")
    print(f"Loss: {loss_tot:.4f}")


Accuracy: 0.5088
Loss: 90.9863
Accuracy: 0.5060
Loss: 92.5615
Accuracy: 0.4993
Loss: 92.6686
Accuracy: 0.5030
Loss: 91.1258
Accuracy: 0.5018
Loss: 91.4921
Accuracy: 0.5094
Loss: 90.2015
Accuracy: 0.5080
Loss: 91.6243
Accuracy: 0.5038
Loss: 90.4939
Accuracy: 0.5081
Loss: 90.6116
Accuracy: 0.5005
Loss: 91.6860
Accuracy: 0.5039
Loss: 90.9248
Accuracy: 0.5025
Loss: 93.0975


KeyboardInterrupt: 

In [264]:
for i, kraus_op in enumerate(mps_tpcp.kraus_ops):
    kraus_op.data[:] = umpsm.params[i].reshape(4,4).T
    # print(mps_tpcp.forward(data.to(torch.float64)))

In [263]:
U = list(mps_tpcp.kraus_ops.parameters())[0].data[:]

U @ U.T

tensor([[ 1.0000e+00,  3.8858e-16,  1.3878e-17, -3.5405e-16],
        [ 3.8858e-16,  1.0000e+00, -1.2490e-16,  2.7756e-17],
        [ 1.3878e-17, -1.2490e-16,  1.0000e+00, -1.1102e-16],
        [-3.5405e-16,  2.7756e-17, -1.1102e-16,  1.0000e+00]],
       dtype=torch.float64)

In [252]:
umpsm.params[0]

Parameter containing:
tensor([[[[ 0.5022, -0.2951],
          [ 0.5226,  0.6225]],

         [[-0.5633, -0.2183],
          [-0.4017,  0.6883]]],


        [[[ 0.4872, -0.5938],
          [-0.6219, -0.1525]],

         [[ 0.4394,  0.7160],
          [-0.4228,  0.3399]]]], dtype=torch.float64, requires_grad=True)

In [None]:
mp

In [329]:
umpsm = umps.uMPS(N = 16 * 16, chi = 2, d = 2, l = 2, layers = 1, device = "cpu")
umpsm_op = unitary_optimizer.Adam(umpsm, lr=0.01)
loss_list = []

for epoch in range(100):
    acc = 0
    for data, target in trainloader:    
        data = data.permute(1, 0, 2)
        umpsm_op.zero_grad()
        outputs = umpsm(data)
        loss = loss_batch(outputs, target)
        loss.backward()
        loss_list.append(loss.item())
        umpsm_op.step()

        # Calculate accuracy
        accuracy = calculate_accuracy(outputs, target)
        # print(f"Accuracy: {accuracy:.4f}")
        acc += accuracy
    acc /= len(trainloader)
    print(f"Accuracy: {acc:.4f}")
    print(f"loss: {loss.item()}")

Path is not set, setting...
Found the path
Initialized MPS unitaries
Accuracy: 0.5238
loss: 83.79042052731249
Accuracy: 0.5281
loss: 83.30355859970851
Accuracy: 0.6275
loss: 50.921918496208235


KeyboardInterrupt: 