In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalize the images to a range of 0 to 1
train_images, test_images = train_images / 255.0, test_images / 255.0

# Reshape the images to include the channel dimension
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))


In [None]:
def baseline_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Dropout(0.25),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(10, activation='softmax')
    ], name="cnn_baseline")
    return model


def global_avg_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), padding="same", activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.SpatialDropout2D(0.25), 
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), padding="same",  activation='relu'),
        tf.keras.layers.SpatialDropout2D(0.25), 
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(128, (3, 3), padding="same",  activation='relu'),
        tf.keras.layers.SpatialDropout2D(0.25), 
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(256, (3, 3), padding="same",  activation='relu'),
        tf.keras.layers.SpatialDropout2D(0.5), 
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(10, (1, 1), padding="same",  activation='softmax'), 
        tf.keras.layers.GlobalAvgPool2D(),
    ], name="global_avg")
    return model


model = global_avg_model()

In [None]:
model.summary()

In [None]:
optimizer = tf.keras.optimizers.AdamW()
# optimizer = tf.keras.optimizers.Nadam(**{
#     'learning_rate': 0.0007289374526908369,
#     'beta_1': 0.7576824018427424,
#     'beta_2': 0.955563109739185,
#     'ema_momentum': 0.9664519778143205}
# )
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=10,
    zoom_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True
)

datagen.fit(train_images)

In [None]:
# history = model.fit(datagen.flow(train_images, train_labels, batch_size=64),
#                     epochs=30,
#                     validation_data=(test_images, test_labels),
# )
def schedule(epoch, lr):
    # if epoch == 25:
    #     return lr / 2
    if epoch == 30:
        return lr / 10
    return lr

reduce_lr = tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)
history = model.fit(train_images, train_labels, epochs=40, 
                    validation_split=0.25, callbacks=[reduce_lr])

In [None]:
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\nTest accuracy:', test_acc)

In [None]:
import os
model_path = os.path.join("..", "models", model.name)
os.makedirs(model_path, exist_ok=True)
model.save_weights(os.path.join(model_path, ".weights.h5"), overwrite=False)

In [None]:
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()