In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import MeanIoU



In [None]:
def multitask_resnet18(input_shape, num_classes):
    # Load the ResNet-18 backbone
    base_model = tf.keras.applications.ResNet50(include_top=False, input_shape=input_shape, weights='imagenet')
    base_model.trainable = False

    # Create the input layer
    inputs = Input(shape=input_shape)

    # Pass inputs through the ResNet backbone
    x = base_model(inputs, training=False)

    # Global average pooling layer
    x_gap = layers.GlobalAveragePooling2D()(x)

    # Classification head
    classification_output = layers.Dense(num_classes, activation='softmax', name='classification')(x_gap)

    # Segmentation head
    # Reshape tensor to add spatial dimensions
    x_seg = layers.Reshape((1, 1, x.shape[-1]))(x_gap)
    x_seg = layers.Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same')(x)
    x_seg = layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(x_seg)
    x_seg = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(x_seg)
    x_seg = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(x_seg)
    segmentation_output = layers.Conv2DTranspose(1, (3, 3), strides=(2, 2), activation='sigmoid', padding='same', name='segmentation')(x_seg)

    # Combine inputs and outputs into a model
    model = Model(inputs=inputs, outputs=[classification_output, segmentation_output])

    return model

# Define input shape and number of classes
input_shape = (224, 224, 3)  # Example input shape for ResNet
num_classes = 10  # Example number of classes for classification

# Create the multitask model
model = multitask_resnet18(input_shape, num_classes)

# Compile the model
model.compile(optimizer='adam',
              loss={'classification': 'sparse_categorical_crossentropy',
                    'segmentation': 'binary_crossentropy'},
              metrics={'classification': 'accuracy',
                       'segmentation': 'accuracy'})

# Print the model summary
model.summary()


Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_8 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 resnet50 (Functional)       (None, 7, 7, 2048)           2358771   ['input_8[0][0]']             
                                                          2                                       
                                                                                                  
 conv2d_transpose_5 (Conv2D  (None, 14, 14, 512)          9437696   ['resnet50[0][0]']            
 Transpose)                                                                                       
                                                                                            