In [None]:
import tensorflow as tf

print("TensorFlow version:", tf.__version__)
gpu_device_name = tf.test.gpu_device_name()
if gpu_device_name:
    print("GPU device name:", gpu_device_name)
else:
    print("No GPU found")

TensorFlow version: 2.18.0
GPU device name: /device:GPU:0


In [None]:
# Helper Methods
import pickle
from os.path import isfile, isdir

def load_data(file_list):
    for file in file_list:
        features, labels = pickle.load(open(file, 'rb'))
        for img, label in zip(features, labels):
            yield img, label

def preprocess(img, label):
    img = tf.image.resize(img, (224, 224))
    return img, label

from urllib.request import urlretrieve
from os.path import isfile, isdir

from tqdm import tqdm
import tarfile
import pickle
import numpy as np

class DownloadProgress(tqdm):
    last_block = 0

    def hook(self, block_num=1, block_size=1, total_size=None):
        self.total = total_size
        self.update((block_num - self.last_block) * block_size)
        self.last_block = block_num

def download(dataset_folder_path):
    if not isfile('cifar-10-python.tar.gz'):
        with DownloadProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Dataset') as pbar:
            urlretrieve(
                'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
                'cifar-10-python.tar.gz',
                pbar.hook)
    else:
        print('cifar-10-python.tar.gz already exists')

    if not isdir(dataset_folder_path):
        with tarfile.open('cifar-10-python.tar.gz') as tar:
            tar.extractall()
            tar.close()
    else:
        print('cifar10 dataset already exists')

def load_cifar10_batch(dataset_folder_path, batch_id):
    with open(dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
        # note the encoding type is 'latin1'
        batch = pickle.load(file, encoding='latin1')

    features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
    labels = batch['labels']

    return features, labels


def one_hot_encode(x):
    encoded = np.zeros((len(x), 10))

    for idx, val in enumerate(x):
        encoded[idx][val] = 1

    return encoded

def _preprocess_and_save(one_hot_encode, features, labels, filename):
    labels = one_hot_encode(labels)

    pickle.dump((features, labels), open(filename, 'wb'))


def preprocess_and_save_data(dataset_folder_path):
    n_batches = 5
    valid_features = []
    valid_labels = []

    for batch_i in range(1, n_batches + 1):
        features, labels = load_cifar10_batch(dataset_folder_path, batch_i)

        # find index to be the point as validation data in the whole dataset of the batch (10%)
        index_of_validation = int(len(features) * 0.1)

        # preprocess the 90% of the whole dataset of the batch
        # - normalize the features
        # - one_hot_encode the lables
        # - save in a new file named, "preprocess_batch_" + batch_number
        # - each file for each batch
        _preprocess_and_save(one_hot_encode,
                             features[:-index_of_validation], labels[:-index_of_validation],
                             'preprocess_batch_' + str(batch_i) + '.p')

        # unlike the training dataset, validation dataset will be added through all batch dataset
        # - take 10% of the whold dataset of the batch
        # - add them into a list of
        #   - valid_features
        #   - valid_labels
        valid_features.extend(features[-index_of_validation:])
        valid_labels.extend(labels[-index_of_validation:])

    # preprocess the all stacked validation dataset
    _preprocess_and_save(one_hot_encode,
                         np.array(valid_features), np.array(valid_labels),
                         'preprocess_validation.p')

    # load the test dataset
    with open(dataset_folder_path + '/test_batch', mode='rb') as file:
        batch = pickle.load(file, encoding='latin1')

    # preprocess the testing data
    test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
    test_labels = batch['labels']

    # Preprocess and Save all testing data
    _preprocess_and_save(one_hot_encode,
                         np.array(test_features), np.array(test_labels),
                         'preprocess_testing.p')

In [None]:
from tensorflow.keras import layers, models, optimizers

class AlexNet(tf.keras.Model):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()

        weight_decay = 5e-4

        self.model = models.Sequential([
            # Layer 1
            layers.Input(shape=(224, 224, 3)),
            layers.Conv2D(96, kernel_size=11, strides=4, activation='relu', padding='same',
                          kernel_regularizer=regularizers.l2(weight_decay)),
            layers.BatchNormalization(),
            layers.MaxPooling2D(pool_size=3, strides=2),

            # Layer 2
            layers.Conv2D(256, kernel_size=5, padding='same', activation='relu',
                          kernel_regularizer=regularizers.l2(weight_decay)),
            layers.BatchNormalization(),
            layers.MaxPooling2D(pool_size=3, strides=2),

            # Layer 3–5
            layers.Conv2D(384, kernel_size=3, padding='same', activation='relu',
                          kernel_regularizer=regularizers.l2(weight_decay)),
            layers.Conv2D(384, kernel_size=3, padding='same', activation='relu',
                          kernel_regularizer=regularizers.l2(weight_decay)),
            layers.Conv2D(256, kernel_size=3, padding='same', activation='relu',
                          kernel_regularizer=regularizers.l2(weight_decay)),
            layers.MaxPooling2D(pool_size=3, strides=2),

            # Flatten
            layers.Flatten(),

            # FC layers
            layers.Dense(4096, activation='relu', kernel_regularizer=regularizers.l2(weight_decay)),
            layers.Dropout(0.5),
            layers.Dense(4096, activation='relu', kernel_regularizer=regularizers.l2(weight_decay)),
            layers.Dropout(0.5),

            # Output
            layers.Dense(num_classes, activation='softmax', kernel_regularizer=regularizers.l2(weight_decay))
        ])

    def call(self, inputs):
        return self.model(inputs)

def train(model, train_dataset, val_dataset, epochs, learning_rate, save_path):
    with tf.device(gpu_device_name):
      model.compile(optimizer=optimizers.Adam(learning_rate=learning_rate),
                    loss='categorical_crossentropy',
                    metrics=[tf.keras.metrics.TopKCategoricalAccuracy(k=1),tf.keras.metrics.TopKCategoricalAccuracy(k=5)])

      history = model.fit(train_dataset, epochs=epochs, validation_data=val_dataset, steps_per_epoch=800)
    model.save(save_path)
    return history

In [None]:
import sys
from tensorflow.keras import regularizers

dataset = 'cifar10'
dataset_path = 'none'
learning_rate = 0.00005
epochs = 20
batch_size = 64
num_classes = 10

cifar10_dataset_folder_path = 'cifar-10-batches-py'
save_path = './image_classification.keras'

if dataset == 'cifar10' and dataset_path == 'none':
    download(cifar10_dataset_folder_path)

if dataset == 'cifar10':
    print('preprocess_and_save_data...')
    preprocess_and_save_data(cifar10_dataset_folder_path)

    print('load features and labels for valid dataset...')
    valid_features, valid_labels = pickle.load(open('preprocess_validation.p', mode='rb'))

    print('converting valid images to fit into imagenet size...')
    with tf.device(gpu_device_name):
      tmp_valid_features = tf.image.resize(valid_features, (224,224))
else:
    sys.exit(0)

with tf.device(gpu_device_name):
  model = AlexNet(num_classes=10)

  file_list = [f'preprocess_batch_{i}.p' for i in range(1, 6)]

  dataset = tf.data.Dataset.from_generator(
    lambda: load_data(file_list),
    output_signature=(
        tf.TensorSpec(shape=(32, 32, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(num_classes,), dtype=tf.int64)
    )
  )

  train_ds = (dataset
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .shuffle(1000)
            .batch(batch_size)
            .repeat()
            .prefetch(tf.data.AUTOTUNE))

  val_ds = tf.data.Dataset.from_tensor_slices((tmp_valid_features, valid_labels))
  val_ds = (val_ds
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size)
            .shuffle(1000)
            .prefetch(tf.data.AUTOTUNE))

  history = train(model, train_ds, val_ds, epochs, learning_rate, save_path)

cifar-10-python.tar.gz already exists
cifar10 dataset already exists
preprocess_and_save_data...
load features and labels for valid dataset...
converting valid images to fit into imagenet size...
Epoch 1/20
[1m800/800[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 108ms/step - loss: 7.1936 - top_k_categorical_accuracy: 0.3983 - top_k_categorical_accuracy_1: 0.8472 - val_loss: 5.0446 - val_top_k_categorical_accuracy: 0.6228 - val_top_k_categorical_accuracy_1: 0.9598
Epoch 2/20
[1m800/800[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 93ms/step - loss: 4.8471 - top_k_categorical_accuracy: 0.6416 - top_k_categorical_accuracy_1: 0.9661 - val_loss: 4.4257 - val_top_k_categorical_accuracy: 0.6724 - val_top_k_categorical_accuracy_1: 0.9744
Epoch 3/20
[1m800/800[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 94ms/step - loss: 4.1449 - top_k_categorical_accuracy: 0.7295 - top_k_categorical_accuracy_1: 0.9807 - val_loss: 3.9146 - val_top_k_categorical_accuracy: 0.7234

In [None]:
import matplotlib.pyplot as plt

def plot_training_history(history):
  history = history.history
  loss_values = history['loss']
  val_loss_values = history['val_loss']
  top1 = history['top_k_categorical_accuracy']
  val_top1 = history['val_top_k_categorical_accuracy']
  top5 = history['top_k_categorical_accuracy_1']
  val_top5 = history['val_top_k_categorical_accuracy_1']
  epochs_range = range(1, len(loss_values) + 1)

  # plotting loss
  plt.figure(figsize=(16,8))
  plt.subplot(1, 3, 1)
  plt.plot(epochs_range, loss_values, label = 'train loss')
  plt.plot(epochs_range, val_loss_values, label = 'validation loss')
  plt.xlabel('Epoch')
  plt.ylabel('Loss')
  plt.title('Loss')
  plt.legend()

  # plotting top 1 accuracy
  plt.subplot(1, 3, 2)
  plt.plot(epochs_range, top1, label='Train Top-1 Acc')
  plt.plot(epochs_range, val_top1, label='Val Top-1 Acc')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.title('Top-1 Accuracy')
  plt.legend()

  # plotting top 5 accuracy
  plt.subplot(1, 3, 3)
  plt.plot(epochs_range, top5, label='Train Top-5 Acc')
  plt.plot(epochs_range, val_top5, label='Val Top-5 Acc')
  plt.xlabel('Epoch')
  plt.ylabel('Accuracy')
  plt.title('Top-5 Accuracy')
  plt.legend()

  plt.tight_layout()
  plt.show()

# print(history.history.keys())

plot_training_history(history)
