In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensornetwork as tn
import numpy as np
from tensorflow.keras.datasets import mnist
import time

In [2]:
# Pre-settings
######################################

# Training parameters
num_train = 4000
num_test = 1000
batch_size = 100
num_epochs = 20
learning_rate = 1e-4
l2_reg = 0.

# Loss func and Optimizer
# loss_func = tf.nn.softmax_cross_entropy_with_logits()

#######################################

# load mnist dataset

# Normal way
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Using tfds
# train_ds, test_ds = tfds.load('mnist', split=['train', 'test'], data_dir='./mnist_data/', shuffle_files=True)
# train_ds = train_ds.take(num_train)
# test_ds = test_ds.take(num_test)

# print(train_ds)

# image_train, label_train = tfds.as_numpy(train_ds)
# image_test, label_test = tfds.as_numpy(test_ds)

# print(type(image_train), type(label_train), image_train.shape, label_train.shape)

########################################


In [3]:
# Local map f(x) = [x, 1-x]

x_train_array = x_train.reshape(-1, 28**2)

def local_feature_map(X):
    return tf.stack([X, 1-X], axis=2)

x_train_featured = local_feature_map(x_train_array)

x_train_batch = x_train_featured[:num_train,:,:]
y_train_batch = y_train[:num_train]

In [24]:
# MPS initialization
# The following code construct a MPS like this

#                        label_size
#                            |
    # bond_dim               |
# A1 --------- A2 ---------- L --------- A3 --------- An
# |            |                         |            |
# |            |                         |            |
#           feature      
def construct_MPS(bond_dim, feature, rank, label_size):
    label_site = rank // 2
    mps = tn.FiniteMPS.random(
        d = [feature for _ in range(label_site)] + [label_size] + [feature for _ in range(rank - label_site)],
        D = [bond_dim for _ in range(rank)],
        dtype = np.float32
    )

    # d = [feature for _ in range(label_site)] + [label_size] + [feature for _ in range(rank - label_site)]
    # D = [bond_dim for _ in range(rank)]

    # tensors = [np.random.random_sample(d[0], D[0])] + \
    #     [np.random.random_sample(D[n-1], d[n], D[n]) for n in range(1,rank-2)] + \
    #     [np.random.random_sample(D[-1],d[-1])]
    # mps = tn.FiniteMPS(
    #     tensors = tensors
    # )

    # connect the edges in the mps and contract over bond dimensions
    nodes = [tn.Node(tensor,f'block_{i}') for i,tensor in enumerate(mps.tensors)]

    connected_bonds = [nodes[k].edges[2] ^ nodes[k+1].edges[0] for k in range(-1,rank)]
    # connected_bonds = []

    return nodes, connected_bonds

# for x in connected_bonds:
#  contracted_node = tn.contract(x) # update for each contracted bond

# MPS parameters
bond_dim = 1
feature = 2
rank = 28**2
label_size = 10

nodes, bonds = construct_MPS(bond_dim, feature, rank, label_size)

# result = tn.contractors.auto(nodes, ignore_edge_order=True)

# print(result.shape)
print(len(nodes))

785


In [25]:
# Testting by getting inner product of dummy pixels and MPS
# The result should be a rank 10 vector (tensor) for label determination

nodes = nodes + [tn.Node(np.ones(2), f'pix_{i}') for i in range(rank)]
start_time = time.time()

label_site = rank // 2
for i in range(label_site):
    bonds = bonds + [nodes[i].edges[1] ^ nodes[i+rank+1].edges[0]]
for j in range(label_site+1, rank+1):
    bonds = bonds + [nodes[j].edges[1] ^ nodes[j+rank].edges[0]]

rel = tn.contractors.auto(nodes, ignore_edge_order=True)
# rel = tn.contract_parallel
print(rel.tensor)
print(f"Runtime so far:         {int(time.time()-start_time)} sec\n")

# Looks like everything is OK

[-3.06950929e-134  1.79473303e-133 -4.93674168e-134  1.26817885e-134
  6.43899148e-134  1.14777735e-133 -1.82832990e-133  7.97770991e-134
  9.06059191e-134 -1.26766471e-133]
Runtime so far:         0 sec



In [None]:
# TODO auto Grad 

class mpsTrain(tf.Module):
    def __init__(self, )

In [None]:
# The block from here is used for testing and playing around

[np.random.random_sample(2, 2, 2) for n in range(1,rank-2)]

In [None]:
a = tn.Node(np.ones(2), f'block_{a}')
b = tn.Node(np.ones(2), f'block_{b}')
print(a.edges)
bond = a[0] ^ b[0]
c = tn.contract(bond)
print(a.edges)
print(c.tensor)

In [None]:
for i in range(2):
    print(i)