In [None]:
import wandb
import time
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from wandb.integration.keras import WandbMetricsLogger  # ✅ Import the new Wandb callback

# Generate a unique name for each training run using timestamp
project_name = f"mnist-execution-{int(time.time())}"

# Initialize Weights & Biases
wandb.init(project=project_name, group=project_name, config={"epochs": 10, "batch_size": 32})

# Load MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Reshape for CNN input (add channel dimension) and normalize
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255

# One-hot encode labels for 10 classes
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Build CNN model
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # Conv Layer 1
    layers.MaxPooling2D((2, 2)),                                            # Pooling 1
    layers.Conv2D(64, (3, 3), activation='relu'),                           # Conv Layer 2
    layers.MaxPooling2D((2, 2)),                                            # Pooling 2
    layers.Flatten(),                                                      # Flatten to 1D
    layers.Dense(128, activation='relu'),                                  # Dense Layer
    layers.Dropout(0.5),                                                   # Dropout for regularization
    layers.Dense(10, activation='softmax')                                 # Output layer
])

# Compile model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Print model architecture
model.summary()

# Train model with Wandb callback
history = model.fit(
    x_train, y_train,
    epochs=10,
    batch_size=128,
    validation_split=0.2,
    callbacks=[WandbMetricsLogger()]  # ✅ Logging metrics to Weights & Biases
)

# Evaluate model on test set
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"Test Accuracy: {test_acc:.4f}")
wandb.log({"test_accuracy": test_acc, "test_loss": test_loss})

[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice:[34m[1mwandb[0m: Enter your choice: