# MNIST classification with TNs

- 14x14, [$0-1$] range, zig-zag pixel ordering
- embedding = trigonometric
- sweeping strategy (claim that 2-3 sweeps is enough)
- quadratic cost
- bond_dim = 10, 20, 120
- MPS (with output dim = 10)

**Imports**

In [None]:
import numpy as np
import jax.numpy as jnp
import jax
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from jax.nn.initializers import *

from tn4ml.initializers import *
from tn4ml.models.smpo import *
from tn4ml.models.model import *
from tn4ml.embeddings import *
from tn4ml.loss import *
from tn4ml.strategy import *
from tn4ml.util import *

**Load dataset**

In [None]:
train, test = mnist.load_data()

In [None]:
data = {"X": dict(train=train[0], test=test[0]), "y": dict(train=train[1], test=test[1])}

In [None]:
# reduce the size of the image
strides = (2,2) # (2,2) for 14x14 images; (4,4) for 7x7 images
pool_size = (2,2)
pool = tf.keras.layers.MaxPooling2D(pool_size=pool_size, strides=strides, padding="same")

In [None]:
X_pooled = pool(tf.constant(data['X']['train'].reshape(-1,28,28,1))).numpy().reshape(-1,14,14)/255.0
X_pooled_test = pool(tf.constant(data['X']['test'].reshape(-1,28,28,1))).numpy().reshape(-1,14,14)/255.0

In [None]:
# rearagne pixels in zig-zag order (from https://arxiv.org/pdf/1605.05775.pdf)

def zigzag_order(data):
    data_zigzag = []
    for x in data:
        image = []
        for i in x:
            image.extend(i)
        data_zigzag.append(image)
    return np.asarray(data_zigzag)

In [None]:
train_data = zigzag_order(X_pooled)
test_data = zigzag_order(X_pooled_test)

In [None]:
n_classes = 10

In [None]:
y_train = integer_to_one_hot(data['y']['train'], n_classes)
y_test = integer_to_one_hot(data['y']['test'], n_classes)

**Take samples for training, validation and testing**

In [None]:
train_size = 6000
test_size = 10000
val_perc = 0.2

In [None]:
# take val_size samples from normal class for validation (X% of training data)
val_size = int(val_perc*train_size)
train_size = int(train_size - val_size)

In [None]:
val_size, train_size

In [None]:
indices = list(range(len(train_data)))
np.random.shuffle(indices)

train_indices = indices[:train_size]
val_indices = indices[train_size : train_size+val_size]

# train data and validation inputs
train_inputs = np.take(train_data, train_indices, axis=0)
val_inputs = np.take(train_data, val_indices, axis=0)


# train data and validation labels
train_targets = np.take(y_train, train_indices, axis=0)
val_targets = np.take(y_train, val_indices, axis=0)

In [None]:
indices = list(range(len(test_data)))
np.random.shuffle(indices)

test_indices = indices[:test_size]

# test inputs
test_inputs = np.take(test_data, test_indices, axis=0)

# test labels
test_targets = np.take(y_test, test_indices, axis=0)

**Training setup** &nbsp;
- direct gradient descent

In [None]:
# model parameters
L = 196
initializer = jax.nn.initializers.normal(0.5)
key = jax.random.key(42)
shape_method = 'noteven'
bond_dim = 10
phys_dim = (2, n_classes)
spacing = L
canonical_center=0

In [None]:
model = SMPO_initialize(L=L,
                        initializer=initializer,
                        key=key,
                        shape_method=shape_method,
                        spacing=spacing,
                        bond_dim=bond_dim,
                        phys_dim=phys_dim,
                        cyclic=False,
                        canonical_center=canonical_center)

In [None]:
def MSE_loss(*args, **kwargs):
    return loss_wrapper_optax(optax.squared_error)(*args, **kwargs).mean()

**When using cross-entropy loss I need to put reduce_mean**

In [None]:
# training parameters
optimizer = optax.adam
strategy = 'global'
loss = MSE
train_type = 1
#embedding = original_inverse(p=3)
embedding = trigonometric()
learning_rate = 5e-5

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=1e-4,
    transition_steps=1000,
    decay_rate=0.01)

# Combining gradient transforms using `optax.chain`.
gradient_transforms = [
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
]

In [None]:
model.configure(optimizer=optimizer, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate)

In [None]:
epochs = 100
batch_size = 256

In [None]:
history = model.train(train_inputs,
                    targets = train_targets,
                    val_inputs = val_inputs,
                    val_targets = val_targets,
                    epochs = epochs,
                    batch_size = batch_size,
                    embedding = embedding,
                    normalize = True,
                    cache=True,
                    dtype = jnp.float64)

In [None]:
model.configure(gradient_transforms=gradient_transforms)

In [None]:
epochs = 280

In [None]:
history = model.train(train_inputs,
                    targets = train_targets,
                    val_inputs = val_inputs,
                    val_targets = val_targets,
                    epochs = epochs,
                    batch_size = batch_size,
                    embedding = embedding,
                    normalize = True,
                    cache=True,
                    dtype = jnp.float64)

In [None]:
import matplotlib.pyplot as plt
# plot loss
plt.plot(range(len(model.history['loss'])), model.history['loss'], label='train')
plt.plot(range(len(model.history['val_loss'])), model.history['val_loss'], label='validation')
plt.legend()
plt.show()

In [None]:
# save the model
# model.save('model', 'tests/mnist_supervised_model6')

**Evaluate**

In [None]:
from tn4ml.models.model import _batch_iterator

In [None]:
batch_size = 64
correct_predictions = 0; total_loss = 0

for batch_data in _batch_iterator(test_inputs, test_targets, batch_size=batch_size):
    x, y = batch_data
    x = jnp.array(x, dtype=jnp.float64)
    y = jnp.array(y, dtype=jnp.float64)

    y_pred = jnp.squeeze(jnp.array(jax.vmap(model.predict, in_axes=(0, None, None))(x, embedding, False)[0]))
    predicted = jnp.argmax(y_pred, axis=-1)
    true = jnp.argmax(y, axis=-1)

    correct_predictions += jnp.sum(predicted == true).item() / batch_size

accuracy = correct_predictions / (len(test_targets)//batch_size)
print(f"Accuracy: {accuracy}")