Multitask learning on the CIFAR-100 dataset using TensorFlow for image classification.

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar100

Load the CIFAR-100 dataset

In [None]:

(x_train, y_train), (x_test, y_test) = cifar100.load_data(label_mode='fine')

Normalize the input images

In [None]:

x_train = x_train / 255.0
x_test = x_test / 255.0

Define the shared feature extractor

In [None]:

input_layer = tf.keras.layers.Input(shape=(32, 32, 3))
conv_layer_1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(input_layer)
conv_layer_2 = tf.keras.layers.Conv2D(32, (3,3), activation='relu')(conv_layer_1)
maxpool_layer = tf.keras.layers.MaxPooling2D(pool_size=(2,2))(conv_layer_2)
flatten_layer = tf.keras.layers.Flatten()(maxpool_layer)
shared_layer_1 = tf.keras.layers.Dense(128, activation='relu')(flatten_layer)
shared_layer_2 = tf.keras.layers.Dense(64, activation='relu')(shared_layer_1)



Define the first output head for fine-grained classification task

In [None]:

fine_output = tf.keras.layers.Dense(100, activation='softmax', name='fine_output')(shared_layer_2)

Define the second output head for coarse-grained classification task

In [None]:

coarse_output = tf.keras.layers.Dense(20, activation='softmax', name='coarse_output')(shared_layer_2)

Define the model with two output heads

In [None]:

model = tf.keras.models.Model(inputs=input_layer, outputs=[fine_output, coarse_output])

Compile the model with two loss functions

In [None]:

model.compile(optimizer='adam', loss={'fine_output': 'sparse_categorical_crossentropy', 'coarse_output': 'sparse_categorical_crossentropy'}, metrics=['accuracy'])

Print

In [None]:
history = model.fit(x_train, {'fine_output': y_train, 'coarse_output': y_train // 5}, validation_data=(x_test, {'fine_output': y_test, 'coarse_output': y_test // 5}), epochs=10, metrics={'fine_output': 'accuracy', 'coarse_output': 'accuracy'})

fine_accuracy = history.history['fine_output_accuracy']
coarse_accuracy = history.history['coarse_output_accuracy']
