Implementing multitask learning on the CIFAR-10 dataset using TensorFlow.

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10

Load the CIFAR-10 dataset

In [None]:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

Normalize the input images

In [None]:

x_train = x_train / 255.0
x_test = x_test / 255.0

Define the shared feature extractor

In [None]:

input_layer = tf.keras.layers.Input(shape=(32, 32, 3))
conv_layer_1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(input_layer)
conv_layer_2 = tf.keras.layers.Conv2D(32, (3,3), activation='relu')(conv_layer_1)
maxpool_layer = tf.keras.layers.MaxPooling2D(pool_size=(2,2))(conv_layer_2)
flatten_layer = tf.keras.layers.Flatten()(maxpool_layer)
shared_layer_1 = tf.keras.layers.Dense(128, activation='relu')(flatten_layer)
shared_layer_2 = tf.keras.layers.Dense(64, activation='relu')(shared_layer_1)

Define the first output head for object recognition task

In [None]:

object_output = tf.keras.layers.Dense(10, activation='softmax', name='object_output')(shared_layer_2)

Define the second output head for color classification task

In [None]:

color_output = tf.keras.layers.Dense(2, activation='softmax', name='color_output')(shared_layer_2)

Define the model with two output heads

In [None]:

model = tf.keras.models.Model(inputs=input_layer, outputs=[object_output, color_output])



Compile the model with two loss functions

In [None]:

model.compile(optimizer='adam', loss={'object_output': 'sparse_categorical_crossentropy', 'color_output': 'binary_crossentropy'}, metrics=['accuracy'])



Train the model with both tasks

In [None]:

history = model.fit(x_train, {'object_output': y_train, 'color_output': y_train % 2}, validation_data=(x_test, {'object_output': y_test, 'color_output': y_test % 2}), epochs=10, metrics={'object_output': 'accuracy', 'color_output': 'accuracy'})


Print accuracy

In [None]:
object_accuracy = history.history['object_output_accuracy']
color_accuracy = history.history['color_output_accuracy']
