In [2]:
import sys
sys.path.append("/home/jaxmao/JaxMao")
sys.path

['/home/jaxmao/JaxMao/Example',
 '/usr/lib/python310.zip',
 '/usr/lib/python3.10',
 '/usr/lib/python3.10/lib-dynload',
 '',
 '/home/jaxmao/.local/lib/python3.10/site-packages',
 '/usr/local/lib/python3.10/dist-packages',
 '/usr/lib/python3/dist-packages',
 '/home/jaxmao/JaxMao']

# Setup
- Import packages
- Import JaxMao
- Set seed and key
- Import and prepare data (MNIST)

##### Import Jax, MNIST dataset, utils functions and set seed

In [13]:
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax import random

from tensorflow.keras import datasets
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import shuffle

seed = 42
key = random.PRNGKey(seed)

##### Import JaxMao functions

In [14]:
# Model
from jaxmao.Modules import Module
from jaxmao.Layers import FC, Conv2D, Flatten
from jaxmao.Activations import ReLU, StableSoftmax

# Training
from jaxmao.Optimizers import GradientDescent
from jaxmao.Losses import CategoricalCrossEntropy

In [18]:
y_train

array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

##### Import and prepare MNIST dataset

In [28]:
(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()

X_train = jnp.array(X_train/255., jnp.float32).reshape(-1, 1, 28, 28)
X_test = jnp.array(X_test/255., jnp.float32).reshape(-1, 1, 28, 28)

onehot_encoder = OneHotEncoder()
y_train = onehot_encoder.fit_transform(y_train.reshape(-1,1))
y_test = onehot_encoder.transform(y_test.reshape(-1,1))

# Build our MNIST Classifier

We will build simple Convolution-FC-FC model.

### Define model

In [29]:
class MNIST_Classifier(Module):
    def __init__(self):
        self.conv1 = Conv2D(1, 32, 3, 2) 
        self.flatten = Flatten()
        self.fc1 = FC(32*14*14, 32)
        self.fc2 = FC(32, 10)
        self.relu = ReLU()
        self.softmax = StableSoftmax()
        
    def __call__(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

### Initialize the model and analyze output behavior
- output behavior: shape, value, etc.

In [30]:
clf = MNIST_Classifier()
clf.init_params(key)

**Predict without and training:** Let see what our model's output look like <br>
- The output shape should be (num_inputs, num_classes), where num_classes = 10.
- Since our last layer is softmax, sum of each output should be equaled to one.

In [31]:
x = X_train[:20]
output = clf(x)
print("output shape: ", output.shape)
print()
print("output sum:\n", output.sum(axis=1))

output shape:  (20, 10)

output sum:
 [1.         1.         0.9999999  0.9999999  1.         1.
 0.99999994 1.0000001  1.         1.         0.9999999  1.
 1.         1.0000001  0.99999994 0.99999994 1.         0.9999999
 0.9999999  1.        ]


### Define loss function and jax.grad(loss)

In [32]:
loss = CategoricalCrossEntropy()

def loss_params(params, x, y):
    pred = clf.forward(params, x)
    return loss(pred, y)

grad_loss = jit(value_and_grad(loss_params))

##### Stochastic Gradient Descent training loop
Usually, taking gradient on all data points is not possible. <br>
We cannot put our entire data into the memory all at once. <br> 
<br>
**Stochastic Gradient Descent:**

```Psuedo code
def training_loop(x, y, epochs, learning_rate, batch_size):
    foreach epoch until reach epochs:
        x, y = shuffle(x, y)
        foreach batch until loop through entire dataset (minus the leftover):
            loss, gradient = grad_loss(model.params, batch_x, batch_y)
            model.params = optimizer(model.params, gradients, learning_rate)
    return model.params
```

In [57]:
from jax.experimental import sparse

optimizer = GradientDescent()
def training_loop(x, y, epochs=20, lr=0.01, batch_size=32):
    num_batches = len(x) // batch_size
    
    for epoch in range(epochs):
        x, y = shuffle(x, y)
        for batch_idx in range(num_batches):
            starting_idx = batch_idx * batch_size
            ending_idx = (batch_idx + 1) * batch_size
            batch_x = x[starting_idx:ending_idx]
            batch_y = y[starting_idx:ending_idx].toarray()
            # batch_y = sparse.BCOO.from_scipy_sparse(y[starting_idx:ending_idx])
            
            losses, gradients = grad_loss(clf.params, batch_x, batch_y)
            clf.params = optimizer(clf.params, gradients, lr=lr)
        if (epoch+1) % 5 == 0:
            print("Epoch: {}\tbatch loss: {}".format(epoch+1, losses/batch_size))
    
    return clf.params

In [59]:
params = training_loop(
    X_train, y_train, 
        epochs=15, lr=0.0001, batch_size=128
    )

Epoch: 5	batch loss: 0.22568993270397186
Epoch: 10	batch loss: 0.20340317487716675
Epoch: 15	batch loss: 0.057628169655799866


In [69]:
import numpy as np

s = 20
n = np.random.randint(0, len(X_test)-s)
output = clf.forward(clf.params, X_test[n:n+s])

print("\tPredicted: ", output.argmax(axis=1))
print("\tActual   : ", y_test[n:n+s].argmax(axis=1).ravel(), )
print("\tAccuracy : ", (y_test[n:n+s].toarray().argmax(axis=1) == output.argmax(axis=1)).sum() / s)

	Predicted:  [6 6 4 3 9 8 3 0 1 9 0 5 4 1 9 1 2 7 0 1]
	Actual   :  [[6 6 4 3 8 8 3 0 1 9 0 5 4 1 9 1 2 7 0 1]]
	Accuracy :  0.95


In [72]:
from sklearn.metrics import accuracy_score

s = 4000
n = np.random.randint(0, len(X_test)-s)
print("Accuracy : {}".format(
    accuracy_score(clf(X_test[n:n+s]).argmax(axis=1), y_test[n:n+s].toarray().argmax(axis=1))
))

Accuracy : 0.9635
