In [42]:
import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import jax
import jax.numpy as jnp
import numpy as np

import jaxmao
from jaxmao.layers import Conv2D, SimpleDense, Dense, BatchNorm, ReLU, Flatten, StableSoftmax, BatchNorm2D, DepthwiseConv2D, Activation
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.metrics import Accuracy, Precision, Recall

print('jax.devices() :', jax.devices())

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential
print('tf.config.list_physical_devices(): ', tf.config.list_physical_devices())

seed = 42
key = jax.random.PRNGKey(seed)

tf.keras.backend.set_floatx('float32')

jax.devices() : [CpuDevice(id=0)]
tf.config.list_physical_devices():  [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]


In [43]:
def close_enough(A, B, eps=1e-5):
    return np.less_equal(np.abs(A - B), eps)

# jax.grad vs tf.GradientTape

In [44]:
# Define the function f(x) = x^2
def jax_function(x):
    return jax.lax.pow(x, 2)

def tf_function(x):
    return tf.pow(x, 2)

input = [2.0, 5.4]

# Compute the gradient using JAX
jax_grad = jax.grad(jax_function)
x_jax = jnp.array(input)
grad_jax = jax.vmap(jax_grad)(x_jax)

# Compute the gradient using TensorFlow
x_tf = tf.Variable(input, dtype=tf.float32)
with tf.GradientTape() as tape:
    tape.watch(x_tf)
    y_tf = tf_function(x_tf)
grad_tf = tape.gradient(y_tf, x_tf).numpy()

# Compare
print("JAX Gradient:", grad_jax)
print("TensorFlow Gradient:", grad_tf)
print("Are they close enough?", np.isclose(grad_jax, grad_tf, atol=1e-6))

JAX Gradient: [ 4.  10.8]
TensorFlow Gradient: [ 4.  10.8]
Are they close enough? [ True  True]


# simple module

In [45]:
X_train = np.random.normal(2, 4, (200, 8, 8, 1)).astype('float32')
y_train = np.random.randint(0, 10, (200,)).astype('float32')
y_train_enc = np.array(
    jax.nn.one_hot(y_train, num_classes=10)
)

In [46]:
class DenseMNISTClasifier(Module):
    def __init__(self):
        super().__init__()
        self.add('bn1', jaxmao.layers.BatchNorm2D(1, momentum=0.99, eps=1e-5))
        self.add('conv1', jaxmao.layers.Conv2D(1, 4, (3, 3), (1,1), 'relu', weights_initializer=jaxmao.initializers.GlorotNormal()))
        self.add('flatten', jaxmao.layers.Flatten())
        self.add('dense1', jaxmao.layers.Dense(8*8*4, 128, 'relu', weights_initializer=jaxmao.initializers.GlorotNormal()))
        self.add('dense2', jaxmao.layers.Dense(128, 32, 'relu', weights_initializer=jaxmao.initializers.GlorotNormal()))
        self.add('dense3', jaxmao.layers.Dense(32, 10, 'softmax', weights_initializer=jaxmao.initializers.GlorotNormal()))
    
    def forward(self, params, x, state):
        x, state = self.apply(params, x, 'bn1', state)
        x, state = self.apply(params, x, 'conv1', state)
        x, state = self.apply(params, x, 'flatten', state)
        x, state = self.apply(params, x, 'dense1', state)
        x, state = self.apply(params, x, 'dense2', state)
        x, state = self.apply(params, x, 'dense3', state)
        return x, state
    
jaxmao_model = DenseMNISTClasifier()
jaxmao_model.init_params(key)
summary = jaxmao_model.summarize(input_shape=(4, 8, 8, 1))

print('\n\n')
# Initialize the Sequential model
keras_model = keras.Sequential(name='keras_denseMNIST')

# Add layers to the keras_model
keras_model.add(keras.layers.BatchNormalization(momentum=0.99, input_shape=(8, 8, 1), epsilon=1e-5))
keras_model.add(keras.layers.Conv2D(filters=4, kernel_size=(3, 3), strides=(1, 1), activation='relu', padding='same'))
keras_model.add(keras.layers.Flatten())
keras_model.add(keras.layers.Dense(units=128, activation='relu'))
keras_model.add(keras.layers.Dense(units=32, activation='relu'))
keras_model.add(keras.layers.Dense(units=10, activation='softmax'))

# Summary of the keras_model to show the architecture
keras_model.summary()

layer                output shape         #'s params           #'s states          
bn1                  (4, 8, 8, 1)         2                    0                   
conv1                (4, 8, 8, 4)         40                   0                   
flatten              (4, 256)             0                    0                   
dense1               (4, 128)             32896                0                   
dense2               (4, 32)              4128                 0                   
dense3               (4, 10)              330                  0                   

total parameters: 37396



Model: "keras_denseMNIST"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 batch_normalization_3 (Bat  (None, 8, 8, 1)           4         
 chNormalization)                                                
                                                                 
 conv2d_3 (Conv2D)          

In [47]:
for index, (jaxmao_layer, keras_layer) in enumerate(zip(jaxmao_model.layers.values(), keras_model.layers)):
    if isinstance(jaxmao_layer, jaxmao.layers.BatchNorm):
        keras_model.layers[index].set_weights(([np.array(jaxmao_layer.params['gamma']), 
                                                np.array(jaxmao_layer.params['beta']), 
                                                np.array(jaxmao_layer.state['running_mean']), 
                                                np.array(jaxmao_layer.state['running_var'])
                                                ]))
    elif isinstance(jaxmao_layer, jaxmao.layers.Conv2D):
        keras_model.layers[index].set_weights([np.array(value) for value in jaxmao_layer.params['conv2d/simple_conv2d'].values()])
        # TODO if the Conv2D has batchnorm.
    elif isinstance(jaxmao_layer, jaxmao.layers.Flatten):
        pass
    elif isinstance(jaxmao_layer, jaxmao.layers.Dense):
        keras_model.layers[index].set_weights([np.array(value) for value in jaxmao_layer.params['dense/simple_dense'].values()])
        # TODO if the Dense has batchnorm.

In [48]:
close_prediction1 = close_enough(keras_model(X_train).numpy(), jaxmao_model(jaxmao_model.params, X_train), 1e-7)
np.prod(close_prediction1.shape), close_prediction1.sum()

(2000, 1973)

### loss grad

In [49]:
loss_fn = jaxmao.losses.CategoricalCrossEntropy('mean_over_batch_size')
optimizer = jaxmao.optimizers.GradientDescent(lr=0.01, params=jaxmao_model.params)

def _loss_fnx(pure_forward, params, x, y, state):
    y_pred, new_state = pure_forward(params, x, state)
    loss = loss_fn(y_pred, y)
    return loss, new_state

loss_and_grad = jax.value_and_grad(_loss_fnx, argnums=1, has_aux=True)
# (loss, new_state), gradients = jax.block_until_ready(
#     loss_and_grad(jaxmao_model.pure_forward, jaxmao_model.params, X_train, y_train_enc, jaxmao_model.state)
# )
loss_fn(jaxmao_model(jaxmao_model.params, X_train), y_train_enc)

Array(5.1540365, dtype=float32)

In [50]:
loss_fn(np.array(keras_model(X_train)), y_train_enc) # keras model, jaxmao loss_fn

Array(5.1540365, dtype=float32)

In [51]:
keras.losses.CategoricalCrossentropy(reduction='sum_over_batch_size')(y_train_enc, keras_model(X_train)).numpy()

5.154037

In [52]:
keras.losses.CategoricalCrossentropy(reduction='sum_over_batch_size')(y_train_enc, jaxmao_model(jaxmao_model.params, X_train)).numpy()

5.1526775

#### JaxMao SGD fit

In [53]:
from sklearn.metrics import accuracy_score

EPOCHS = 20
BATCH_SIZE = 10
NUM_BATCHES = len(X_train) // BATCH_SIZE

for epoch in range(EPOCHS):
    losses = 0.0
    # we can shuffle training set here
    for num_batch in range(NUM_BATCHES):
        i_index = num_batch*BATCH_SIZE
        e_index = (num_batch+1)*BATCH_SIZE
        batch_x = X_train[i_index: e_index]
        batch_y = y_train_enc[i_index: e_index]
        
        (loss, new_state), gradients = loss_and_grad(jaxmao_model.pure_forward, jaxmao_model.params, 
                                                     batch_x, batch_y, jaxmao_model.state)
        jaxmao_model.params, optimizer.state = optimizer(jaxmao_model.params, gradients, optimizer.state)
        jaxmao_model.update_state(new_state)
        
        losses += loss

    pred = jaxmao_model(jaxmao_model.params, X_train)    
    print('epoch {}: loss: {}, accuracy: {}'.format(epoch+1, losses/NUM_BATCHES, accuracy_score(y_train, pred.argmax(axis=1))))

epoch 1: loss: 3.0711724758148193, accuracy: 0.315
epoch 2: loss: 2.032979726791382, accuracy: 0.435
epoch 3: loss: 1.7698472738265991, accuracy: 0.565
epoch 4: loss: 1.5275793075561523, accuracy: 0.635
epoch 5: loss: 1.259604573249817, accuracy: 0.74
epoch 6: loss: 1.0200330018997192, accuracy: 0.79
epoch 7: loss: 0.8040340542793274, accuracy: 0.82
epoch 8: loss: 0.6357768774032593, accuracy: 0.85
epoch 9: loss: 0.4631962776184082, accuracy: 0.91
epoch 10: loss: 0.3440026342868805, accuracy: 0.955
epoch 11: loss: 0.2704353332519531, accuracy: 0.965
epoch 12: loss: 0.1963098645210266, accuracy: 0.98
epoch 13: loss: 0.14424698054790497, accuracy: 0.98
epoch 14: loss: 0.10772375762462616, accuracy: 0.99
epoch 15: loss: 0.08619878441095352, accuracy: 0.995
epoch 16: loss: 0.06618763506412506, accuracy: 1.0
epoch 17: loss: 0.05326313525438309, accuracy: 1.0
epoch 18: loss: 0.04275178909301758, accuracy: 1.0
epoch 19: loss: 0.03586540371179581, accuracy: 1.0
epoch 20: loss: 0.02980282902717

#### Keras SGD fit

In [54]:
keras_model.compile(optimizer=keras.optimizers.SGD(learning_rate=0.01), 
                    loss=keras.losses.CategoricalCrossentropy(reduction='sum_over_batch_size'), 
                    metrics=['accuracy'])
history = keras_model.fit(X_train, y_train_enc, epochs=20, batch_size=10, shuffle=False)

Epoch 1/20


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
