# Add JaxMao 'library' to PYTHONPATH

**Will remove this section layer when I know how to add package to actual python's path.**

In [1]:
import sys
sys.path.append("/home/jaxmao/JaxMao")

# Setup
- Import packages
- Import JaxMao
- Set seed and key
- Import and prepare data (MNIST)

##### Import Jax, MNIST dataset, utils functions and set seed

In [2]:
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax import random

from keras import datasets
from keras.utils import to_categorical
from sklearn.utils import shuffle

seed = 42
key = random.PRNGKey(seed)



##### Import JaxMao functions

In [3]:
# Model
from jaxmao.Modules import Module
from jaxmao.Layers import FC, Conv2D, Flatten
from jaxmao.Activations import ReLU, StableSoftmax

# Training
from jaxmao.Optimizers import GradientDescent
from jaxmao.Losses import MeanSquaredError

##### Import and prepare MNIST dataset

In [4]:
(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()

X_train = jnp.array(X_train/255., jnp.float32).reshape(-1, 1, 28, 28)
X_test = jnp.array(X_test/255., jnp.float32).reshape(-1, 1, 28, 28)

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

# Build our MNIST Classifier

We will build simple Convolution-FC-FC model.

### Define model

In [42]:
class MNIST_Classifier(Module):
    def __init__(self):
        self.conv1 = Conv2D(1, 32, 3, 2) 
        self.flatten = Flatten()
        self.fc1 = FC(32*14*14, 32)
        self.fc2 = FC(32, 10)
        self.relu = ReLU()
        self.softmax = StableSoftmax()
        
    def __call__(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

### Initialize the model and analyze output behavior
- output behavior: shape, value, etc.

In [43]:
clf = MNIST_Classifier()
clf.init_params(key)

**Predict without and training:** Let see what our model's output look like <br>
- The output shape should be (num_inputs, num_classes), where num_classes = 10.
- Since our last layer is softmax, sum of each output should equal to one.

In [44]:
x = X_train[:20]
output = clf(x)
print("output shape: ", output.shape)
print()
print("output sum:\n", output.sum(axis=1))

output shape:  (20, 10)

output sum:
 [1.         1.0000001  1.         0.99999994 1.         0.9999999
 1.         1.         1.0000001  1.         0.99999994 1.0000001
 1.0000001  0.99999994 1.         1.         1.0000001  0.99999994
 1.         1.        ]


### Define loss function and jax.grad(loss)

In [45]:
loss = MeanSquaredError()

def loss_params(params, x, y):
    pred = clf.forward(params, x)
    return loss(pred, y)

grad_loss = jit(value_and_grad(loss_params))

### Define optimizer and training loop

##### Gradient Descent on 50 data points.

In [46]:
optimizer = GradientDescent()

def training_loop(epochs=20, lr=0.01):
    losses, gradients = grad_loss(clf.params, X_train[:50], y_train[:50])
    clf.params = optimizer(clf.params, gradients, lr=lr)
    return clf.params

##### Stochastic Gradient Descent training loop
Usually, taking gradient on all data points is not possible. <br>
We cannot put our entire data into the memory all at once. <br> 
<br>
**Stochastic Gradient Descent:**

In [47]:
def training_loop(x, y, epochs=20, lr=0.01, batch_size=32):
    num_batches = len(x) // batch_size
    
    for epoch in range(epochs):
        x, y = shuffle(x, y)
        for batch_idx in range(num_batches):
            starting_idx = batch_idx * batch_size
            ending_idx = (batch_idx + 1) * batch_size
            batch_x = x[starting_idx:ending_idx]
            batch_y = y[starting_idx:ending_idx]
            
            losses, gradients = grad_loss(clf.params, batch_x, batch_y)
            clf.params = optimizer(clf.params, gradients, lr=lr)
        if (epoch+1) % 5 == 0:
            print("Epoch: {}\tLoss: {}".format(epoch+1, losses))
    
    return clf.params

In [48]:
params = training_loop(
    X_train, y_train, 
    epochs=25, lr=0.05, batch_size=128
    )

params = training_loop(
    X_train, y_train, 
    epochs=25, lr=0.01, batch_size=128
    )

Epoch: 5	Loss: 0.007728424854576588
Epoch: 10	Loss: 0.005018957424908876
Epoch: 15	Loss: 0.010880048386752605
Epoch: 20	Loss: 0.0028994835447520018
Epoch: 25	Loss: 0.0032447841949760914
Epoch: 5	Loss: 0.002887452719733119
Epoch: 10	Loss: 0.0009570828406140208
Epoch: 15	Loss: 0.0031617318745702505
Epoch: 20	Loss: 0.005867414176464081
Epoch: 25	Loss: 0.0021143059711903334


In [67]:
import numpy as np

s = 20
n = np.random.randint(0, len(X_test)-s)
output = clf.forward(clf.params, X_test[n:n+s])

print("\tPredicted: ", output.argmax(axis=1))
print("\tActual   : ", y_test[n:n+s].argmax(axis=1), )
print("\tAccuracy : ", (y_test[n:n+s].argmax(axis=1) == output.argmax(axis=1)).sum() / s)

	Predicted:  [3 0 6 0 2 7 6 4 1 2 8 8 7 7 9 7 7 3 7 9]
	Actual   :  [3 0 6 0 2 7 6 6 1 2 8 8 7 7 4 7 7 3 7 4]
	Accuracy :  0.85


In [68]:
from sklearn.metrics import accuracy_score

s = 4000
n = np.random.randint(0, len(X_test)-s)
print("Accuracy : {}".format(
    accuracy_score(clf(X_test[n:n+s]).argmax(axis=1), y_test[n:n+s].argmax(axis=1))
))

Accuracy : 0.73625
