In [None]:
!pip list | grep jax

In [None]:
!pip list | grep keras

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

In [None]:
import keras_core as keras
import numpy as np

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

In [None]:
model = keras.Sequential(
        [
          keras.layers.Input(shape=(28, 28, 1)),
          keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
          keras.layers.BatchNormalization(),
          keras.layers.MaxPooling2D(pool_size=(2, 2)),
          keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
          keras.layers.BatchNormalization(),
          keras.layers.MaxPooling2D(pool_size=(2, 2)),
          keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
          keras.layers.BatchNormalization(),
          keras.layers.GlobalAveragePooling2D(),
          keras.layers.Dropout(0.5),
          keras.layers.Dense(256, activation="relu"),
          keras.layers.Dropout(0.5),
          keras.layers.Dense(10, activation="softmax")
        ]
    )

In [None]:
model.summary()

In [None]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

In [None]:
batch_size = 128
epochs = 15

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))


## Dropping in a torch dataset

In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

x_train = train_dataset.data.float() / 255.0
x_test = test_dataset.data.float() / 255.0

# Add a channel dimension and move it to the last position
x_train = x_train.unsqueeze(3)
x_test = x_test.unsqueeze(3)

y_train = train_dataset.targets
y_test = test_dataset.targets

print("x_train shape:", x_train.shape, " Train samples: ", x_train.shape[0])
print("y_train shape:", y_train.shape, "Testing samples: ", x_test.shape[0])

In [None]:
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
)

In [None]:
batch_size = 128
epochs = 15

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))


In [None]:
model.save('model.keras')