In [3]:
# This script needs these libraries to be installed:
#   tensorflow, numpy

import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

import random
import numpy as np
import tensorflow as tf


# Start a run, tracking hyperparameters
wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",

    # track hyperparameters and run metadata with wandb.config
    config={
        "layer_1": 512,
        "activation_1": "relu",
        "dropout": random.uniform(0.01, 0.80),
        "layer_2": 10,
        "activation_2": "softmax",
        "optimizer": "sgd",
        "loss": "sparse_categorical_crossentropy",
        "metric": "accuracy",
        "epoch": 8,
        "batch_size": 256
    }
)

# [optional] use wandb.config as your config
config = wandb.config

# get the data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, y_train = x_train[::5], y_train[::5]
x_test, y_test = x_test[::20], y_test[::20]
labels = [str(digit) for digit in range(np.max(y_train) + 1)]

# build a model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(config.layer_1, activation=config.activation_1),
    tf.keras.layers.Dropout(config.dropout),
    tf.keras.layers.Dense(config.layer_2, activation=config.activation_2)
    ])

# compile the model
model.compile(optimizer=config.optimizer,
              loss=config.loss,
              metrics=[config.metric]
              )

# WandbMetricsLogger will log train and validation metrics to wandb
# WandbModelCheckpoint will upload model checkpoints to wandb
history = model.fit(x=x_train, y=y_train,
                    epochs=config.epoch,
                    batch_size=config.batch_size,
                    validation_data=(x_test, y_test),
                    callbacks=[
                      WandbMetricsLogger(log_freq=5),
                      WandbModelCheckpoint("models.keras", )
                    ])

# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

Epoch 1/8


  super().__init__(**kwargs)
I0000 00:00:1737629942.635320    8076 service.cc:148] XLA service 0x7f867c0060f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737629942.635773    8076 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3080, Compute Capability 8.6
2025-01-23 19:59:02.653719: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1737629942.697961    8076 cuda_dnn.cc:529] Loaded cuDNN version 90600





[1m41/47[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m0s[0m 3ms/step - accuracy: 0.1284 - loss: 2.3860

I0000 00:00:1737629944.996011    8076 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.





[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 88ms/step - accuracy: 0.1356 - loss: 2.3664 - val_accuracy: 0.5740 - val_loss: 1.8697
Epoch 2/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.3591 - loss: 1.9035 - val_accuracy: 0.7200 - val_loss: 1.5478
Epoch 3/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 11ms/step - accuracy: 0.4985 - loss: 1.6239 - val_accuracy: 0.7760 - val_loss: 1.3073
Epoch 4/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5948 - loss: 1.4106 - val_accuracy: 0.7940 - val_loss: 1.1290
Epoch 5/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6449 - loss: 1.2563 - val_accuracy: 0.8020 - val_loss: 0.9974
Epoch 6/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.6711 - loss: 1.1357 - val_accuracy: 0.8100 - val_loss: 0.8960
Epoch 7/8
[1m47/47[0m [32m━━━━━━━━━━━━━━━━━━━━[

0,1
batch/accuracy,▁▁▂▂▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇████████
batch/batch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
batch/learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss,██▇▇▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
epoch/accuracy,▁▄▅▆▇▇██
epoch/epoch,▁▂▃▄▅▆▇█
epoch/learning_rate,▁▁▁▁▁▁▁▁
epoch/loss,█▆▄▃▂▂▁▁
epoch/val_accuracy,▁▅▇▇▇▇██
epoch/val_loss,█▆▄▃▂▂▁▁

0,1
batch/accuracy,0.7269
batch/batch_step,395.0
batch/learning_rate,0.01
batch/loss,0.95456
epoch/accuracy,0.72592
epoch/epoch,7.0
epoch/learning_rate,0.01
epoch/loss,0.95542
epoch/val_accuracy,0.83
epoch/val_loss,0.76051
