In [6]:
'''
Sample implementation of MTL

reference:
https://www.geeksforgeeks.org/deep-learning/multi-task-learning-scenario-in-tensorflow/
'''

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.models import Model

def build_multi_task_model(input_shape, num_classes):
    # Input Layer
    inputs = Input(shape=input_shape)

    # Shared layers
    x = Dense(128, activation='relu')(inputs)
    x = Dense(64, activation='relu')(x)

    # Task 1: Regression Output
    reg_output = Dense(1, name='regression_output')(x)  # Assuming the target is a single continuous value

    # Task 2: Classification Output
    class_output = Dense(num_classes, activation='softmax', name='classification_output')(x)

    # Build the Model
    model = Model(inputs=inputs, outputs=[reg_output, class_output])

    return model

# Model configuration
input_shape = (10,)  # Example input size (e.g., 10 features)
num_classes = 3     # Example number of classes for classification

# Build the model
model = build_multi_task_model(input_shape, num_classes)

# Compile the model with different losses and metrics for each task
model.compile(optimizer='adam',
              loss={'regression_output': 'mse', 'classification_output': 'sparse_categorical_crossentropy'},
              metrics={'regression_output': ['mae'], 'classification_output': ['accuracy']})

# Summary of the model
model.summary()

# Hypothetical datasets
import numpy as np

# Generate random data (example)
train_data = np.random.random((1000, 10))
train_labels_regression = np.random.random((1000, 1))  # Regression targets
train_labels_classification = np.random.randint(0, num_classes, (1000,))  # Classification targets

# Train the model
model.fit(train_data, {'regression_output': train_labels_regression, 'classification_output': train_labels_classification}, epochs=10)


Epoch 1/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - classification_output_accuracy: 0.3349 - classification_output_loss: 1.1021 - loss: 1.2613 - regression_output_loss: 0.1592 - regression_output_mae: 0.3332
Epoch 2/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - classification_output_accuracy: 0.3464 - classification_output_loss: 1.0998 - loss: 1.1902 - regression_output_loss: 0.0905 - regression_output_mae: 0.2555
Epoch 3/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - classification_output_accuracy: 0.3771 - classification_output_loss: 1.0954 - loss: 1.1812 - regression_output_loss: 0.0858 - regression_output_mae: 0.2507
Epoch 4/10
[1m32/32[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - classification_output_accuracy: 0.3661 - classification_output_loss: 1.0921 - loss: 1.1798 - regression_output_loss: 0.0876 - regression_output_mae: 0.2508
Epoch 5/10
[1m32/32[0m [32m━━

<keras.src.callbacks.history.History at 0x787e3d8dab90>