In [None]:
import torch.nn as nn
from torchvision import transforms
import numpy as np
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import os
from data import ImageFolderDataset
import data
import mps

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# Paths
cwd = os.path.abspath(os.getcwd() )
cwd
mnist_root = os.path.join(cwd, "dataset", "MNIST", "raw")
mnist_root

In [None]:
# Transforms

transform = transforms.Compose([])

In [None]:
# Transforms and loading

train = data.ImageFolderDataset(root=mnist_root,images='train_images.pt',labels='train_labels.pt',force_download=False,verbose=True,transform=transform)
val = data.ImageFolderDataset(root=mnist_root,images='val_images.pt',labels='val_labels.pt',force_download=False,verbose=True,transform=transform)
test = data.ImageFolderDataset(root=mnist_root,images='test_images.pt',labels='test_labels.pt',force_download=False,verbose=True,transform=transform)

In [None]:
plt.imshow(train.images[0,0])

In [None]:
train.images.shape
train.labels.shape

In [None]:
# Convert to numpy
train.images, train.labels = train.images.numpy(), train.labels.numpy()
val.images, val.labels = val.images.numpy(), val.labels.numpy()
test.images, test.labels = test.images.numpy(), test.labels.numpy()

In [None]:
train.images.shape

In [None]:
test_mps = mps.MNIST_MPS(train.images[0,0], train.images[0,0].shape, train.labels[0])

In [None]:
print(test_mps.Bs)

In [None]:
added_mps = test_mps + test_mps

In [None]:
print(added_mps.Bs)

In [None]:
plt.imshow(added_mps.img)

In [None]:
print(mps.overlap_theta(test_mps, test_mps) )
print(mps.overlap_theta(test_mps, added_mps) )


In [None]:
quantum_digits = {
"0": None,    
"1": None,
"2": None,
"3": None,
"4": None,
"5": None,
"6": None,
"7": None,
"8": None,
"9": None,
}

for i in range(len(train.images)):
    qs = mps.MNIST_MPS(train.images[i], train.images[i,0].shape, train.labels[i] )
    if quantum_digits[str(train.labels[i])] :
        quantum_digits[str(train.labels[i])] = quantum_digits[str(train.labels[i])] + qs
    else:
        quantum_digits[str(train.labels[i])] = qs


In [None]:
# Check norm of sum of product states

for key, item in quantum_digits.items():
    print(f"{key}: {mps.overlap_rC(item, item)} \n")

In [None]:
# Test overlap with keys:
quantum_overlaps = {
"0": None,    
"1": None,
"2": None,
"3": None,
"4": None,
"5": None,
"6": None,
"7": None,
"8": None,
"9": None,
}
quantum_overlaps_rC = {
"0": None,    
"1": None,
"2": None,
"3": None,
"4": None,
"5": None,
"6": None,
"7": None,
"8": None,
"9": None,
}

for key, item in quantum_digits.items():
    quantum_overlaps[key] = mps.overlap_theta(item, test_mps)
    quantum_overlaps_rC[key] = mps.overlap_rC(item, test_mps)



In [None]:
print(f"Test label {test_mps.label}")
classifier_score = 0
classifier_score_rC = 0
pred = None
pred_rC = None
for (key, item), (krC,irC) in zip(quantum_overlaps.items(), quantum_overlaps_rC.items() ):
    print(f"{key}: {item}  {irC} \n")
    if np.abs(item) > classifier_score:
        classifier_score = np.abs(item)
        pred = key
    if np.abs(irC) > classifier_score_rC:
        classifier_score_rC = np.abs(irC)
        pred_rC = krC

print(f"Predition {pred} Overlap {classifier_score}")
print(f"rC: Predition {pred_rC} Overlap {classifier_score_rC}")

In [None]:
# Illustrate images

plt.imshow(test_mps.img)
plt.show()
plt.imshow(quantum_digits[pred].img)
plt.show()

#fig, ax = plt.subplots(1,10, figsize = (12,6))
for i, (key, item) in enumerate(quantum_digits.items() ):
    #index = ((i)%5, (i)//5)
    plt.imshow(item.img)
    plt.show()


In [None]:
# Even though some parts of the implementation is wrong, I want to check results

# Do validation

validation_dict = {
    "0": np.zeros(2), # Zeros are correct, wrong
    "1": np.zeros(2),
    "2": np.zeros(2),
    "3": np.zeros(2),
    "4": np.zeros(2),
    "5": np.zeros(2),
    "6": np.zeros(2),
    "7": np.zeros(2),
    "8": np.zeros(2),
    "9": np.zeros(2),
}

prediction_dict = { 
    "0": 0,
    "1": 0,
    "2": 0,
    "3": 0,
    "4": 0,
    "5": 0,
    "6": 0,
    "7": 0,
    "8": 0,
    "9": 0,
    
}


for i in range(len(val.images)):
    qs = mps.MNIST_MPS(val.images[i], val.images[i,0].shape, val.labels[i] )
    classifier_overlap = 0
    pred = None
    for j, (key, item) in enumerate(quantum_digits.items() ):
        
        overlap = np.abs( mps.overlap_theta(qs, item ) )
        if overlap > classifier_overlap:
            classifier_overlap = overlap
            pred = key
    
    if pred == str(qs.label):
        validation_dict[ str(qs.label) ][0] += 1
        prediction_dict[pred] += 1
    else:
        validation_dict[ str(qs.label) ][1] +=1
        prediction_dict[pred] +=1




# Use of Matrix Product States (MPS) to efficiently calculate wave function overlap of MNIST dataset images, and implementing a classification model using Machine Learning (ML)

In [None]:
print(validation_dict)
print(prediction_dict)

In [None]:
# Plot resuls

fig, ax = plt.subplots()
width = 0.35
bar1 = ax.bar(np.arange(len(validation_dict) ) - width/2, [v[0] for k,v in validation_dict.items()], facecolor="blue", width = width, label = "Correct" )
bar2 = ax.bar(np.arange(len(validation_dict) ) + width/2, [v[1] for k,v in validation_dict.items()], facecolor="red", width = width, label = "False" )

ax.set_ylabel('Predictions')
ax.set_title('Classifier Score Untrained and False Implementation')
ax.set_xticks( np.arange(len(validation_dict) ), np.arange(len(validation_dict) ) )
ax.legend()

ax.bar_label(bar1, labels= ['{:.1f}%'.format(100*v[0]/(v[0]+v[1]) ) for k,v in validation_dict.items()]  ,padding=3)

plt.show()