In [2]:
import os
import umps


In [1]:
import torchvision
import torchvision.transforms as transforms
import torch
import torch.utils.data

def embedding_pixel(batch, label: int = 0):
    pixel_size = batch.shape[-1] * batch.shape[-2]
    x = batch.view(*batch.shape[:-2], pixel_size)
    # x[:] = 0
    x = torch.stack([x, 1-x], dim=-1)
    # x = x / torch.sum(x, dim=-1).unsqueeze(-1)
    x = x / torch.norm(x, dim=-1).unsqueeze(-1)
    return x

def embedding_label(labels: torch.Tensor):
    emb = torch.zeros(labels.shape[0], 2)
    emb[torch.arange(labels.shape[0]), labels] = 1
    return emb

def filiter_single_channel(batch):
    return batch[0, ...]

def filter_dataset(dataset, allowed_digits=[0, 1]):
    indices = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in allowed_digits:
            indices.append(i)
    return torch.utils.data.Subset(dataset, indices)

img_size = 16
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Lambda(filiter_single_channel),
    transforms.Lambda(embedding_pixel),
])

trainset = torchvision.datasets.QMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

trainset = filter_dataset(trainset, allowed_digits=[0, 1])

trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=128,
    shuffle=True,
)


In [3]:
import unitary_optimizer
umpsm = umps.uMPS(N = 16 * 16, chi = 2, d = 2, l = 2, layers = 1, device = "cpu")


Path is not set, setting...
Found the path
Initialized MPS unitaries


In [4]:
def loss_batch(outputs, labels):
    device = outputs.device
    loss = torch.zeros(1, device=device, dtype=torch.float64)

    for i in range(len(outputs)):
        prob = outputs[i] if labels[i] == 0 else 1 - outputs[i]
        loss -= torch.log(prob + 1e-8)
    return loss

def calculate_accuracy(outputs, labels):
    predictions = (outputs < 0.5).float()
    correct = (predictions == labels).float().sum()
    accuracy = correct / labels.numel()
    return accuracy.item()

In [9]:
outputs

tensor([0.7637, 0.9795, 0.8926, 0.0836, 0.8505, 0.9312, 0.0141, 0.9725, 0.1513,
        0.9215, 0.0045, 0.0138, 0.0089, 0.0070, 0.6724, 0.9093, 0.0056, 0.0247,
        0.0412, 0.0244, 0.0280, 0.0170, 0.8263, 0.0167, 0.0047, 0.0053, 0.0180,
        0.9278, 0.9438, 0.0029, 0.0059, 0.0206, 0.0076, 0.8622, 0.0115, 0.2721,
        0.8460, 0.0073, 0.0073, 0.9246, 0.0516, 0.0091, 0.9750, 0.9407, 0.8581,
        0.0070, 0.9598, 0.0050, 0.6729, 0.9233, 0.0349, 0.0216, 0.9407, 0.0320,
        0.8074, 0.0115, 0.8249, 0.8394, 0.8526, 0.0065, 0.0150, 0.1316, 0.9226,
        0.9621, 0.0381, 0.0073, 0.0118, 0.9372, 0.8304, 0.9543, 0.9461, 0.0072,
        0.0116, 0.7962, 0.9524, 0.0862, 0.9462, 0.3711, 0.9059, 0.9561, 0.9215,
        0.0224, 0.0299, 0.0226, 0.0125, 0.0085, 0.0345, 0.0175, 0.0083, 0.0250,
        0.0045, 0.8125, 0.0155, 0.9516, 0.9834, 0.1243, 0.0123, 0.9089, 0.0109,
        0.9692, 0.0099, 0.9248, 0.8843, 0.9015, 0.0178, 0.0118, 0.0097, 0.0158,
        0.9392, 0.7801, 0.0062, 0.0061, 

In [10]:
umpsm_op = unitary_optimizer.Adam(umpsm, lr=0.01)
umpsm = umps.uMPS(N = 16 * 16, chi = 2, d = 2, l = 2, layers = 1, device = "cpu")

data, target = next(iter(trainloader))
data = data.permute(1, 0, 2)


for epoch in range(100):
    acc = 0
    # for data, target in trainloader:    
    #     data = data.permute(1, 0, 2)
    umpsm_op.zero_grad()
    outputs = umpsm(data)
    loss = loss_batch(outputs, target)
    loss.backward()
    umpsm_op.step()

    # Calculate accuracy
    accuracy = calculate_accuracy(outputs, target)
    print(f"Accuracy: {accuracy:.4f}")
    print(f"loss: {loss.item()}")

Path is not set, setting...
Found the path
Initialized MPS unitaries
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0.5391
loss: 88.5568615402778
Accuracy: 0

KeyboardInterrupt: 

In [12]:
umpsm_op = unitary_optimizer.Adam(umpsm, lr=0.01)
loss_list = []

for epoch in range(100):
    acc = 0
    for data, target in trainloader:    
        data = data.permute(1, 0, 2)
        umpsm_op.zero_grad()
        outputs = umpsm(data)
        loss = loss_batch(outputs, target)
        loss.backward()
        loss_list.append(loss.item())
        umpsm_op.step()

        # Calculate accuracy
        accuracy = calculate_accuracy(outputs, target)
        # print(f"Accuracy: {accuracy:.4f}")
        acc += accuracy
    acc /= len(trainloader)
    print(f"Accuracy: {acc:.4f}")
    print(f"loss: {loss.item()}")
    # print("grad", umpsm.params[-1].grad)







89.62225885873355
88.400779478013
88.60617995000045
88.87919172283672
88.8166169799111
88.9230582490079
88.89673588191654
88.30784067465723
88.58697651412227
88.89301631927417
92.62701037763468
89.68329130806143
87.97742245863435
88.02188398893549
88.6090710640717
88.13267963698254
88.48040841418795
88.10945438398502
88.9771047835331
88.11855071648897
88.87948038653596
89.16618582039055
87.84996391477794
88.13714208482457
89.55143128984768
88.48058798526598
88.98256957168863
87.17232935990128
87.804648284048
90.08918941731231
88.79090859634914
88.4849981176141
88.63517737951305
89.06399800213697
88.3416708962494
88.59250719697431
88.2635701363373
88.07611812424696
87.83689866820531
87.96770101142275
87.87484267284829
89.0732690084382
88.36076242587293
87.53418404821701
87.50203539425091
88.21176195451984
87.20474369363372
85.81242761084728
89.88397296528669
90.54964370692309
90.33466888457934
88.95224481754441
88.16064475455956
88.33229514799883
89.57185240803437
88.41370311437367
88.6

KeyboardInterrupt: 

(tensor([[0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
         [0.0000, 1.0000],
 

In [49]:
import tpcp_mps
from importlib import reload
reload(tpcp_mps)

mps = tpcp_mps.MPSTPCP(N=16*16, K=2, d=2)

rho = mps.forward(data.to(torch.float64))

tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1.0000, dtype=torch.float64, grad_fn=<TraceBackward0>)
tensor(1