Mnist Keras example : [Here](https://keras.io/examples/vision/mnist_convnet/)

In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
#os.environ["KERAS_BACKEND"] = "torch"


In [2]:
import numpy as np
import keras
from keras import layers


In [3]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [4]:
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()


In [5]:
batch_size = 128
epochs = 10

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)


Epoch 1/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 13ms/step - accuracy: 0.8886 - loss: 0.3627 - val_accuracy: 0.9777 - val_loss: 0.0809
Epoch 2/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.9661 - loss: 0.1086 - val_accuracy: 0.9843 - val_loss: 0.0561
Epoch 3/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.9740 - loss: 0.0834 - val_accuracy: 0.9863 - val_loss: 0.0451
Epoch 4/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.9787 - loss: 0.0674 - val_accuracy: 0.9877 - val_loss: 0.0417
Epoch 5/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.9812 - loss: 0.0595 - val_accuracy: 0.9887 - val_loss: 0.0413
Epoch 6/10
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.9829 - loss: 0.0557 - val_accuracy: 0.9882 - val_loss: 0.0408
Epoch 7/10
[1m422/422

<keras.src.callbacks.history.History at 0x11efa6de0>

In [6]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])


Test loss: 0.027485515922307968
Test accuracy: 0.9909999966621399


Jax : 12 ms /step
Torch: 20 ms / step ??
Torch on Colab T4: 10 ms / step
Jax on Colab T4: 2ms / step

In [7]:
w = model.layers[0].variables[0].value
w

Array([[[[ 0.13831805, -0.08997356, -0.41121262,  0.02935524,
          -0.27025893,  0.11473007,  0.07417739, -0.06986379,
          -0.03312434, -0.58960044,  0.20869508,  0.16039714,
          -0.15204097,  0.34677064,  0.13520537,  0.01782749,
           0.05888194,  0.11134432,  0.07271089, -0.27778688,
           0.28234974, -0.08048576, -0.03002048, -0.3107495 ,
           0.02368851, -0.31540358,  0.21255097,  0.14913292,
           0.12414618,  0.0819942 , -0.40066648,  0.10615517]],

        [[ 0.01240036, -0.01661164, -0.2378345 , -0.38389987,
           0.14609525,  0.3245014 , -0.30482882,  0.1586864 ,
          -0.10974918, -0.25236815,  0.12905763, -0.0764987 ,
          -0.5518677 ,  0.21232215,  0.13719097,  0.01094102,
          -0.03674712,  0.28724822,  0.02440302, -0.5832697 ,
          -0.20290904, -0.02966259, -0.3770637 , -0.2856497 ,
           0.17963928, -0.13059008,  0.1870545 ,  0.16800481,
          -0.07926358, -0.06054571,  0.07833687,  0.1517742 ]],

  

In [8]:
type(w)

jaxlib._jax.ArrayImpl