In [None]:
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers
from tensorflow import keras
import tensorflow as tf
import datetime
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

In [None]:
# Split training data into training and validation sets
num_validation_samples = int(0.2 * X_train.shape[0])  # 20% of the training dataset (40k:10k)

X_val = X_train[:num_validation_samples]
y_val = y_train[:num_validation_samples]

X_train = X_train[num_validation_samples:]
y_train = y_train[num_validation_samples:]


y_train = to_categorical(y_train)
y_val = to_categorical(y_val)
y_test = to_categorical(y_test)


In [None]:
AUTOTUNE = tf.data.AUTOTUNE
IMG_SIZE = 180
batch_size = 32

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))


train_dataset = (
    train_dataset
    .shuffle(10000)
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

validation_dataset = (
    val_dataset
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

test_dataset = (
    test_dataset
    .batch(batch_size)
    .prefetch(AUTOTUNE)
)

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.1),
        layers.RandomZoom(0.1),
    ]
)

In [None]:
inputs = keras.Input(shape=(32, 32, 3))

x = data_augmentation(inputs)

x = layers.Rescaling(1./255)(x)

x = layers.Conv2D(filters=64, kernel_size=(3,3), activation="relu", padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Dropout(0.2)(x)

x = layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(pool_size=(2, 2))(x)
x = layers.Dropout(0.3)(x)


x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(10, activation="sigmoid")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
model.summary()

model.compile(loss="categorical_crossentropy",
              optimizer="adamw",
              metrics=["accuracy"])

In [None]:
callback = keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, min_delta = 0.002, restore_best_weights=True)
# dataset augmentation can cause higher volatility, it needs more patience, 10 patience 80% accuracy, 10 patience 90% accuracy

history = model.fit(train_dataset, epochs=30, validation_data=validation_dataset, callbacks=[callback, tensorboard_callback])

In [None]:
import matplotlib.pyplot as plt
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(accuracy) + 1)
plt.plot(epochs, accuracy, "bo", label="Training accuracy")
plt.plot(epochs, val_accuracy, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

In [None]:
test_loss, test_acc = model.evaluate(test_dataset)
print("Test accuracy: is", test_acc)

In [None]:
batch_size = 10
img_height, img_width = 32, 32

test2 = tf.keras.utils.image_dataset_from_directory(
    dir,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode=None
)