In [1]:
!pip install -q keras-core

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [6]:
import numpy as np
import os, pprint, collections

os.environ["KERAS_BACKEND"] = "jax"

# Note that keras_core should only be imported after the backend
# has been configured. The backend cannot be changed once the
# package is imported.
import keras_core as keras

In [7]:
pp = pprint.PrettyPrinter()


In [8]:
import jax
import jax.numpy as jnp
import tensorflow as tf  # just for tf.data
import keras_core as keras  # Keras multi-backend

import numpy as np
from tqdm import tqdm

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


In [9]:
""" Dataset
Classic MNIST, loaded using tf.data
"""

BATCH_SIZE = 192

(x_train, train_labels), (
    x_eval,
    eval_labels,
) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype(
    np.float32
)  # from 28x28 to 28x28 x 1 color channel (B&W)
x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32)

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.repeat()

eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels))
eval_data = eval_data.batch(10000)  # everything as one batch

STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE

In [11]:
# Keras "sequential" model building style
def make_backbone():
    return keras.Sequential(
        [
            keras.layers.Rescaling(
                1.0 / 255.0
            ),  # input images are in the range [0, 255]
            keras.layers.Conv2D(
                filters=12, kernel_size=3, padding="same", use_bias=False
            ),
            keras.layers.BatchNormalization(scale=False, center=True),
            keras.layers.Activation("relu"),
            keras.layers.Conv2D(
                filters=24,
                kernel_size=6,
                padding="same",
                use_bias=False,
                strides=2,
            ),
            keras.layers.BatchNormalization(scale=False, center=True),
            keras.layers.Activation("relu"),
            keras.layers.Conv2D(
                filters=32,
                kernel_size=6,
                padding="same",
                use_bias=False,
                strides=2,
                name="large_k",
            ),
            keras.layers.BatchNormalization(scale=False, center=True),
            keras.layers.Activation("relu"),
        ],
        name="backbone",
    )


def make_model():
    input = keras.Input(shape=[28, 28, 1])
    y = make_backbone()(input)
    y = keras.layers.Flatten()(y)
    y = keras.layers.Dense(200, activation="relu")(y)
    y = keras.layers.Dropout(0.4)(y)
    y = keras.layers.Dense(10, activation="softmax")(y)
    model = keras.Model(inputs=input, outputs=y)
    return model

In [12]:
""" JAX-native distribution with a Keras model
For now, you have to write a custom training loop for this
Note: The features required by jax.sharding are not supported by the Colab TPU
runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs.
"""

if len(jax.local_devices()) < 8:
    raise Exception("This part requires 8 devices to run")
else:
    print("\nIdentified local devices:")
    pp.pprint(jax.local_devices())

# ----------------- Keras ---------------------

# instantiate the model
model = make_model()

# learning rate
lr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6)

# optimizer
optimizer = keras.optimizers.Adam(lr)

# initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


Identified local devices:
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


In [13]:
""" Distribution settings

* Sharding the data on the batch axis
* Replicating all model variables

Note: this implements standard "data parallel" distributed training

* Just for show, sharding the largest convolutional kernel along the
  "channels" axis 4-ways and replicating 2-ways

Note: this does not reflect a best practice but is intended to show
      that you can split a very large kernel across multiple devices
      if you have to
"""

print(
    "\nMostly data-parallel distribution. "
    "Data is sharded across devices while the model is replicated. "
    "For demo purposes, we split the largest kernel 4-ways "
    "(and replicate 2-ways since we have 8 devices)."
)

# ------------------ Jax ----------------------

devices = mesh_utils.create_device_mesh((8,))

# data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
# naming axes of the sharded partition
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)
# all variables will be replicated on all devices
var_mesh = Mesh(devices, axis_names=("_"))
# in NamedSharding, axes that are not mentioned are replicated (all axes here)
var_replication = NamedSharding(var_mesh, P())

# for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices)
large_kernel_mesh = Mesh(
    devices.reshape((-1, 4)), axis_names=(None, "out_chan")
)  # naming axes of the mesh
large_kernel_sharding = NamedSharding(
    large_kernel_mesh, P(None, None, None, "out_chan")
)  # naming axes of the sharded partition



Mostly data-parallel distribution. Data is sharded across devices while the model is replicated. For demo purposes, we split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices).


In [14]:
# ----------------- Keras ---------------------

# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way)
# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias'
special_layer_var = model.get_layer("backbone").get_layer("large_k").kernel

# ------------------ Jax ----------------------
# - accessing variables in Keras lists model.trainable_variables,
# - model.non_trainable_variables and optimizer.variables

# Apply the distribution settings to the model variables
non_trainable_variables = jax.device_put(
    model.non_trainable_variables, var_replication
)
optimizer_variables = jax.device_put(optimizer.variables, var_replication)
# this is what you would do replicate all trainable variables:
# trainable_variables = jax.device_put(model.trainable_variables, var_replication)

# For the demo, we split the largest kernel 4-ways instead of replicating it.
# We still replicate all other trainable variables as in standard "data-parallel"
# distributed training.
print_once = True
trainable_variables = model.trainable_variables
for i, v in enumerate(trainable_variables):
    if v is special_layer_var:
        # Apply distribution settings: sharding
        sharded_v = jax.device_put(v, large_kernel_sharding)
        trainable_variables[i] = sharded_v

        print("Sharding of convolutional", v.name, v.shape)
        jax.debug.visualize_array_sharding(
            jnp.reshape(sharded_v, [-1, v.shape[-1]])
        )
    else:
        # Apply distribution settings: replication
        replicated_v = jax.device_put(v, var_replication)
        trainable_variables[i] = replicated_v

        if print_once:
            print_once = False
            print(
                "\nSharding of all other model variables (they are replicated)"
            )
            jax.debug.visualize_array_sharding(
                jnp.reshape(replicated_v, [-1, v.shape[-1]])
            )

# collect state in a handy named tuple
TrainingState = collections.namedtuple(
    "TrainingState",
    ["trainable_variables", "non_trainable_variables", "optimizer_variables"],
)
device_train_state = TrainingState(
    trainable_variables=trainable_variables,
    non_trainable_variables=non_trainable_variables,
    optimizer_variables=optimizer_variables,
)
# display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28 * 28]))



Sharding of all other model variables (they are replicated)


Sharding of convolutional kernel (6, 6, 24, 32)


Data sharding


In [16]:
# ------------------ Jax ----------------------
# - Using Keras-provided stateless APIs
# - model.stateless_call
# - optimizer.stateless_apply
# These functions also work on other backends.

# define loss
loss = keras.losses.SparseCategoricalCrossentropy()


# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables


# function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)


# Trainig step: Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    (loss_value, non_trainable_variables), grads = compute_gradients(
        train_state.trainable_variables,
        train_state.non_trainable_variables,
        x,
        y,
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        train_state.optimizer_variables, grads, train_state.trainable_variables
    )

    return loss_value, TrainingState(
        trainable_variables, non_trainable_variables, optimizer_variables
    )


# training loop
EPOCHS = 200
print("\nTrainig:")
data_iter = iter(train_data)
for epoch in range(EPOCHS):
    for i in tqdm(range(STEPS_PER_EPOCH)):
        x, y = next(data_iter)
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, device_train_state = train_step(
            device_train_state, sharded_x, y.numpy()
        )
    print("Epoch", epoch, "loss:", loss_value)

# The output of the model is still sharded. Sharding follows the data.

data, labels = next(iter(eval_data))
sharded_data = jax.device_put(data.numpy(), data_sharding)


@jax.jit
def predict(data):
    predictions, updated_non_trainable_variables = model.stateless_call(
        device_train_state.trainable_variables,
        device_train_state.non_trainable_variables,
        data,
    )
    return predictions


predictions = predict(sharded_data)
print("\nModel output sharding follows data sharding:")
jax.debug.visualize_array_sharding(predictions)

# Post-processing model state update to write them back into the model
update = lambda variable, value: variable.assign(value)

jax.tree_map(
    update, model.trainable_variables, device_train_state.trainable_variables
)
jax.tree_map(
    update,
    model.non_trainable_variables,
    device_train_state.non_trainable_variables,
)
jax.tree_map(
    update, optimizer.variables, device_train_state.optimizer_variables
)

# check that the model has the new state by running an eval
# known issue: the optimizer should not be required here
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
print("\nUpdating model and running an eval:")
loss, accuracy = model.evaluate(eval_data)
print("The model achieved an evaluation accuracy of:", accuracy)


Trainig:


100%|██████████| 312/312 [00:04<00:00, 64.95it/s] 


Epoch 0 loss: 0.0051613655


100%|██████████| 312/312 [00:02<00:00, 154.30it/s]


Epoch 1 loss: 0.023073694


100%|██████████| 312/312 [00:02<00:00, 152.40it/s]


Epoch 2 loss: 0.01335797


100%|██████████| 312/312 [00:02<00:00, 151.26it/s]


Epoch 3 loss: 0.004434864


100%|██████████| 312/312 [00:02<00:00, 151.67it/s]


Epoch 4 loss: 0.002694505


100%|██████████| 312/312 [00:02<00:00, 151.03it/s]


Epoch 5 loss: 0.011267702


100%|██████████| 312/312 [00:02<00:00, 150.24it/s]


Epoch 6 loss: 0.005084862


100%|██████████| 312/312 [00:02<00:00, 152.90it/s]


Epoch 7 loss: 0.0037868065


100%|██████████| 312/312 [00:02<00:00, 151.00it/s]


Epoch 8 loss: 0.004354721


100%|██████████| 312/312 [00:02<00:00, 148.06it/s]


Epoch 9 loss: 0.0033651057


100%|██████████| 312/312 [00:02<00:00, 147.88it/s]


Epoch 10 loss: 0.007970343


100%|██████████| 312/312 [00:02<00:00, 147.56it/s]


Epoch 11 loss: 0.0022554111


100%|██████████| 312/312 [00:02<00:00, 145.90it/s]


Epoch 12 loss: 0.0025315904


100%|██████████| 312/312 [00:02<00:00, 146.35it/s]


Epoch 13 loss: 0.0027730335


100%|██████████| 312/312 [00:02<00:00, 148.88it/s]


Epoch 14 loss: 0.09231086


100%|██████████| 312/312 [00:02<00:00, 147.05it/s]


Epoch 15 loss: 0.00050703855


100%|██████████| 312/312 [00:02<00:00, 147.97it/s]


Epoch 16 loss: 0.005900763


100%|██████████| 312/312 [00:02<00:00, 150.30it/s]


Epoch 17 loss: 0.0047382293


100%|██████████| 312/312 [00:02<00:00, 150.39it/s]


Epoch 18 loss: 0.0057516964


100%|██████████| 312/312 [00:02<00:00, 149.29it/s]


Epoch 19 loss: 0.000799091


100%|██████████| 312/312 [00:02<00:00, 147.69it/s]


Epoch 20 loss: 0.0020407606


100%|██████████| 312/312 [00:02<00:00, 150.54it/s]


Epoch 21 loss: 0.0017153182


100%|██████████| 312/312 [00:02<00:00, 144.87it/s]


Epoch 22 loss: 0.003606855


100%|██████████| 312/312 [00:02<00:00, 146.99it/s]


Epoch 23 loss: 0.005014763


100%|██████████| 312/312 [00:02<00:00, 145.06it/s]


Epoch 24 loss: 0.0015248026


100%|██████████| 312/312 [00:02<00:00, 148.57it/s]


Epoch 25 loss: 0.0032154913


100%|██████████| 312/312 [00:02<00:00, 151.67it/s]


Epoch 26 loss: 0.002289772


100%|██████████| 312/312 [00:02<00:00, 149.41it/s]


Epoch 27 loss: 0.0077875424


100%|██████████| 312/312 [00:02<00:00, 150.29it/s]


Epoch 28 loss: 0.0018546786


100%|██████████| 312/312 [00:02<00:00, 149.81it/s]


Epoch 29 loss: 0.003362091


100%|██████████| 312/312 [00:02<00:00, 150.59it/s]


Epoch 30 loss: 0.0019750698


100%|██████████| 312/312 [00:02<00:00, 149.48it/s]


Epoch 31 loss: 0.0077667153


100%|██████████| 312/312 [00:02<00:00, 149.61it/s]


Epoch 32 loss: 0.00509238


100%|██████████| 312/312 [00:02<00:00, 152.86it/s]


Epoch 33 loss: 0.09413421


100%|██████████| 312/312 [00:02<00:00, 150.38it/s]


Epoch 34 loss: 0.010569198


100%|██████████| 312/312 [00:02<00:00, 150.15it/s]


Epoch 35 loss: 0.004389434


100%|██████████| 312/312 [00:02<00:00, 150.04it/s]


Epoch 36 loss: 0.0046832967


100%|██████████| 312/312 [00:02<00:00, 150.29it/s]


Epoch 37 loss: 0.0033779717


100%|██████████| 312/312 [00:02<00:00, 149.25it/s]


Epoch 38 loss: 0.0012988928


100%|██████████| 312/312 [00:02<00:00, 152.23it/s]


Epoch 39 loss: 0.00148514


100%|██████████| 312/312 [00:02<00:00, 153.47it/s]


Epoch 40 loss: 0.08799639


100%|██████████| 312/312 [00:02<00:00, 150.90it/s]


Epoch 41 loss: 0.102935396


100%|██████████| 312/312 [00:02<00:00, 154.20it/s]


Epoch 42 loss: 0.0064084753


100%|██████████| 312/312 [00:02<00:00, 154.79it/s]


Epoch 43 loss: 0.0059722695


100%|██████████| 312/312 [00:02<00:00, 150.49it/s]


Epoch 44 loss: 0.003496557


100%|██████████| 312/312 [00:02<00:00, 151.02it/s]


Epoch 45 loss: 0.0066671125


100%|██████████| 312/312 [00:02<00:00, 151.05it/s]


Epoch 46 loss: 0.0010258653


100%|██████████| 312/312 [00:02<00:00, 150.70it/s]


Epoch 47 loss: 0.001148876


100%|██████████| 312/312 [00:02<00:00, 151.08it/s]


Epoch 48 loss: 0.0022844695


100%|██████████| 312/312 [00:02<00:00, 153.12it/s]


Epoch 49 loss: 0.0019500902


100%|██████████| 312/312 [00:02<00:00, 151.60it/s]


Epoch 50 loss: 0.0025554872


100%|██████████| 312/312 [00:02<00:00, 148.82it/s]


Epoch 51 loss: 0.006012059


100%|██████████| 312/312 [00:02<00:00, 152.54it/s]


Epoch 52 loss: 0.006816498


100%|██████████| 312/312 [00:02<00:00, 152.91it/s]


Epoch 53 loss: 0.012626072


100%|██████████| 312/312 [00:02<00:00, 151.17it/s]


Epoch 54 loss: 0.0020623181


100%|██████████| 312/312 [00:02<00:00, 154.11it/s]


Epoch 55 loss: 0.00055154663


100%|██████████| 312/312 [00:02<00:00, 151.58it/s]


Epoch 56 loss: 0.0005175198


100%|██████████| 312/312 [00:02<00:00, 149.14it/s]


Epoch 57 loss: 0.08928703


100%|██████████| 312/312 [00:02<00:00, 149.05it/s]


Epoch 58 loss: 0.0010914998


100%|██████████| 312/312 [00:02<00:00, 149.39it/s]


Epoch 59 loss: 0.0023103198


100%|██████████| 312/312 [00:02<00:00, 149.00it/s]


Epoch 60 loss: 0.00221428


100%|██████████| 312/312 [00:02<00:00, 148.17it/s]


Epoch 61 loss: 0.00079013285


100%|██████████| 312/312 [00:02<00:00, 145.70it/s]


Epoch 62 loss: 0.009109486


100%|██████████| 312/312 [00:02<00:00, 147.87it/s]


Epoch 63 loss: 0.006482146


100%|██████████| 312/312 [00:02<00:00, 147.95it/s]


Epoch 64 loss: 0.009222313


100%|██████████| 312/312 [00:02<00:00, 146.95it/s]


Epoch 65 loss: 0.0043873936


100%|██████████| 312/312 [00:02<00:00, 146.83it/s]


Epoch 66 loss: 0.0023372069


100%|██████████| 312/312 [00:02<00:00, 147.29it/s]


Epoch 67 loss: 0.01008275


100%|██████████| 312/312 [00:02<00:00, 146.99it/s]


Epoch 68 loss: 0.00083925517


100%|██████████| 312/312 [00:02<00:00, 151.24it/s]


Epoch 69 loss: 0.011575425


100%|██████████| 312/312 [00:02<00:00, 147.08it/s]


Epoch 70 loss: 0.0026201992


100%|██████████| 312/312 [00:02<00:00, 148.46it/s]


Epoch 71 loss: 0.0012501717


100%|██████████| 312/312 [00:02<00:00, 150.57it/s]


Epoch 72 loss: 0.0050113937


100%|██████████| 312/312 [00:02<00:00, 147.58it/s]


Epoch 73 loss: 0.001466934


100%|██████████| 312/312 [00:02<00:00, 147.11it/s]


Epoch 74 loss: 0.0005016887


100%|██████████| 312/312 [00:02<00:00, 149.29it/s]


Epoch 75 loss: 0.0052876724


100%|██████████| 312/312 [00:02<00:00, 146.68it/s]


Epoch 76 loss: 0.005690371


100%|██████████| 312/312 [00:02<00:00, 147.04it/s]


Epoch 77 loss: 0.003983121


100%|██████████| 312/312 [00:02<00:00, 150.29it/s]


Epoch 78 loss: 0.0013184206


100%|██████████| 312/312 [00:02<00:00, 148.29it/s]


Epoch 79 loss: 0.010736575


100%|██████████| 312/312 [00:02<00:00, 147.26it/s]


Epoch 80 loss: 0.019568603


100%|██████████| 312/312 [00:02<00:00, 149.57it/s]


Epoch 81 loss: 0.004313573


100%|██████████| 312/312 [00:02<00:00, 147.69it/s]


Epoch 82 loss: 0.010991074


100%|██████████| 312/312 [00:02<00:00, 150.66it/s]


Epoch 83 loss: 0.0012171332


100%|██████████| 312/312 [00:02<00:00, 152.07it/s]


Epoch 84 loss: 0.00055077707


100%|██████████| 312/312 [00:02<00:00, 153.65it/s]


Epoch 85 loss: 0.0073316023


100%|██████████| 312/312 [00:02<00:00, 146.77it/s]


Epoch 86 loss: 0.0028434987


100%|██████████| 312/312 [00:02<00:00, 147.29it/s]


Epoch 87 loss: 0.00817379


100%|██████████| 312/312 [00:02<00:00, 143.19it/s]


Epoch 88 loss: 0.004961028


100%|██████████| 312/312 [00:02<00:00, 140.57it/s]


Epoch 89 loss: 0.012947354


100%|██████████| 312/312 [00:02<00:00, 138.62it/s]


Epoch 90 loss: 0.001499664


100%|██████████| 312/312 [00:02<00:00, 140.84it/s]


Epoch 91 loss: 0.0037881478


100%|██████████| 312/312 [00:02<00:00, 146.77it/s]


Epoch 92 loss: 0.009549415


100%|██████████| 312/312 [00:02<00:00, 147.53it/s]


Epoch 93 loss: 0.0078542745


100%|██████████| 312/312 [00:02<00:00, 144.17it/s]


Epoch 94 loss: 0.004703642


100%|██████████| 312/312 [00:02<00:00, 140.28it/s]


Epoch 95 loss: 0.014925811


100%|██████████| 312/312 [00:02<00:00, 139.29it/s]


Epoch 96 loss: 0.008275516


100%|██████████| 312/312 [00:02<00:00, 142.86it/s]


Epoch 97 loss: 0.005108869


100%|██████████| 312/312 [00:02<00:00, 144.96it/s]


Epoch 98 loss: 0.019720722


100%|██████████| 312/312 [00:02<00:00, 148.66it/s]


Epoch 99 loss: 0.001974451


100%|██████████| 312/312 [00:02<00:00, 147.79it/s]


Epoch 100 loss: 0.0014513393


100%|██████████| 312/312 [00:02<00:00, 145.89it/s]


Epoch 101 loss: 0.012814565


100%|██████████| 312/312 [00:02<00:00, 146.26it/s]


Epoch 102 loss: 0.005026892


100%|██████████| 312/312 [00:02<00:00, 148.33it/s]


Epoch 103 loss: 0.001023326


100%|██████████| 312/312 [00:02<00:00, 145.51it/s]


Epoch 104 loss: 0.0073186615


100%|██████████| 312/312 [00:02<00:00, 145.34it/s]


Epoch 105 loss: 0.004433358


100%|██████████| 312/312 [00:02<00:00, 148.76it/s]


Epoch 106 loss: 0.001427429


100%|██████████| 312/312 [00:02<00:00, 149.81it/s]


Epoch 107 loss: 0.0007770895


100%|██████████| 312/312 [00:02<00:00, 147.20it/s]


Epoch 108 loss: 0.0011203168


100%|██████████| 312/312 [00:02<00:00, 149.95it/s]


Epoch 109 loss: 0.005042796


100%|██████████| 312/312 [00:02<00:00, 145.29it/s]


Epoch 110 loss: 0.002338284


100%|██████████| 312/312 [00:02<00:00, 143.27it/s]


Epoch 111 loss: 0.0012297763


100%|██████████| 312/312 [00:02<00:00, 147.27it/s]


Epoch 112 loss: 0.008016388


100%|██████████| 312/312 [00:02<00:00, 147.19it/s]


Epoch 113 loss: 0.0013119545


100%|██████████| 312/312 [00:02<00:00, 146.85it/s]


Epoch 114 loss: 0.0035683354


100%|██████████| 312/312 [00:02<00:00, 144.06it/s]


Epoch 115 loss: 0.0020889153


100%|██████████| 312/312 [00:02<00:00, 147.50it/s]


Epoch 116 loss: 0.0018147466


100%|██████████| 312/312 [00:02<00:00, 149.14it/s]


Epoch 117 loss: 0.003004505


100%|██████████| 312/312 [00:02<00:00, 147.44it/s]


Epoch 118 loss: 0.0057763434


100%|██████████| 312/312 [00:02<00:00, 148.48it/s]


Epoch 119 loss: 0.001639887


100%|██████████| 312/312 [00:02<00:00, 148.29it/s]


Epoch 120 loss: 0.0042094085


100%|██████████| 312/312 [00:02<00:00, 145.79it/s]


Epoch 121 loss: 0.0016605941


100%|██████████| 312/312 [00:02<00:00, 146.75it/s]


Epoch 122 loss: 0.0046089813


100%|██████████| 312/312 [00:02<00:00, 146.33it/s]


Epoch 123 loss: 0.00039930813


100%|██████████| 312/312 [00:02<00:00, 144.72it/s]


Epoch 124 loss: 0.0012195


100%|██████████| 312/312 [00:02<00:00, 148.08it/s]


Epoch 125 loss: 0.004652785


100%|██████████| 312/312 [00:02<00:00, 145.83it/s]


Epoch 126 loss: 0.0021859822


100%|██████████| 312/312 [00:02<00:00, 140.12it/s]


Epoch 127 loss: 0.010130579


100%|██████████| 312/312 [00:02<00:00, 144.52it/s]


Epoch 128 loss: 0.0033728941


100%|██████████| 312/312 [00:02<00:00, 147.09it/s]


Epoch 129 loss: 0.0073223854


100%|██████████| 312/312 [00:02<00:00, 144.13it/s]


Epoch 130 loss: 0.0061236704


100%|██████████| 312/312 [00:02<00:00, 147.07it/s]


Epoch 131 loss: 0.0016905258


100%|██████████| 312/312 [00:02<00:00, 145.58it/s]


Epoch 132 loss: 0.002220683


100%|██████████| 312/312 [00:02<00:00, 145.24it/s]


Epoch 133 loss: 0.0012796351


100%|██████████| 312/312 [00:02<00:00, 146.92it/s]


Epoch 134 loss: 0.010886195


100%|██████████| 312/312 [00:02<00:00, 148.37it/s]


Epoch 135 loss: 0.0010391874


100%|██████████| 312/312 [00:02<00:00, 147.79it/s]


Epoch 136 loss: 0.0029520458


100%|██████████| 312/312 [00:02<00:00, 146.66it/s]


Epoch 137 loss: 0.0020032595


100%|██████████| 312/312 [00:02<00:00, 146.14it/s]


Epoch 138 loss: 0.08622976


100%|██████████| 312/312 [00:02<00:00, 146.25it/s]


Epoch 139 loss: 0.109192364


100%|██████████| 312/312 [00:02<00:00, 146.30it/s]


Epoch 140 loss: 0.0012985271


100%|██████████| 312/312 [00:02<00:00, 148.04it/s]


Epoch 141 loss: 0.0012599411


100%|██████████| 312/312 [00:02<00:00, 150.09it/s]


Epoch 142 loss: 0.002931129


100%|██████████| 312/312 [00:02<00:00, 146.44it/s]


Epoch 143 loss: 0.026181644


100%|██████████| 312/312 [00:02<00:00, 150.53it/s]


Epoch 144 loss: 0.008663185


100%|██████████| 312/312 [00:02<00:00, 148.27it/s]


Epoch 145 loss: 0.0047908197


100%|██████████| 312/312 [00:02<00:00, 146.94it/s]


Epoch 146 loss: 0.0014185531


100%|██████████| 312/312 [00:02<00:00, 148.06it/s]


Epoch 147 loss: 0.0027114234


100%|██████████| 312/312 [00:02<00:00, 147.33it/s]


Epoch 148 loss: 0.002854012


100%|██████████| 312/312 [00:02<00:00, 148.56it/s]


Epoch 149 loss: 0.0036366757


100%|██████████| 312/312 [00:02<00:00, 147.35it/s]


Epoch 150 loss: 0.008331113


100%|██████████| 312/312 [00:02<00:00, 151.92it/s]


Epoch 151 loss: 0.0060296264


100%|██████████| 312/312 [00:02<00:00, 152.78it/s]


Epoch 152 loss: 0.0013188244


100%|██████████| 312/312 [00:02<00:00, 149.75it/s]


Epoch 153 loss: 0.0011017967


100%|██████████| 312/312 [00:02<00:00, 150.80it/s]


Epoch 154 loss: 0.0027540554


100%|██████████| 312/312 [00:02<00:00, 150.64it/s]


Epoch 155 loss: 0.016113192


100%|██████████| 312/312 [00:02<00:00, 144.69it/s]


Epoch 156 loss: 0.0049481853


100%|██████████| 312/312 [00:02<00:00, 149.84it/s]


Epoch 157 loss: 0.0024123907


100%|██████████| 312/312 [00:02<00:00, 151.31it/s]


Epoch 158 loss: 0.011917135


100%|██████████| 312/312 [00:02<00:00, 149.83it/s]


Epoch 159 loss: 0.0059766974


100%|██████████| 312/312 [00:02<00:00, 150.95it/s]


Epoch 160 loss: 0.007746275


100%|██████████| 312/312 [00:02<00:00, 149.18it/s]


Epoch 161 loss: 0.0021171786


100%|██████████| 312/312 [00:02<00:00, 148.42it/s]


Epoch 162 loss: 0.0025407393


100%|██████████| 312/312 [00:02<00:00, 148.47it/s]


Epoch 163 loss: 0.0012321952


100%|██████████| 312/312 [00:02<00:00, 151.32it/s]


Epoch 164 loss: 0.0050710808


100%|██████████| 312/312 [00:02<00:00, 152.75it/s]


Epoch 165 loss: 0.0018525075


100%|██████████| 312/312 [00:02<00:00, 150.31it/s]


Epoch 166 loss: 0.0073213195


100%|██████████| 312/312 [00:02<00:00, 152.10it/s]


Epoch 167 loss: 0.0028305366


100%|██████████| 312/312 [00:02<00:00, 153.94it/s]


Epoch 168 loss: 0.001279901


100%|██████████| 312/312 [00:02<00:00, 151.78it/s]


Epoch 169 loss: 0.0007328101


100%|██████████| 312/312 [00:02<00:00, 152.52it/s]


Epoch 170 loss: 0.005517888


100%|██████████| 312/312 [00:02<00:00, 152.51it/s]


Epoch 171 loss: 0.0039075078


100%|██████████| 312/312 [00:02<00:00, 150.07it/s]


Epoch 172 loss: 0.00078372326


100%|██████████| 312/312 [00:02<00:00, 149.90it/s]


Epoch 173 loss: 0.00035636983


100%|██████████| 312/312 [00:02<00:00, 151.01it/s]


Epoch 174 loss: 0.002435878


100%|██████████| 312/312 [00:02<00:00, 146.56it/s]


Epoch 175 loss: 0.0015519147


100%|██████████| 312/312 [00:02<00:00, 147.11it/s]


Epoch 176 loss: 0.00381902


100%|██████████| 312/312 [00:02<00:00, 150.62it/s]


Epoch 177 loss: 0.000856303


100%|██████████| 312/312 [00:02<00:00, 151.77it/s]


Epoch 178 loss: 0.0051490823


100%|██████████| 312/312 [00:02<00:00, 146.69it/s]


Epoch 179 loss: 0.008321327


100%|██████████| 312/312 [00:02<00:00, 152.64it/s]


Epoch 180 loss: 0.002521487


100%|██████████| 312/312 [00:02<00:00, 151.83it/s]


Epoch 181 loss: 0.0009861202


100%|██████████| 312/312 [00:02<00:00, 148.69it/s]


Epoch 182 loss: 0.010210268


100%|██████████| 312/312 [00:02<00:00, 150.27it/s]


Epoch 183 loss: 0.0035283607


100%|██████████| 312/312 [00:02<00:00, 152.55it/s]


Epoch 184 loss: 0.008553078


100%|██████████| 312/312 [00:02<00:00, 151.20it/s]


Epoch 185 loss: 0.00052486436


100%|██████████| 312/312 [00:02<00:00, 152.22it/s]


Epoch 186 loss: 0.0019686937


100%|██████████| 312/312 [00:02<00:00, 148.86it/s]


Epoch 187 loss: 0.0027586464


100%|██████████| 312/312 [00:02<00:00, 147.57it/s]


Epoch 188 loss: 0.0058374805


100%|██████████| 312/312 [00:02<00:00, 148.67it/s]


Epoch 189 loss: 0.0048602917


100%|██████████| 312/312 [00:02<00:00, 148.52it/s]


Epoch 190 loss: 0.0057701813


100%|██████████| 312/312 [00:02<00:00, 149.10it/s]


Epoch 191 loss: 0.003368033


100%|██████████| 312/312 [00:02<00:00, 151.71it/s]


Epoch 192 loss: 0.003662585


100%|██████████| 312/312 [00:02<00:00, 148.22it/s]


Epoch 193 loss: 0.09084996


100%|██████████| 312/312 [00:02<00:00, 147.28it/s]


Epoch 194 loss: 0.005811976


100%|██████████| 312/312 [00:02<00:00, 142.47it/s]


Epoch 195 loss: 0.004816776


100%|██████████| 312/312 [00:02<00:00, 144.73it/s]


Epoch 196 loss: 0.005332428


100%|██████████| 312/312 [00:02<00:00, 144.56it/s]


Epoch 197 loss: 0.0086028315


100%|██████████| 312/312 [00:02<00:00, 142.43it/s]


Epoch 198 loss: 0.0024974593


100%|██████████| 312/312 [00:02<00:00, 145.95it/s]


Epoch 199 loss: 0.0036550718

Model output sharding follows data sharding:



Updating model and running an eval:
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - loss: 0.0406 - sparse_categorical_accuracy: 0.9891
The model achieved an evaluation accuracy of: 0.9890999794006348
