# Training CiFAR with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [None]:
(ods_train, ods_test), ds_info = tfds.load(
    'cifar10',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [None]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ods_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(512)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

## EDA

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [None]:
ds_test = ods_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(128, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(128, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)


history = model.fit(
    ds_train,
    epochs=30,
    validation_data=ds_test,
)

In [None]:
import matplotlib.pyplot as plt

def show_training_curves(history):
  acc = history.history['sparse_categorical_accuracy']
  val_acc = history.history['val_sparse_categorical_accuracy']

  loss = history.history['loss']
  val_loss = history.history['val_loss']

  epochs_range = range(len(acc))

  plt.figure(figsize=(12, 4))
  plt.subplot(1, 2, 1)
  plt.plot(epochs_range, acc, label='Training Accuracy')
  plt.plot(epochs_range, val_acc, label='Validation Accuracy')
  plt.legend(loc='lower right')
  plt.title('Training and Validation Accuracy')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.grid(True)

  plt.subplot(1, 2, 2)
  plt.plot(epochs_range, loss, label='Training Loss')
  plt.plot(epochs_range, val_loss, label='Validation Loss')
  plt.legend(loc='upper right')
  plt.title('Training and Validation Loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.grid(True)

  plt.tight_layout()
  plt.show()

In [None]:
new_history = model.fit(
    ds_train,
    epochs=10,
    validation_data=ds_test,
)

for key in new_history.history:
    history.history[key].extend(new_history.history[key])

In [None]:
show_training_curves(history)

In [None]:
new_history = model.fit(
    ds_train,
    epochs=10,
    validation_data=ds_test,
)

for key in new_history.history:
    history.history[key].extend(new_history.history[key])

In [None]:
show_training_curves(history)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def show_confusion_matrix(model, dataset):
  # Get true labels from dataset
  true_labels = []
  for images, labels in dataset.unbatch():
      true_labels.append(labels.numpy())
  true_labels = np.array(true_labels)

  # Predict probabilities for the test set
  predictions = model.predict(dataset)

  # Get predicted classes
  predicted_labels = np.argmax(predictions, axis=1)

  # Generate the confusion matrix
  cm = confusion_matrix(true_labels, predicted_labels)

  # Get class names from dataset info (assuming cifar10 has 'label' feature with names)
  class_names = ds_info.features['label'].names

  # Plot the confusion matrix
  plt.figure(figsize=(10, 8))
  sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, xticklabels=class_names, yticklabels=class_names)
  plt.xlabel('Predicted Labels')
  plt.ylabel('True Labels')
  plt.title('Confusion Matrix')
  plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def plot_value_counts(dataset):
  # Get true labels from dataset
  true_labels = []
  for images, labels in dataset.unbatch():
      true_labels.append(labels.numpy())
  true_labels = np.array(true_labels)
  # Count the occurrences of each true label
  label_counts = pd.Series(true_labels).value_counts().sort_index()

  # Create a DataFrame for plotting
  plot_df = pd.DataFrame({
      'Label Index': label_counts.index,
      'Count': label_counts.values,
      'Class Name': [class_names[i] for i in label_counts.index]
  })

  plt.figure(figsize=(10, 6))
  sns.barplot(x='Class Name', y='Count', hue='Class Name', data=plot_df, palette='viridis', legend=False)
  plt.xlabel('Output Label')
  plt.ylabel('Count')
  plt.title('Counts of Each Output Label in Test Set')
  plt.xticks(rotation=45, ha='right')
  plt.grid(axis='y', linestyle='--', alpha=0.7)
  plt.tight_layout()
  plt.show()

In [None]:
show_confusion_matrix(model, ds_test)

In [None]:
plot_value_counts(ds_test)

In [None]:
plot_value_counts(ds_train)

In [None]:
## Improve

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.BatchNormalization(), # Changed to BatchNormalization
  tf.keras.layers.Dense(1024, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(512, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(256, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

# No need to call .adapt() for BatchNormalization as it learns stats per batch

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)


history = model.fit(
    ds_train,
    epochs=30,
    validation_data=ds_test,
)

In [None]:
show_training_curves(history)

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dense(1024, activation='leaky_relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)), # Added L2 regularization
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(512, activation='leaky_relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)), # Added L2 regularization
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(256, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.0001),  ## decreased learning rate
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Define EarlyStopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_sparse_categorical_accuracy',  # Monitor validation accuracy
    mode='max',             # Maximize validation accuracy
    patience=5,          # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored quantity.
)

history = model.fit(
    ds_train,
    epochs=30,
    validation_data=ds_test,
    callbacks=[early_stopping_callback] # Add early stopping callback
)

In [None]:
show_training_curves(history)

In [None]:
import tensorflow as tf

def augment_img(image, label):
  # Random horizontal flip
  image = tf.image.random_flip_left_right(image)
  # Random brightness adjustment
  image = tf.image.random_brightness(image, max_delta=0.2)
  # Random contrast adjustment
  image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
  # Random crop by padding and then cropping
  IMG_SIZE = 32
  paddings = tf.constant([[4, 4], [4, 4], [0, 0]]) # Pad by 4 pixels on each side (total 40x40)
  image = tf.pad(image, paddings, "REFLECT") # Use REFLECT padding to avoid black borders
  image = tf.image.random_crop(image, size=[IMG_SIZE, IMG_SIZE, 3])
  return image, label

# Re-build the training pipeline with data augmentation.
# Assumes 'ods_train', 'normalize_img', 'ds_info.splits['train'].num_examples', and 'tf.data.AUTOTUNE' are defined in prior cells.
ds_train = ods_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
# Apply data augmentation to the training dataset after caching normalized data
ds_train = ds_train.map(
    augment_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(512)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

print("Training dataset pipeline rebuilt with data augmentation.")
print("The new pipeline includes random horizontal flips, brightness, contrast adjustments, and random cropping.")

In [None]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.Dense(1024, activation='leaky_relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)), # Added L2 regularization
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(512, activation='leaky_relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)), # Added L2 regularization
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(256, activation='leaky_relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.0001),  ## decreased learning rate
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Define EarlyStopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_accuracy',  # Monitor validation accuracy
    mode='max',             # Maximize validation accuracy
    patience=5,          # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored quantity.
)

history = model.fit(
    ds_train,
    epochs=30,
    validation_data=ds_test,
    callbacks=[early_stopping_callback] # Add early stopping callback
)

In [None]:
show_training_curves(history)

In [None]:
more_history = model.fit(
    ds_train,
    epochs=30,
    validation_data=ds_test,
    callbacks=[early_stopping_callback] # Add early stopping callback
)

In [None]:

for key in more_history.history:
    history.history[key].extend(more_history.history[key])

In [None]:
show_training_curves(history)