In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
keras.mixed_precision.set_global_policy('mixed_float16')

# data loading config
batch_size = 64
img_height = 28
img_width = 28
dataPath = 'fashion_mnist'
labels = 'inferred'  # auto find from folders
label_mode = 'categorical'  # one hot encoding
color_mode = 'grayscale'
shuffle = True
seed = 69
test_split = 0.2  # split into train and test (NOT val)
AUTOTUNE = tf.data.AUTOTUNE

train = keras.preprocessing.image_dataset_from_directory(dataPath, labels=labels, label_mode=label_mode,
color_mode=color_mode, shuffle=shuffle, subset="training", seed=seed, validation_split=test_split,
image_size=(img_height, img_width), batch_size=batch_size)
class_names = train.class_names
train_size = int(70000*0.8)
val = train.take(int(0.2*train_size))
train = train.take(int(0.8*train_size))

test = keras.preprocessing.image_dataset_from_directory(dataPath, labels=labels, label_mode=label_mode,
color_mode=color_mode, shuffle=shuffle, subset="validation", seed=seed, validation_split=test_split,
image_size=(img_height, img_width), batch_size=batch_size)

train = train.cache().prefetch(buffer_size=AUTOTUNE)
val = val.cache().prefetch(buffer_size=AUTOTUNE)
test = test.cache().prefetch(buffer_size=AUTOTUNE)

print(class_names)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: GeForce RTX 2060 SUPER, compute capability 7.5
Found 70000 files belonging to 10 classes.
Using 56000 files for training.
Found 70000 files belonging to 10 classes.
Using 14000 files for validation.
['Ankle Boot', 'Bag', 'Coat', 'Dress', 'Pullover', 'Sandal', 'Shirt', 'Sneaker', 'T-Shirt', 'Trouser']


In [2]:
# mode hyperparams
lr = 1e-3
opt = keras.optimizers.Adam()
epoch = 100
validation_split = 0.2  # this is within the train dataset, different from previous one where its spilt into train and test
batch_size = 128

# model layers
xInput = layers.Input((28, 28))  # get resolutions, ignore batch size
x = layers.Flatten()(xInput)
x = layers.experimental.preprocessing.Rescaling(1./255)(x)  # fit 0-255 into 0 and 1
x = layers.Dense(128, activation='relu')(x)
x = layers.Dense(256, activation='relu')(x)
xOutput = layers.Dense(len(class_names), activation='sigmoid')(x)  # a probability for each class so need same no as classes

model = keras.Model(xInput, xOutput)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics='accuracy')
model.summary()

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', min_delta=0, patience=5, verbose=1,
                                     mode='auto', baseline=None, restore_best_weights=True),
    tf.keras.callbacks.ModelCheckpoint('./best_model',monitor='val_accuracy',save_best_only=True),
    tf.keras.callbacks.TensorBoard(log_dir="./logs")
]

model.fit(train, validation_data=val, batch_size=batch_size, epochs=epoch, callbacks=callbacks, verbose=1)

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
rescaling (Rescaling)        (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 256)               33024     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                2570      
Total params: 136,074
Trainable params: 136,074
Non-trainable params: 0
_______________________________________________________

<tensorflow.python.keras.callbacks.History at 0x1962f1898e0>

In [4]:
model.evaluate(test)
model.save('Fashion MNIST')

INFO:tensorflow:Assets written to: Fashion MNIST\assets
