In [1]:
import sys
sys.path.append("/home/jaxmao/JaxMao")
print(sys.path)

['/home/jaxmao/JaxMao/Example', '/usr/lib/python310.zip', '/usr/lib/python3.10', '/usr/lib/python3.10/lib-dynload', '', '/home/jaxmao/.local/lib/python3.10/site-packages', '/usr/local/lib/python3.10/dist-packages', '/usr/lib/python3/dist-packages', '/home/jaxmao/JaxMao']


In [2]:
from keras import datasets
from keras.utils import to_categorical
from jax import random, grad
import jax.numpy as jnp

from jaxmao.Modules import Module
from jaxmao.Layers import FC, Conv2D, Flatten
from jaxmao.Activations import ReLU, StableSoftmax

seed = 42
key = random.PRNGKey(seed)

2023-08-13 19:51:15.994132: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-08-13 19:51:16.039941: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-08-13 19:51:16.041051: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [30]:
(X_train, y_train), (X_test, y_test) = datasets.mnist.load_data()
X_train = jnp.reshape(X_train/255., (-1, 1, 28, 28)).astype(jnp.float32)
X_test = jnp.reshape(X_test/255., (-1, 1, 28, 28)).astype(jnp.float32)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

In [64]:
from jax.tree_util import tree_flatten, tree_unflatten, tree_structure


class MNIST_Classifier(Module):
    def __init__(self):
        self.conv1 = Conv2D(1, 32, 3)
        self.flatten = Flatten()
        self.fc1 = FC(32*28*28, 32)
        self.fc2 = FC(32, 10)
        self.relu = ReLU()
        self.softmax = StableSoftmax()
        
        
    def _forward(self, params, x):
        # for layer, new_params in zip(self.layers, params):
        #     layer.params = new_params
        # x = self.relu(self.fc1._forward(self.fc1.params, x))
        # x = self.relu(self.fc2._forward(self.fc2.params, x))
        # x = self.softmax(self.fc3._forward(self.fc3.params, x))
        # return x
        x = self.conv1(self.conv1._forward(params[0], x))
        x = self.flatten(x)
        x = self.relu(self.fc1._forward(params[2], x))
        x = self.softmax(self.fc2._forward(params[3], x))
        return x
    
    def forward(self, params, x):
        # for layer, new_params in zip(self.layers, params):
        #     layer.params = new_params
        # x = self.relu(self.fc1._forward(self.fc1.params, x))
        # x = self.relu(self.fc2._forward(self.fc2.params, x))
        # x = self.softmax(self.fc3._forward(self.fc3.params, x))
        # return x
        
        x = self.relu(self.conv1.forward(params[0], x))
        x = self.flatten(x)
        x = self.relu(self.fc1.forward(params[2], x))
        x = self.softmax(self.fc2.forward(params[3], x))
        return x
    
    def predict(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.softmax(self.fc3(x))
        return x

In [37]:
X_train[:20].shape

(20, 1, 28, 28)

In [65]:
clf = MNIST_Classifier()
clf.init_params(key)

In [66]:
clf.params[0]['weights'].shape

(32, 1, 3, 3)

In [67]:
output = clf.forward(clf.params, X_train[:20])
output.shape, output.sum(axis=1)

((20, 10),
 Array([1.        , 1.        , 1.        , 0.9999999 , 1.0000001 ,
        1.        , 0.99999994, 1.        , 1.0000001 , 1.0000001 ,
        1.        , 0.9999999 , 1.        , 1.        , 1.0000001 ,
        1.        , 1.        , 1.        , 0.99999994, 0.99999994],      dtype=float32))

In [68]:
# val, structure = tree_flatten(clf.params)
# tree_unflatten(structure, val)

In [69]:
from jax import value_and_grad, jit, vmap

In [71]:
from jaxmao.Optimizers import GradientDescent
from jaxmao.Losses import MeanSquaredError
from sklearn.utils import shuffle

loss = MeanSquaredError()
optimizer = GradientDescent()

def loss_params(params, x, y):
    pred = clf.forward(params, x)
    return loss(pred, y)

grad_loss = jit(value_and_grad(loss_params))

# def training_loop(epochs=20, lr=0.01):
#     losses, gradients = grad_loss(clf.params, X_train[:50], y_train[:50])
#     # return losses, gradients
#     clf.params = optimizer(clf.params, gradients, lr=lr)

def training_loop(x, y, epochs=20, lr=0.01, batch_size=32):
    num_batches = len(x) // batch_size
    
    for epoch in range(epochs):
        x, y = shuffle(x, y)
        for batch_idx in range(num_batches):
            starting_idx = batch_idx * batch_size
            ending_idx = (batch_idx + 1) * batch_size
            batch_x = x[starting_idx:ending_idx]
            batch_y = y[starting_idx:ending_idx]
            
            losses, gradients = grad_loss(clf.params, batch_x, batch_y)
            clf.params = optimizer(clf.params, gradients, lr=lr)
        if (epoch+1) % 5 == 0:
            print("Epoch: {}\tLoss: {}".format(epoch+1, loss_params(clf.params, batch_x, batch_y)))
    
    return clf.params

In [73]:
clf.forward(clf.params, X_train[:20]).shape

(20, 10)

In [74]:
# from jax.tree_util import tree_leaves

# s = tree_structure(clf.layers[0].params)
# v = tree_leaves(clf.layers[0].params)
# tree_unflatten(s, v)

In [75]:
params = training_loop(
    X_train, y_train, 
    epochs=21, lr=0.05, batch_size=128
    )

Epoch: 5	Loss: 0.0037350631318986416
Epoch: 10	Loss: 0.0028351321816444397
Epoch: 15	Loss: 0.0029883235692977905
Epoch: 20	Loss: 0.0038615725934505463


In [77]:
output = clf.forward(clf.params, X_train[:20])
(output.shape, output.argmax(axis=1), 
 y_train[:20].argmax(axis=1), 
 (y_train[:20].argmax(axis=1) == output.argmax(axis=1)).astype(int), 
 (y_train[:20].argmax(axis=1) == output.argmax(axis=1)).sum())

((20, 10),
 Array([3, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 1, 3, 6, 1, 9, 2, 8, 6, 7], dtype=int32),
 array([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9]),
 Array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0], dtype=int32),
 Array(16, dtype=int32))

In [84]:
from sklearn.metrics import accuracy_score
accuracy_score(clf(X_train[4000:8000]).argmax(axis=1), y_train[4000:8000].argmax(axis=1))

0.68325

In [79]:
id(params) == id(clf.params)

True