# Cifar-10 : Computer Vision Transfer Learning

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.ops as ops
import pandas as pd

import random

import pickle
from sklearn.metrics import confusion_matrix
import itertools

In [None]:
(data_train, data_test, data_val), infos = tfds.load('cifar10',split=['train[:48000]','train[48000:]+test[:4000]','test[4000:]'], shuffle_files=True, # 80/10/10%
                      as_supervised=True, with_info=True)

In [None]:
# Preprocessing data : ensure they all have the same size (32x32x3), datas are scaled [0,1] and type float32
resize_rescale = tf.keras.Sequential([
    tf.keras.layers.Resizing(32,32),
    tf.keras.layers.Rescaling(1./255)
])

def preprocess(data, augment=False) :
  data = data.shuffle(buffer_size=1000) #Shuffle data

  data=data.map(lambda x, y: (resize_rescale(x), y), num_parallel_calls=tf.data.AUTOTUNE) #Resize, rescale

  data=data.batch(64) # Batch

  return data.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
batch_train=preprocess(data_train)
batch_test=preprocess(data_test)
batch_val=preprocess(data_val)
batch_train, batch_test, batch_val

In [None]:
# Enable mixed precision training : Better use of GPU's Memory by using lfoat32 and float16 when possible
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

In [None]:
mixed_precision.global_policy()

In [None]:
# Callbacks :

checkpoint_path="best_model_tl_cifar10.weights.h5"

early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=1)

LrReducer = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, verbose=1, min_lr=1e-7) # Only when the val_loss doesn't go down for 2 epochs

checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', verbose=0, save_best_only=True, save_weights_only=True, mode='max')

In [None]:
# Turn off all warnings except for errors
tf.get_logger().setLevel('ERROR')

## Feature extraction :

In [None]:
from tensorflow.keras import layers

input_shape = (32, 32, 3)
base_model = tf.keras.applications.ResNet152(include_top=False)
base_model.trainable = False # freeze layers

inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False) # set base_model to inference mode only. Batch norm and droupout don't change their learning.
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(10)(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model_0 = tf.keras.Model(inputs, outputs)

# Compile the model
model_0.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])

In [None]:
model_0.summary()

In [None]:
# Fit the model with callbacks
history_feature_extraction = model_0.fit(batch_train, epochs=5, validation_data=batch_val, callbacks=[early_stop,LrReducer,checkpoint])

In [None]:
res_0=model_0.evaluate(batch_test)

In [None]:
df_fe = pd.DataFrame(history_feature_extraction.history).drop(columns=['learning_rate'])
df_fe.columns=['accuracy_fe','loss_fe','val_accuracy_fe','val_loss_fe']
df_fe_acc=df_fe.loc[:,['accuracy_fe','val_accuracy_fe']]
df_fe_loss=df_fe.loc[:,['loss_fe','val_loss_fe']]

In [None]:
df_fe_loss.plot() #Plot loss

df_fe_acc.plot() #Plot acc

In [None]:
# Fit for another 5 epochs starting from the last (5th).
history_fe_2 = model_0.fit(batch_train, epochs=10, validation_data=batch_val, callbacks=[checkpoint,LrReducer,early_stop], initial_epoch=history_feature_extraction.epoch[-1]+1)

In [None]:
res_0=model_0.evaluate(batch_test)

In [None]:
df_fe=pd.concat([pd.DataFrame(history_feature_extraction.history).drop(columns=['learning_rate']),pd.DataFrame(history_fe_2.history).drop(columns=['learning_rate'])])
df_fe.index=range(0,len(df_fe))
df_fe.columns=['acc_fe','loss_fe','val_acc_fe','val_loss_fe']
df_fe_acc=df_fe.loc[:,['acc_fe','val_acc_fe']]
df_fe_loss=df_fe.loc[:,['loss_fe','val_loss_fe']]

In [None]:
df_fe_loss.plot() #Plot loss
plt.vlines(4,min(df_fe['loss_fe'])-0.001,max(df_fe['loss_fe'])+0.001,linestyles='dashed')
plt.plot()

df_fe_acc.plot() #Plot acc
plt.vlines(4,min(df_fe['acc_fe'])-0.001,max(df_fe['val_acc_fe']),linestyles='dashed')
plt.plot()

## Fine-tuning :

### 10 Layers :

In [None]:
input_shape = (32, 32, 3)
base_model = tf.keras.applications.ResNet152(include_top=False)
base_model.trainable = False

# Unfreeze the last 10
for layer in base_model.layers[-10:]:
  layer.trainable = True

inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(10)(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model_1 = tf.keras.Model(inputs, outputs)

# Compile the model
model_1.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])

In [None]:
history_fine_tune10 = model_1.fit(batch_train, epochs=10, validation_data=batch_val, callbacks=[checkpoint,LrReducer,early_stop])

In [None]:
res_1=model_1.evaluate(batch_test)

In [None]:
df_ft10=pd.DataFrame(history_fine_tune10.history).drop(columns=['learning_rate'])
df_ft10.columns=['accuracy_ft10','loss_ft10','val_accuracy_ft10','val_loss_ft10']
df_ft10_acc=df_ft10.loc[:,['accuracy_ft10','val_accuracy_ft10']]
df_ft10_loss=df_ft10.loc[:,['loss_ft10','val_loss_ft10']]

In [None]:
df_ft10_loss.plot() #Plot loss
df_ft10_acc.plot() #Plot acc

### 25 layers :

In [None]:
input_shape = (32, 32, 3)
base_model = tf.keras.applications.ResNet152(include_top=False)
base_model.trainable = False

# Unfreeze the last 25 :
for layer in base_model.layers[-25:]:
  layer.trainable = True

inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(10)(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model_2 = tf.keras.Model(inputs, outputs)

# Compile the model
model_2.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])

In [None]:
history_fine_tune25 = model_2.fit(batch_train, epochs=15, validation_data=batch_val, callbacks=[checkpoint,LrReducer,early_stop])

In [None]:
res_2=model_2.evaluate(batch_test)

In [None]:
df_ft25=pd.DataFrame(history_fine_tune25.history).drop(columns=['learning_rate'])
df_ft25.columns=['accuracy_ft25','loss_ft25','val_accuracy_ft25','val_loss_ft25']
df_ft25_acc=df_ft25.loc[:,['accuracy_ft25','val_accuracy_ft25']]
df_ft25_loss=df_ft25.loc[:,['loss_ft25','val_loss_ft25']]

In [None]:
pd.concat([df_ft10_loss,df_ft25_loss]).plot() #Plot loss
pd.concat([df_ft10_acc,df_ft25_acc]).plot() #Plot acc

### 50 layers :

In [None]:
input_shape = (32, 32, 3)
base_model = tf.keras.applications.ResNet152(include_top=False)
base_model.trainable = False

# Unfreeze the last 50 :
for layer in base_model.layers[-50:]:
  layer.trainable = True

inputs = layers.Input(shape=input_shape, name="input_layer")
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D(name="pooling_layer")(x)
x = layers.Dense(10)(x)
outputs = layers.Activation("softmax", dtype=tf.float32, name="softmax_float32")(x)
model_3 = tf.keras.Model(inputs, outputs)

# Compile the model
model_3.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics=["accuracy"])

In [None]:
history_fine_tune50 = model_3.fit(batch_train, epochs=15, validation_data=batch_val,callbacks=[checkpoint,LrReducer,early_stop])

In [None]:
res_3=model_3.evaluate(batch_test)

In [None]:
df_ft50=pd.DataFrame(history_fine_tune50.history).drop(columns=['learning_rate'])
df_ft50.columns=['accuracy_ft50','loss_ft50','val_accuracy_ft50','val_loss_ft50']
df_ft50_acc=df_ft50.loc[:,['accuracy_ft50','val_accuracy_ft50']]
df_ft50_loss=df_ft50.loc[:,['loss_ft50','val_loss_ft50']]

In [None]:
pd.concat([df_ft10_loss,df_ft25_loss,df_ft50_loss]).plot(ylim=(0.01,5)) #Plot loss
pd.concat([df_ft10_acc,df_ft25_acc,df_ft50_acc]).plot() #Plot acc

## Conclusion :

In [None]:
res=pd.DataFrame([res_0,res_1,res_2,res_3])
res.columns=['loss','accuracy']
res.index=['model_fe','model_ft10','model_ft25','model_ft50']

In [None]:
res.sort_values(by='accuracy',ascending=False).plot(kind='bar')

### Confusion matrix

In [None]:
# https://stackoverflow.com/questions/64622210/how-to-extract-classes-from-prefetched-dataset-in-tensorflow-for-confusion-matri

y_pred = []  # store predicted labels
y_true = []  # store true labels

# iterate over the dataset
for image_batch, label_batch in batch_test:   # use dataset.unbatch() with repeat
   # append true labels
   y_true.append(label_batch)
   # compute predictions
   preds = model_2.predict(image_batch,verbose=0)
   # append predicted labels
   y_pred.append(np.argmax(preds, axis = 1))

# convert the true and predicted labels into tensors
correct_labels = tf.concat([item for item in y_true], axis = 0)
predicted_labels = tf.concat([item for item in y_pred], axis = 0)

In [None]:
# https://github.com/mrdbourke/tensorflow-deep-learning/blob/main/06_transfer_learning_in_tensorflow_part_3_scaling_up.ipynb

n_classes=len(infos.features['label'].names)

cm = confusion_matrix(correct_labels, predicted_labels)

# Plot the figure and make it pretty
fig, ax = plt.subplots(figsize=(8,8))
cax = ax.matshow(cm, cmap=plt.cm.Blues) # colors will represent how 'correct' a class is, darker == better
fig.colorbar(cax)

# Label the axes
ax.set(title="Confusion Matrix",
        xlabel="Predicted label",
        ylabel="True label",
        xticks=np.arange(n_classes), # create enough axis slots for each class
        yticks=np.arange(n_classes),
        xticklabels=infos.features['label'].names,
        yticklabels=infos.features['label'].names)

# Make x-axis labels appear on bottom
ax.xaxis.set_label_position("bottom")
ax.xaxis.tick_bottom()

### Added: Rotate xticks for readability & increase font size (required due to such a large confusion matrix)
plt.xticks(rotation=70, fontsize=12)
plt.yticks(fontsize=12)

# Set the threshold for different colors
threshold = (cm.max() + cm.min()) / 2.

# Plot the text on each cell
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  plt.text(j, i, f"{cm[i, j]}",
          ha="center",
          va='center',
          color="white" if cm[i, j] > threshold else "black",
          size=12)