In [1]:
import tensorwrap as tw
import jax.numpy as jnp
from tensorwrap import nn

In [2]:
# Loading the dataset:
from tensorflow.keras import datasets
(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()



# Preprocessing the data:

In [3]:
# Converting to tensor:
X_train, X_test = tw.convert_to_tensor(X_train, tw.float32)/255., tw.convert_to_tensor(X_test, tw.float32)/255.

In [4]:
# Splitting data between validation and train_set:
X_valid, X_train = X_train[55000:], X_train[:55000]
y_valid, y_train = y_train[55000:], y_train[:55000]

In [5]:
# Creating a class names list:
class_names = ["T-shirt", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle Boot"]

In [6]:
X_train_batched, y_train_batched = tw.experimental.data.Dataset(X_train).batch(128), tw.experimental.data.Dataset(y_train).batch(128)

# Building a Nueral Network:

In [7]:
# Subclassing for all features:
class Classifier(nn.Model):
    def __init__(self, name: str = "Classifier") -> None:
        super().__init__(name)
        self.flatten = nn.layers.Flatten()
        self.hidden_layers = [
            nn.layers.Dense(300),
            nn.activations.ReLU(),
            nn.layers.Dense(100),
            nn.activations.ReLU(),
            nn.layers.Dense(10)
        ]
    
    def call(self, params, inputs):
        x = self.flatten(params, inputs)
        for layer in self.hidden_layers:
            x = layer(params, x)
        return x

In [8]:
model = Classifier()
input_shape = tw.randn((1, 28, 28))
params = model.init_params(input_shape)

In [9]:
# Defining Loss Functions and optimizers and metrics:
loss_fn = nn.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = nn.optimizers.adam(learning_rate=1e-3)
state = optimizer.init(model.trainable_variables)

In [10]:
# Creating a metrics:
metrics = nn.losses.Accuracy()

In [11]:
train_state = nn.models.Train(model, loss_fn, optimizer, metrics)

In [17]:
train_state.train(X_train, y_train, 5, 128, validation_data=(X_valid, y_valid))

Epoch 1/5

Epoch 2/5

Epoch 3/5

Epoch 4/5

Epoch 5/5



In [13]:
train_state.evaluate(X_test, y_test)



# Creating a training loop:

In [14]:
@tw.value_and_grad
def grad_fn(params, X, y):
    pred = model(params, X)
    return loss_fn(y, pred)


@tw.function
def update(params, state, X, y):
    losses, grads = grad_fn(params, X, y)
    updates, state = optimizer.update(grads, state, params)
    params = nn.optimizers.apply_updates(params, updates)
    return params, losses, state

@tw.function
def validation(params, X, y):
    pred = model(params, X)
    return metrics(y, pred)

def train(epochs):
    global state
    for epoch in range(1, epochs+1):
        print(f"Epoch {epoch}/{epochs}")
        for index, (X, y) in enumerate(zip(X_train_batched, y_train_batched)): 
            model.trainable_variables, losses, state = update(model.trainable_variables, state, X, y)
            accuracy = metrics(y, model(model.trainable_variables, X))
            val_metrics = validation(model.trainable_variables, X_valid, y_valid)
            model.loading_animation(X_train_batched.len(), index+1, losses, accuracy, val_metric=val_metrics)
        print("\n")    

In [15]:
train(10)

Epoch 1/10

Epoch 2/10

Epoch 3/10

Epoch 4/10

Epoch 5/10

Epoch 6/10

Epoch 7/10

Epoch 8/10

Epoch 9/10

Epoch 10/10



In [16]:
model.evaluate(X_test, y_test, loss_fn=loss_fn, metric_fn=metrics)

