In [1]:
import opt_einsum as oe
import numpy as np
import torch
import sys
sys.path.append("../../")

In [2]:
from importlib import reload
from mps import simple_mps, tpcp_mps
reload(simple_mps)

<module 'mps.simple_mps' from '/Users/keisuke/Documents/presentation/QC_MPS/mps/notebooks/../../mps/simple_mps.py'>

In [3]:
def filter_digits(dataset, allowed_digits=[0, 1]):
    """Return a subset of MNIST dataset containing only allowed_digits (0 or 1)."""
    indices = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in allowed_digits:
            indices.append(i)
    return torch.utils.data.Subset(dataset, indices)


def filiter_single_channel(img: torch.Tensor) -> torch.Tensor:
    """
    MNIST is loaded as shape [C, H, W].
    Take only the first channel => shape [H, W].
    """
    return img[0, ...]


def embedding_pixel(batch, label: int = 0):
    """
    Flatten each image from shape [H, W] => [H*W],
    then embed x => [x, 1-x], and L2-normalize along last dim.
    """
    pixel_size = batch.shape[-1] * batch.shape[-2]
    x = batch.view(*batch.shape[:-2], pixel_size)
    x = torch.stack([x, 1 - x], dim=-1)
    x = x / torch.sum(x, dim=-1).unsqueeze(-1)
    return x

In [4]:
###############################################################################
# Loss & Accuracy
###############################################################################
def loss_batch(outputs, labels):
    """
    Binary cross-entropy style loss for outputs in [0, 1].
    For label=0 => prob=outputs[i], else => 1 - outputs[i].
    """
    device = outputs.device
    loss = torch.zeros(1, device=device, dtype=torch.float64)
    for i in range(len(outputs)):
        prob = outputs[i] if labels[i] == 0 else (1 - outputs[i])
        loss -= torch.log(prob + 1e-8)
        # Start of Selection
        if torch.isnan(loss):
            print(f"Loss is NaN at i={i}")
            print(prob, outputs[i], labels[i])
    return loss


def calculate_accuracy(outputs, labels):
    """
    Threshold 0.5 => label 0 or 1. Compare to true labels.
    """
    predictions = (outputs < 0.5).float()
    correct = (predictions == labels).float().sum()
    return correct / labels.numel()

from torchvision import transforms
import torchvision

img_size = 16
transform = transforms.Compose(
    [
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Lambda(filiter_single_channel),
        transforms.Lambda(embedding_pixel),
        transforms.Lambda(lambda x: x.to(torch.float64)),  # double precision
    ]
)

trainset = torchvision.datasets.MNIST(
    root="data", train=True, download=True, transform=transform
)
# Filter digits 0,1 only
trainset = filter_digits(trainset, allowed_digits=[0, 1])

batch_size = 128

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=False
)


In [5]:
# ---------- Build MPS model ----------
N = img_size * img_size
d = l = 2 #data input dimension and class label dimension 
chi_umps = 2
chi_max = 2
reload(simple_mps)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
smps = simple_mps.SimpleMPS(
    N, 
    2,
    d, 
    l, 
    layers=2,
    device=device, 
    dtype=torch.float64, 
    optimize="greedy",
)


Path is not set, setting...
Found the path
Initialized MPS with random matrices


In [6]:
def accuracy(outputs, target):
    return (outputs.argmax(dim=-1) == target).float().mean()
losses = []
running_loss = 0
running_accuracy = 0
logsoftmax = torch.nn.LogSoftmax(dim=-1)
nnloss = torch.nn.NLLLoss(reduction="mean")
optimizer = torch.optim.Adam(smps.parameters(), lr=0.001)
n_samples = 0
for epoch in range(10):
    for batch_idx, (data, target) in enumerate(trainloader):
        target = target.to(device).to(torch.int64)
        data = data.to(device).permute(1, 0, 2)
        optimizer.zero_grad()
        outputs = smps(data)
        outputs = logsoftmax(outputs)
        loss = nnloss(outputs, target)
        loss.backward()
        optimizer.step()

        data_size = data.shape[1]
        
        # Calculate accuracy
        # print(torch.exp(outputs[:10]), target[:10])
        
        running_loss += loss.item() * data_size
        n_samples += data_size
        
        if batch_idx % 1 == 0:
            avg_loss = running_loss / n_samples
            avg_accuracy = accuracy(outputs, target)
            losses.append(avg_loss)
            running_loss = 0
            running_accuracy = 0
            n_samples = 0
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} Accuracy: {:.2f}%'.format(
                epoch, batch_idx * data_size, len(trainloader.dataset),
                100. * batch_idx / len(trainloader), avg_loss, avg_accuracy * 100))



KeyboardInterrupt: 

In [7]:
from mps import tpcp_mps  

reload(tpcp_mps)

tpcp = tpcp_mps.MPSTPCP(N, K=1, d=2, with_identity=True, manifold=tpcp_mps.ManifoldType.EXACT)
tpcp.W.data[:, 1] = 0
tpcp.W.data[:, 0] = 1



In [17]:
tpcp.set_canonical_mps(smps)

In [18]:
def accuracy(outputs, target):
    correct = (outputs < 0).float() == target.float()
    return correct.float().sum() / target.numel()

data, target = next(iter(trainloader))
out = tpcp(data)

calculate_accuracy(out, target)

tensor(0.5781)

In [19]:
from mps.StiefelOptimizers import StiefelAdam
from mps.radam import RiemannianAdam
W = torch.zeros(tpcp.L, 2, dtype=torch.float64)
W[:, 0] = 1
W[:, 1] = 0
tpcp.initialize_W(W)
optimizer = RiemannianAdam(tpcp.parameters(), lr=0.0001, betas=(0.9, 0.999))
# optimzier = StiefelAdam(tpcp.parameters(), lr=0.0001, expm_method="ForwardEuler")
epochs = 100
for epoch in range(epochs):
    acc_tot = 0
    loss_tot = 0
    for data, target in trainloader:
        optimizer.zero_grad()
        outputs = tpcp(data)
        loss = loss_batch(outputs, target)
        loss.backward()
        optimizer.step()
        acc = calculate_accuracy(outputs, target)
        acc_tot += acc
        loss_tot += loss.item()
        print("Loss: ", loss.item(), "Accuracy: ", acc)

    print(f"Epoch {epoch} / {epochs} / Loss: {loss_tot / len(trainloader)} / Accuracy: {acc_tot / len(trainloader)}")


Loss:  152.3278759303782 Accuracy:  tensor(0.5781)
Loss:  163.4813242453831 Accuracy:  tensor(0.5391)
Loss:  168.54286556974026 Accuracy:  tensor(0.5391)
Loss:  154.6920416529844 Accuracy:  tensor(0.5312)
Loss:  148.96475346766428 Accuracy:  tensor(0.5547)
Loss:  155.0554851535941 Accuracy:  tensor(0.5156)
Loss:  134.8051555082037 Accuracy:  tensor(0.5547)
Loss:  136.01640378461934 Accuracy:  tensor(0.5391)
Loss:  155.10144471849856 Accuracy:  tensor(0.5078)
Loss:  130.80626742558394 Accuracy:  tensor(0.5312)
Loss:  132.78404800383308 Accuracy:  tensor(0.5547)
Loss:  150.14285403325962 Accuracy:  tensor(0.5312)
Loss:  109.59956925090549 Accuracy:  tensor(0.6016)
Loss:  150.49303731536605 Accuracy:  tensor(0.4531)
Loss:  95.44415424492747 Accuracy:  tensor(0.6250)
Loss:  111.04002203129404 Accuracy:  tensor(0.5547)
Loss:  144.69964718887564 Accuracy:  tensor(0.4766)
Loss:  108.45220235635139 Accuracy:  tensor(0.5469)
Loss:  123.5052850100724 Accuracy:  tensor(0.5625)
Loss:  126.44659005

KeyboardInterrupt: 

In [104]:
out = smps(data)
out = logsoftmax(out)
out


tensor([[-1.0300e+01, -3.3639e-05],
        [-1.0873e+01, -1.8957e-05],
        [-1.1319e+01, -1.2135e-05],
        [-1.0072e+01, -4.2227e-05],
        [-1.1331e+01, -1.2001e-05],
        [-8.8451e+00, -1.4410e-04],
        [ 0.0000e+00, -4.8987e+01],
        [-3.3434e-08, -1.7214e+01],
        [-9.8337e+00, -5.3613e-05],
        [-1.0920e+01, -1.8086e-05],
        [-3.4195e-14, -3.1010e+01],
        [-1.0579e+01, -2.5457e-05],
        [-6.0174e-14, -3.0443e+01],
        [-7.9001e+00, -3.7078e-04],
        [ 0.0000e+00, -3.8507e+01],
        [-9.9722e+00, -4.6679e-05],
        [-4.7244e-11, -2.3776e+01],
        [-5.5502e-12, -2.5917e+01],
        [-1.0454e+01, -2.8819e-05],
        [-3.4917e-05, -1.0263e+01],
        [-2.2871e-14, -3.1411e+01],
        [ 0.0000e+00, -4.2346e+01],
        [ 0.0000e+00, -4.4464e+01],
        [-1.0186e+01, -3.7694e-05],
        [-3.7323e-08, -1.7104e+01],
        [ 0.0000e+00, -7.0261e+01],
        [-8.3386e+00, -2.3913e-04],
        [-8.4054e+00, -2.236

In [108]:
out = tpcp(data.permute(1, 0, 2))

# loss_batch(out, target)
(torch.sgn(out) + 1) / 2 + target


tensor([1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0.,
        1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.,
        1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 1.,
        0., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
        1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1.,
        1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 1.,
        0., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 0.,
        1., 1.], dtype=torch.float64, grad_fn=<AddBackward0>)