In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensornetwork as tn
import numpy as np
from tensorflow.keras.datasets import mnist

In [2]:
# Pre-settings
######################################

# Training parameters
num_train = 4000
num_test = 1000
batch_size = 100
num_epochs = 20
learning_rate = 1e-4
l2_reg = 0.

# Loss func and Optimizer
# loss_func = tf.nn.softmax_cross_entropy_with_logits()

#######################################

# load mnist dataset

# Normal way
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Using tfds
# train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], data_dir='./mnist_data/', shuffle_files=True)
# train_ds = train_ds.take(num_train)
# test_ds = test_ds.take(num_test)

# print(train_ds)

# image_train, label_train = tfds.as_numpy(train_ds)
# image_test, label_test = tfds.as_numpy(test_ds)

# print(type(image_train), type(label_train), image_train.shape, label_train.shape)

########################################


In [4]:
# Local map f(x) = [x, 1-x]

x_train_array = x_train.reshape(-1, 28**2)

def local_feature_map(X):
    return tf.stack([X, 1-X], axis=2)

x_train_featured = local_feature_map(x_train_array)

x_train_batch = x_train_featured[:num_train,:,:]
y_train_batch = y_train[:num_train]

In [5]:
# MPS initialization
# The following code construct a MPS like this

#                        label_size
#                            |
    # bond_dim               |
# A1 --------- A2 ---------- L --------- A3 --------- An
# |            |                         |            |
# |            |                         |            |
#           feature      
def construct_MPS(bond_dim, feature, rank, label_site):
    label_site = rank // 2
    mps = tn.FiniteMPS.random(
        d = [feature for _ in range(label_site)] + [label_size] + [feature for _ in range(rank - label_site)],
        D = [bond_dim for _ in range(rank)],
        dtype = np.float32
    )

    # connect the edges in the mps and contract over bond dimensions
    nodes = [tn.Node(tensor,f'block_{i}') for i,tensor in enumerate(mps.tensors)]

    connected_bonds = [nodes[k].edges[2] ^ nodes[k+1].edges[0] for k in range(-1,rank)]

    return nodes, connected_bonds

# for x in connected_bonds:
#  contracted_node = tn.contract(x) # update for each contracted bond

# MPS parameters
bond_dim = 3
feature = 2
rank = 28**2
label_size = 10

nodes, bonds = construct_MPS(bond_dim, feature, rank, label_size)

# result = tn.contractors.auto(nodes, ignore_edge_order=True)

# print(result.shape)


In [None]:
x_train_batch = x_train_featured[:3000,:,:]

x_train_batch.shape

In [None]:
a = tn.Node(np.ones(2))
b = tn.Node(np.ones(2))
edge = a[0] ^ b[0]
c = tn.contract(edge)
print(c.tensor, a.tensor, b.tensor) # Should print 2.0

In [None]:
class mpsTrain(tf.Module):
    def __init__(self, )