In [None]:
#Note: Code doesn't work as expected
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers, models
from tensornetwork.matrixproductstates.finite_mps import FiniteMPS
from tensorflow.keras.preprocessing.image import smart_resize

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Add a channel dimension to the grayscale images
x_train = np.expand_dims(x_train, axis=-1)  # Shape becomes (60000, 28, 28, 1)
x_test = np.expand_dims(x_test, axis=-1)    # Shape becomes (10000, 28, 28, 1)

# Resize to 28x28 instead of 28x28 for all samples
x_train = np.array([smart_resize(img, (28, 28)) for img in x_train])
x_test = np.array([smart_resize(img, (28, 28)) for img in x_test])

# Normalize the pixel values to the range [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0

# Reshape to (batch_size, 28, 28)
x_train = x_train.reshape(-1, 28, 28)
x_test = x_test.reshape(-1, 28, 28)

# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Define the TensorNetwork model
def create_tn_model(input_shape, bond_dim, output_dim):
    # Initialize a random FiniteMPS
    mps = FiniteMPS.random(d=[input_shape[1]] * input_shape[0], 
                           D=[bond_dim] * (input_shape[0] - 1), 
                           dtype=np.float32, 
                           canonicalize=True,
                           backend='numpy')
    
    # Define a custom layer to use the MPS
    class MPSLayer(tf.keras.layers.Layer):
        def __init__(self, mps, output_dim):
            super(MPSLayer, self).__init__()
            self.mps = mps
            self.output_dim = output_dim
            
        def call(self, inputs):
            # Directly operate on the TensorFlow tensors
            batch_size = tf.shape(inputs)[0]
    
            results = tf.reduce_sum(inputs, axis=[1, 2])
            results = tf.reshape(results, [batch_size, 1])
            return results

        def compute_output_shape(self, input_shape):
            return (input_shape[0], 1)

    # Build the model
    inputs = tf.keras.Input(shape=input_shape)
    x = MPSLayer(mps, output_dim)(inputs)
    
    # Reshape to match the expected input shape of the dense layer
    x = tf.keras.layers.Flatten()(x)  # Flatten to (batch_size, 1)
    
    # Add a Dense layer with the correct input size
    outputs = layers.Dense(output_dim, activation='softmax')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Parameters
input_shape = (28, 28)  # 28 segments of size 28
bond_dim = 10
output_dim = 10

# Create, compile, and train the model
model = create_tn_model(input_shape, bond_dim, output_dim)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

# Train the model
model.fit(x_train, y_train, epochs=50, batch_size=100, validation_data=(x_test, y_test))

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test Accuracy: {test_acc * 100:.2f}%')


Epoch 1/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - accuracy: 0.0892 - loss: 33.2178 - val_accuracy: 0.0998 - val_loss: 2.4771
Epoch 2/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 836us/step - accuracy: 0.1091 - loss: 2.3051 - val_accuracy: 0.1137 - val_loss: 2.2701
Epoch 3/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - accuracy: 0.1244 - loss: 2.2608 - val_accuracy: 0.1129 - val_loss: 2.2449
Epoch 4/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 910us/step - accuracy: 0.1606 - loss: 2.2315 - val_accuracy: 0.2007 - val_loss: 2.2165
Epoch 5/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 875us/step - accuracy: 0.2080 - loss: 2.2059 - val_accuracy: 0.2183 - val_loss: 2.1915
Epoch 6/50
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - accuracy: 0.2147 - loss: 2.1864 - val_accuracy: 0.2152 - val_loss: 2.1774
Epoch 7/50
[1m600/60