# Classification for Linearly Separable dataset

Data from https://pennylane.ai/datasets/

In [None]:
import pennylane as qml
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax.nn.initializers import *
import optax
from tn4ml.embeddings import *
from tn4ml.util import *
from tn4ml.models.model import *
from tn4ml.models.smpo import *
from tn4ml.initializers import *
from tn4ml.loss import *

**Load Dataset**

In [None]:
[ds] = qml.data.load("other", name="linearly-separable")

inputs = np.array(ds.train['2']['inputs']) # points in 4-dimensional space
labels = np.array(ds.train['2']['labels']).astype(int) # labels for the points above

test_inputs = np.array(ds.test['2']['inputs']) # points in 4-dimensional space
test_labels = np.array(ds.test['2']['labels']).astype(int) # labels for the points above

In [None]:
test_labels.shape

In [None]:
# to range [0-1]
inputs = (inputs - np.min(inputs, axis=0)) / (np.max(inputs, axis=0) - np.min(inputs, axis=0))

In [None]:
plt.figure(figsize=(8, 6))
plt.scatter(inputs[labels==1][0], inputs[labels==1][1], color='blue', label='Label 1')
plt.scatter(inputs[labels==-1][0], inputs[labels==-1][1], color='red', label='Label -1')

# Adding labels and title
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.title('Visualization of Linearly Separable Data')
plt.legend()
plt.show()

In [None]:
labels = integer_to_one_hot(labels, 2)
test_labels = integer_to_one_hot(test_labels, 2)

In [None]:
train_size = inputs.shape[0]
val_perc = 0.2

In [None]:
# take val_size samples from normal class for validation (X% of training data)
val_size = int(val_perc*train_size)
train_size = int(train_size - val_size)

In [None]:
indices = list(range(len(inputs)))
np.random.shuffle(indices)

train_indices = indices[:train_size]
val_indices = indices[train_size : train_size+val_size]

# train data and validation inputs
train_inputs = np.take(inputs, train_indices, axis=0)
val_inputs = np.take(inputs, val_indices, axis=0)


# train data and validation labels
train_targets = np.take(labels, train_indices, axis=0)
val_targets = np.take(labels, val_indices, axis=0)

**Define TN model**

In [None]:
# model parameters
n_classes = 2
L = 2
initializer = jax.nn.initializers.normal(0.5)
key = jax.random.key(42)
shape_method = 'noteven'
bond_dim = 4
phys_dim = (2, n_classes)
spacing = L

In [None]:
model = SMPO_initialize(L=L,
                        initializer=initializer,
                        key=key,
                        shape_method=shape_method,
                        spacing=spacing,
                        bond_dim=bond_dim,
                        phys_dim=phys_dim,
                        cyclic=False)

In [None]:
def cross_entropy(*args, **kwargs):
    return loss_wrapper_optax(optax.softmax_cross_entropy)(*args, **kwargs)[0]

In [None]:
# training parameters
optimizer = optax.adam
strategy = 'global'
loss = cross_entropy
train_type = 1
#embedding = basis_quantum_encoding(basis={0: np.array([1, 0]), 1: np.array([0, 1])})
embedding = trigonometric()
learning_rate = 1e-3

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=0.0001,
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transforms = [
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
]

In [None]:
model.configure(optimizer=optimizer, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate)

In [None]:
epochs = 300
batch_size = 32

In [None]:
# early stopping from flax
from flax.training.early_stopping import EarlyStopping

earlystop = EarlyStopping(min_delta=0, patience=5)

In [None]:
train_inputs.shape

In [None]:
history = model.train(train_inputs.reshape(train_inputs.shape[0], 2),
                    targets = train_targets,
                    val_inputs=val_inputs.reshape(val_inputs.shape[0], 2),
                    val_targets = val_targets,
                    epochs = epochs,
                    batch_size = batch_size,
                    embedding = embedding,
                    earlystop=earlystop,
                    normalize = True,
                    cache=True,
                    dtype = jnp.float64)

In [None]:
plt.figure()
plt.plot(range(len(model.history['loss'])), model.history['loss'], label='train')
plt.plot(range(len(model.history['val_loss'])), model.history['val_loss'], label='val')
plt.legend()
plt.show()

**Evaluate**

In [None]:
from tn4ml.models.model import _batch_iterator

In [None]:
batch_size = 10
correct_predictions = 0; total_loss = 0

for batch_data in _batch_iterator(test_inputs.reshape(test_inputs.shape[0], 2), test_labels, batch_size=batch_size):
    x, y = batch_data
    x = jnp.array(x, dtype=jnp.float64)
    y = jnp.array(y)

    y_pred = jnp.squeeze(jnp.array(jax.vmap(model.predict, in_axes=(0, None, None))(x, embedding, False)[0]))
    y_pred
    predicted = jnp.argmax(y_pred, axis=-1)
    true = jnp.argmax(y, axis=-1)

    correct_predictions += jnp.sum(predicted == true).item() / batch_size

accuracy = correct_predictions / (len(test_inputs)//batch_size)
print(f"Accuracy: {accuracy}")