In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import math
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

# Some hyper-param
input_shape = [224, 224, 3]
batch_size = 64
num_epochs = 5
random_seed = 42
log_begin_red, log_begin_green = '\033[91m', '\033[92m'
log_begin_bold, log_begin_underline = '\033[1m', '\033[4m'
log_end_format = '\033[0m'

# **Preparing the Data**

In [2]:
import cifar_utils

cifar_info = cifar_utils.get_info()
print(cifar_info)

# Num of classes
num_classes = cifar_info.features['label'].num_classes

# Num of train/val imgs
num_train_imgs = cifar_info.splits['train'].num_examples
num_val_imgs = cifar_info.splits['test'].num_examples

# Train/val steps per epoch
train_steps_per_epoch = math.ceil(num_train_imgs / batch_size)
val_steps_per_epoch = math.ceil(num_val_imgs / batch_size)

[1mDownloading and preparing dataset cifar100/3.0.2 (download: 160.71 MiB, generated: 132.03 MiB, total: 292.74 MiB) to /root/tensorflow_datasets/cifar100/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar100/3.0.2.incomplete2HYG9X/cifar100-train.tfrecord


  0%|          | 0/50000 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar100/3.0.2.incomplete2HYG9X/cifar100-test.tfrecord


  0%|          | 0/10000 [00:00<?, ? examples/s]

[1mDataset cifar100 downloaded and prepared to /root/tensorflow_datasets/cifar100/3.0.2. Subsequent calls will reuse this data.[0m
tfds.core.DatasetInfo(
    name='cifar100',
    version=3.0.2,
    description='This dataset is just like the CIFAR-10, except it has 100 classes containing 600 images each. There are 500 training images and 100 testing images per class. The 100 classes in the CIFAR-100 are grouped into 20 superclasses. Each image comes with a "fine" label (the class to which it belongs) and a "coarse" label (the superclass to which it belongs).',
    homepage='https://www.cs.toronto.edu/~kriz/cifar.html',
    features=FeaturesDict({
        'coarse_label': ClassLabel(shape=(), dtype=tf.int64, num_classes=20),
        'id': Text(shape=(), dtype=tf.string),
        'image': Image(shape=(32, 32, 3), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=100),
    }),
    total_num_examples=60000,
    splits={
        'test': 10000,
        'train

# **Fine-tuning the ResNet feature extractor**

## **Meta-Iterating to find the best transfer learning solution**

In [None]:
# We just freeze up to the third macro block, freeze all layers of pre-trained
# ResNet50 is beyond this notebook.
num_macroblocks_to_freeze = [0, 1, 2, 3]  
histories = dict()

for freeze_num in num_macroblocks_to_freeze:

  print('{1}{2}>> {3}ResNet-50 with {0} macro-block(s) frozen:{4} '.format(
      freeze_num, log_begin_green, log_begin_bold, log_begin_underline, log_end_format
  ))

  # ----------------------------------------
  # 1. Instantiate a new classifier each time
  resnet50_feature_extractor = tf.keras.applications.resnet50.ResNet50(
      include_top=False, weights='imagenet', input_shape=input_shape
  )

  avg_pool = GlobalAveragePooling2D()(resnet50_feature_extractor.output)
  prediction = Dense(num_classes, activation='softmax')(avg_pool)
  resnet50_finetune = Model(resnet50_feature_extractor.input, prediction)

  # -----------------------------------------
  # 2. Freeze the desired layers
  break_layer_name = 'conv{}'.format(freeze_num + 2) 
  frozen_layers = []
  for layer in resnet50_finetune.layers:
    if freeze_num == 0:
      break
    if break_layer_name in layer.name:
      break
    if isinstance(layer, tf.keras.layers.Conv2D):
      # If layer is Conv2D, freeze that layer
      layer.trainable = False
      frozen_layers.append(layer.name)

  print('\t> {2}Layers we froze:{4} {0} ({3}total = {1}{4})'.format(
      frozen_layers, len(frozen_layers), log_begin_red, log_begin_bold, log_end_format
  ))

  # -----------------------------------------
  # 3. Re-instantiate the input pipelines (same parameters)
  train_cifar_dataset = cifar_utils.get_dataset('train', batch_size=batch_size,
                                                num_epochs=num_epochs, shuffle=True,
                                                input_shape=input_shape,
                                                return_batch_as_tuple=True,
                                                seed=random_seed)
  
  val_cifar_dataset = cifar_utils.get_dataset('val', batch_size=batch_size,
                                              num_epochs=num_epochs, shuffle=False,
                                              input_shape=input_shape,
                                              return_batch_as_tuple=True,
                                              seed=random_seed)
                                              
  # -----------------------------------------
  # 4. Set up the training operations, and start the process
  optimizer = tf.keras.optimizers.SGD(learning_rate=1e-4, momentum=0.9, nesterov=True)

  metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='acc'),
             tf.keras.metrics.SparseTopKCategoricalAccuracy(name='top3_acc', k=3)]

  resnet50_finetune.compile(loss='sparse_categorical_crossentropy',
                           optimizer=optimizer, metrics=metrics)
  
  print('\t> Training - {}start{} (logs = off)'.format(log_begin_red, log_end_format))
  
  history = resnet50_finetune.fit(train_cifar_dataset, epochs=num_epochs, verbose=0,
                                  validation_data=val_cifar_dataset,
                                  steps_per_epoch=train_steps_per_epoch,
                                  validation_steps=val_steps_per_epoch)
  
  print('\t> Training - {}over{}'.format(log_begin_green, log_end_format))
  
  acc = history.history['acc'][-1] * 100
  top3_acc = history.history['top3_acc'][-1] * 100
  val_acc = history.history['val_acc'][-1] * 100
  val_top3_acc = history.history['val_top3_acc'][-1] * 100
  
  print('\t> Results after {5}{0}{6} epochs: \t{5}acc = {1:.2f}%; top3 = {2:.2f}%, val_acc '\
        '= {3:.2f}%; val_top3 = {4:.2f}%{6}'.format(num_epochs, acc, top3_acc, val_acc,
                                            val_top3_acc, log_begin_bold,
                                            log_end_format))
  
  histories['freeze_{}_macro_block(s)'.format(freeze_num)] = history

[92m[1m>> [4mResNet-50 with 0 macro-block(s) frozen:[0m 
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
	> [91mLayers we froze:[0m [] ([1mtotal = 0[0m)
	> Training - [91mstart[0m (logs = off)




	> Training - [92mover[0m
	> Results after [1m5[0m epochs: 	[1macc = 54.38%; top3 = 75.57%, val_acc = 57.60%; val_top3 = 79.37%[0m
[92m[1m>> [4mResNet-50 with 1 macro-block(s) frozen:[0m 
	> [91mLayers we froze:[0m ['conv1_conv', 'conv2_block1_1_conv', 'conv2_block1_2_conv', 'conv2_block1_0_conv', 'conv2_block1_3_conv', 'conv2_block2_1_conv', 'conv2_block2_2_conv', 'conv2_block2_3_conv', 'conv2_block3_1_conv', 'conv2_block3_2_conv', 'conv2_block3_3_conv'] ([1mtotal = 11[0m)
	> Training - [91mstart[0m (logs = off)




	> Training - [92mover[0m
	> Results after [1m5[0m epochs: 	[1macc = 52.85%; top3 = 74.81%, val_acc = 57.12%; val_top3 = 78.74%[0m
[92m[1m>> [4mResNet-50 with 2 macro-block(s) frozen:[0m 
	> [91mLayers we froze:[0m ['conv1_conv', 'conv2_block1_1_conv', 'conv2_block1_2_conv', 'conv2_block1_0_conv', 'conv2_block1_3_conv', 'conv2_block2_1_conv', 'conv2_block2_2_conv', 'conv2_block2_3_conv', 'conv2_block3_1_conv', 'conv2_block3_2_conv', 'conv2_block3_3_conv', 'conv3_block1_1_conv', 'conv3_block1_2_conv', 'conv3_block1_0_conv', 'conv3_block1_3_conv', 'conv3_block2_1_conv', 'conv3_block2_2_conv', 'conv3_block2_3_conv', 'conv3_block3_1_conv', 'conv3_block3_2_conv', 'conv3_block3_3_conv', 'conv3_block4_1_conv', 'conv3_block4_2_conv', 'conv3_block4_3_conv'] ([1mtotal = 24[0m)
	> Training - [91mstart[0m (logs = off)


In [None]:
fig, axs = plt.subplots(3, 2, figsize=(15, 10), sharex='col')
axs[0, 0].set_title('loss')
axs[0, 1].set_title('val-loss')
axs[1, 0].set_title('acc')
axs[1, 1].set_title('val-acc')
axs[2, 0].set_title('top3-acc')
axs[2, 1].set_title('val-top3-acc')

lines, labels = [], []
for config_name in histories:
  history = histories[config_name]
  axs[0, 0].plot(history.history['loss'])
  axs[0, 1].plot(history.history['val_loss'])
  axs[1, 0].plot(history.history['acc'])
  axs[1, 1].plot(history.history['val_acc'])
  axs[2, 0].plot(history.history['top3_acc'])
  line, = axs[2, 1].plot(history.history['val_top3_acc'])
  labels.append(config_name)
  lines.append(line)

fig.legend(lines, labels, loc='center right')