In [1]:
pip install tensornetwork

Collecting tensornetwork
[?25l  Downloading https://files.pythonhosted.org/packages/b5/e9/8575b2a9bdf634258b64e581014117f1c8386fd55c2bc84e6624ba747916/tensornetwork-0.4.1-py3-none-any.whl (295kB)
[K     |█                               | 10kB 19.5MB/s eta 0:00:01[K     |██▏                             | 20kB 1.8MB/s eta 0:00:01[K     |███▎                            | 30kB 2.3MB/s eta 0:00:01[K     |████▍                           | 40kB 2.6MB/s eta 0:00:01[K     |█████▌                          | 51kB 2.0MB/s eta 0:00:01[K     |██████▋                         | 61kB 2.3MB/s eta 0:00:01[K     |███████▊                        | 71kB 2.6MB/s eta 0:00:01[K     |████████▉                       | 81kB 2.8MB/s eta 0:00:01[K     |██████████                      | 92kB 3.0MB/s eta 0:00:01[K     |███████████                     | 102kB 2.8MB/s eta 0:00:01[K     |████████████▏                   | 112kB 2.8MB/s eta 0:00:01[K     |█████████████▎                  | 122kB 2.

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
# Import tensornetwork
import tensornetwork as tn
# Set the backend to tesorflow
# (default is numpy)
tn.set_default_backend("tensorflow")

In [19]:

class MPSLayer(tf.keras.layers.Layer):

  def __init__(self, rank, bond, feature, label_size, dtype=tf.float32):
    super(MPSLayer, self).__init__()
    self.label_site = rank // 2
    # Create the variables for the layer.
    # self.l_region = tf.Variable(tf.random.normal(
    #         shape=(self.label_site, feature, bond, bond)),
    #          name="left", trainable=True)
    # self.r_region = tf.Variable(tf.random.normal(
    #         shape=(rank - self.label_site, feature, bond, bond)),
    #          name="right", trainable=True)
    # self.output_site = tf.Variable(tf.random.normal(
    #         shape=(bond, label_size, bond)),
    #          name="output", trainable=True)
    self.l_region = tf.Variable(self._initializer(self.label_site, feature, bond),
                            dtype=dtype, trainable=True)
    self.r_region = tf.Variable(self._initializer(self.label_site, feature, bond),
                             dtype=dtype, trainable=True)
    self.output_site = tf.Variable(self._initializer(label_size, 1, bond)[0],
                              dtype=dtype, trainable=True)
    
  @staticmethod
  def _initializer(n_sites, d_phys, d_bond):
    w = np.stack(d_phys * n_sites * [np.eye(d_bond)])
    w = w.reshape((d_phys, n_sites, d_bond, d_bond))
    return w + np.random.normal(0, 1e-2, size=w.shape)

  def call(self, input_x):
    # Define the contraction.
    # We break it out so we can parallelize a batch using
    # tf.vectorized_map (see below).
    
    # input_x shape in [b,rank,2]
    left = tf.einsum("fsij,bsf->sbij", self.l_region, input_x[:,:self.label_site])
    right = tf.einsum("fsij,bsf->sbij", self.r_region, input_x[:,self.label_site:])
    left = self.reduction(left)
    right = self.reduction(right)
    return tf.einsum("bij,jok,bki->bo", left, self.output_site, right)
    # Now we create the network.
    # l_core = tn.Node(self.l_region) # [s,i,f,j]
    # r_core = tn.Node(self.r_region) # [s,j,f,j]
    # output_core = tn.Node(self.output) # [i,o,j]
    # x_l = tn.Node(x[:self.label_site])  # [s,f]
    # x_r = tn.Node(x[self.label_site:])  # [s,f]

    # The TN should now look like this
    #      |    |.   |
    # >--- a -- c -- b ---<
    #      |         |  
    #.    x1.        x2
    # Now we begin the contraction.

  @staticmethod
  def reduction(tensor):
    size = int(tensor.shape[0])
    while size > 1:
      half_size = size // 2
      nice_size = 2 * half_size
      leftover = tensor[nice_size:]
      tensor = tf.matmul(tensor[0:nice_size:2], tensor[1:nice_size:2])
      tensor = tf.concat([tensor, leftover], axis=0)
      size = half_size + int(size % 2 == 1)
    return tensor[0]


In [20]:
# Build Model
bond_dim = 10
feature = 2
rank = 28**2
label_size = 10

mps_model = tf.keras.Sequential(
    [
     MPSLayer(rank=rank, bond=bond_dim, feature=feature, label_size=label_size),
     tf.keras.layers.Softmax()
     ])

In [21]:
# Data

num_classes = 10
def preprocess_images(x):
  n_data, dim0, dim1 = tuple(x.shape)
  n_sites = dim0 * dim1
  x = x.reshape((n_data, n_sites)) / 255
  x = tf.cast(tf.math.greater(x, 0.5), dtype=tf.int32)
  return tf.keras.utils.to_categorical(x, 2)


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# convert class vectors to binary class matrices
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)
# convert images to supported format
x_train = preprocess_images(x_train)
x_test = preprocess_images(x_test)

# (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# def preprosessing(X):
#     X = X.astype(np.float32).reshape(-1, 28**2) / 255.0
#     return np.stack([X, 1-X], axis=2)

# x_train, x_test = preprosessing(x_train), preprosessing(x_test)
# y_train = tf.keras.utils.to_categorical(y_train, 10)
# y_test = tf.keras.utils.to_categorical(y_test, 10)

# # train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
# # train_ds = train_ds.shuffle(buffer_size=2048).batch(batch_size)
# # test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))

In [22]:
# training
%%time
batch_size = 128
epochs = 30
learning_rate = 1e-4

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

mps_model.compile(
    loss = loss_fn,
    optimizer = optimizer,
    metrics=['accuracy']
)

mps_model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))

score = mps_model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30

KeyboardInterrupt: ignored