In [24]:
import os

# Set backend env to JAX
os.environ["KERAS_BACKEND"] = "jax"

In [25]:
import numpy as np
import keras_core
from keras_core import layers
from keras_core.utils import to_categorical

In [26]:
num_classes = 10
input_shape = (28, 28, 1)

In [27]:
(x_train, y_train), (x_test, y_test) = keras_core.datasets.mnist.load_data()

In [28]:
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

In [29]:
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [30]:
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [31]:
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

In [32]:
batch_size = 128
epochs = 3

In [33]:
model = keras_core.Sequential(
    [
        layers.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

In [34]:
model.summary()

In [21]:
model.compile(
    loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)

In [22]:
model.fit(
    x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1
)

Epoch 1/3
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 96ms/step - accuracy: 0.7723 - loss: 0.7488 - val_accuracy: 0.9755 - val_loss: 0.0885
Epoch 2/3
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 93ms/step - accuracy: 0.9599 - loss: 0.1300 - val_accuracy: 0.9832 - val_loss: 0.0597
Epoch 3/3
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 94ms/step - accuracy: 0.9711 - loss: 0.0930 - val_accuracy: 0.9875 - val_loss: 0.0470


<keras_core.src.callbacks.history.History at 0x7fbd981fddd0>

In [23]:
score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Test loss: 0.04630282521247864
Test accuracy: 0.9846000075340271
