In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.optimizers import Adam


def load_dataset(name):
    if name == "MNIST":
        return tf.keras.datasets.mnist.load_data()
    elif name == "Fashion-MNIST":
        return tf.keras.datasets.fashion_mnist.load_data()
    elif name == "CIFAR-10":
        return tf.keras.datasets.cifar10.load_data()
    else:
        raise ValueError("Unknown dataset")


def preprocess_data(x, y, num_channels=1):
    x = x.astype("float32") / 255.0
    if num_channels == 3 and x.shape[-1] != 3:
        x = tf.image.grayscale_to_rgb(tf.expand_dims(x, -1))
    elif num_channels == 1 and len(x.shape) == 3:
        x = x[..., None]
    return x, y


def build_cnn(input_shape, num_classes=10):
    model = Sequential([
        Conv2D(32, (3,3), activation='relu', input_shape=input_shape),
        MaxPooling2D(2,2),
        Conv2D(64, (3,3), activation='relu'),
        MaxPooling2D(2,2),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(num_classes, activation='softmax')
    ])
    model.compile(optimizer=Adam(),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

datasets = ["MNIST", "Fashion-MNIST", "CIFAR-10"]
results = {}


2026-02-23 15:19:09.493622: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-23 15:19:09.495219: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-23 15:19:09.500073: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-23 15:19:09.517496: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771838349.548439  226726 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771838349.55

In [None]:
for dataset_name in datasets:
    (x_train, y_train), (x_test, y_test) = load_dataset(dataset_name)
    channels = 1 if dataset_name != "CIFAR-10" else 3
    x_train, y_train = preprocess_data(x_train, y_train, num_channels=channels)
    x_test, y_test = preprocess_data(x_test, y_test, num_channels=channels)

    model = build_cnn(x_train.shape[1:], num_classes=10)
    print(f"\nTraining on {dataset_name}...")
    model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=15, batch_size=64, verbose=2)
    
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
    print(f"Test Accuracy on {dataset_name}: {test_acc:.4f}")
    results[dataset_name] = test_acc

print("\nFinal Test Accuracies:", results)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2026-02-23 15:19:19.996518: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)



Training on MNIST...
Epoch 1/15
938/938 - 20s - 21ms/step - accuracy: 0.9524 - loss: 0.1602 - val_accuracy: 0.9812 - val_loss: 0.0530
Epoch 2/15
938/938 - 17s - 19ms/step - accuracy: 0.9847 - loss: 0.0495 - val_accuracy: 0.9884 - val_loss: 0.0348
Epoch 3/15
938/938 - 19s - 20ms/step - accuracy: 0.9894 - loss: 0.0336 - val_accuracy: 0.9861 - val_loss: 0.0436
Epoch 4/15
938/938 - 18s - 19ms/step - accuracy: 0.9918 - loss: 0.0252 - val_accuracy: 0.9899 - val_loss: 0.0310
Epoch 5/15
938/938 - 19s - 20ms/step - accuracy: 0.9937 - loss: 0.0202 - val_accuracy: 0.9885 - val_loss: 0.0351
Epoch 6/15
938/938 - 17s - 18ms/step - accuracy: 0.9950 - loss: 0.0150 - val_accuracy: 0.9915 - val_loss: 0.0261
Epoch 7/15
938/938 - 21s - 22ms/step - accuracy: 0.9962 - loss: 0.0114 - val_accuracy: 0.9898 - val_loss: 0.0341
Epoch 8/15
938/938 - 12s - 12ms/step - accuracy: 0.9962 - loss: 0.0116 - val_accuracy: 0.9911 - val_loss: 0.0275
Epoch 9/15
938/938 - 12s - 12ms/step - accuracy: 0.9978 - loss: 0.0071 - v