In [1]:
import sys
sys.path.append('/home/jaxmao/jaxmao_branches/JaxMao/')
sys.path.append('/home/jaxmao/jaxmao_branches/JaxMao/jaxmao')

import jax
from jaxmao.modules import Module
from jaxmao.layers import Dense
from jaxmao.losses import CategoricalCrossEntropy
from jaxmao.optimizers import GradientDescent
from jaxmao.utils_struct import _check_dict_ids

from sklearn.datasets import load_digits

In [2]:
images, targets = load_digits(return_X_y=True)

images = images / images.max()
targets_enc = jax.nn.one_hot(targets, num_classes=10)

I0000 00:00:1698215200.508901   37071 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [3]:
class Classifier(Module):
    def __init__(self):
        super().__init__()
        self.dense1 = Dense(64, 32, activation='relu', batch_norm=True) # Dense apply fc-bn-activ
        self.dense2 = Dense(32, 10, activation='softmax', batch_norm=True)

    def pure_forward(self, params, x, state):
        x, state = self.forward(self.dense1, params, x, state)
        x, state = self.forward(self.dense2, params, x, state)
        return x, state
    
    """ equivalent
    def pure_forward(self, params, x, state):
        x, state['dense1'] = self.dense1.forward(params['dense1'], x, state['dense1'])
        x, state['dense2'] = self.dense2.forward(params['dense2'], x, state['dense2'])
        return x, state
    """

seed = 4
key = jax.random.key(seed)

clf = Classifier()
clf.init_params(key)

training loop

In [4]:
clf.switch_mode('train')

In [5]:
loss_fn = CategoricalCrossEntropy(reduce_fn='mean_over_batch_size')
optimizer = GradientDescent(lr=0.01, params=clf.params)

def loss_fn_wrapped(method, params, x, y, state):
    y_pred, new_state = method(params, x, state)
    loss = loss_fn(y_pred, y)
    return loss, new_state
loss_and_grad = jax.value_and_grad(loss_fn_wrapped, argnums=1, has_aux=True)

EPOCHS = 5
BATCH_SIZE = 128    
NUM_BATCHES = len(images) // BATCH_SIZE
for epoch in range(EPOCHS):
    total_losses = 0.0
    for n in range(BATCH_SIZE): 
        (loss, new_state), gradients = loss_and_grad(clf.pure_forward, clf.params, images, targets_enc, clf.state)
        new_params, optimizer.state = optimizer.step(clf.params, gradients, optimizer.state)            
        clf.update_params(new_params)
        clf.update_state(new_state)
        total_losses += loss
    print('epoch: {} - avg_loss: {} '.format(epoch+1, total_losses/NUM_BATCHES))

epoch: 1 - avg_loss: 15.675487518310547 
epoch: 2 - avg_loss: 9.428189277648926 
epoch: 3 - avg_loss: 7.113495826721191 
epoch: 4 - avg_loss: 5.7458367347717285 
epoch: 5 - avg_loss: 4.820437908172607 


result

In [7]:
clf.switch_mode('inference')

In [8]:
from sklearn.metrics import accuracy_score, precision_score, recall_score

y_pred = clf(images).argmax(axis=1)
accuracy = accuracy_score(targets, y_pred)
precision = precision_score(targets, y_pred, average='macro')
recall = recall_score(targets, y_pred, average='macro')
print('Accuracy : {:<.6f}'.format(accuracy))
print('Precision: {:<.6f}'.format(precision))
print('Recall   : {:<.6f}'.format(recall))

Accuracy : 0.940456
Precision: 0.940922
Recall   : 0.940144


ids are consistent. There are no copies of same items.

In [9]:
print('params dense1: ', _check_dict_ids(clf.params['dense1'], clf.layers['dense1'].params))
print('params dense2: ', _check_dict_ids(clf.params['dense2'], clf.layers['dense2'].params))

print('state dense1: ', _check_dict_ids(clf.state['dense1'], clf.layers['dense1'].state))
print('state dense2: ', _check_dict_ids(clf.state['dense2'], clf.layers['dense2'].state))

params dense1:  True
params dense2:  True
state dense1:  True
state dense2:  True


# compare to Keras

In [10]:
from tensorflow import keras
targets_enc = keras.utils.to_categorical(targets)

2023-10-25 13:26:55.048276: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-25 13:26:55.048336: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-25 13:26:55.048412: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [11]:
model = keras.Sequential([
            keras.layers.Dense(32, use_bias=False),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('relu'),
            keras.layers.Dense(10, use_bias=False),
            keras.layers.BatchNormalization(),
            keras.layers.Activation('softmax'),
        ])

2023-10-25 13:26:57.359137: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [12]:
model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.01, momentum=0),
    loss=keras.losses.CategoricalCrossentropy(reduction='sum')
)

model.fit(
        images, targets_enc,
        epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=False
    )

Epoch 1/5


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7fc4571405d0>

In [13]:
y_pred = model.predict(images).argmax(axis=1)
targets = targets
accuracy = accuracy_score(targets, y_pred)
precision = precision_score(targets, y_pred, average='macro')
recall = recall_score(targets, y_pred, average='macro')
print('Accuracy : {:<.6f}'.format(accuracy))
print('Precision: {:<.6f}'.format(precision))
print('Recall   : {:<.6f}'.format(recall))

 1/57 [..............................] - ETA: 5s

Accuracy : 0.937117
Precision: 0.952420
Recall   : 0.936503
