In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50

In [11]:
def DeeplabV3Plus(input_shape=(512, 512, 3), num_classes=21):
    # Input Layer
    inputs = layers.Input(shape=input_shape)

    # Backbone: ResNet50
    base_model = ResNet50(weights='imagenet', include_top=False, input_tensor=inputs)

    # Extract feature maps
    x = base_model.get_layer("conv4_block6_2_relu").output  # low-level features
    low_level_features = base_model.get_layer("conv2_block3_2_relu").output  # early features

    # ASPP (Atrous Spatial Pyramid Pooling)
    b0 = layers.Conv2D(256, 1, padding='same', dilation_rate=1, activation='relu')(x)
    b1 = layers.Conv2D(256, 3, padding='same', dilation_rate=6, activation='relu')(x)
    b2 = layers.Conv2D(256, 3, padding='same', dilation_rate=12, activation='relu')(x)
    b3 = layers.Conv2D(256, 3, padding='same', dilation_rate=18, activation='relu')(x)

    # Image pooling branch
    pool = layers.GlobalAveragePooling2D()(x)
    pool = layers.Reshape((1, 1, -1))(pool)
    pool = layers.Conv2D(256, 1, padding='same', activation='relu')(pool)
    pool = layers.UpSampling2D(size=(32, 32), interpolation='bilinear')(pool) 

    # Concatenate ASPP branches
    x = layers.Concatenate()([b0, b1, b2, b3, pool])
    x = layers.Conv2D(256, 1, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)

    # Process low-level features
    low_level = layers.Conv2D(48, 1, padding='same', activation='relu')(low_level_features)
    low_level = layers.BatchNormalization()(low_level)
    low_level = layers.Activation('relu')(low_level)

    # Concatenate and decode
    x = layers.Concatenate()([x, low_level])
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # Final classifier
    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)
    outputs = layers.Conv2D(num_classes, 1, padding='same', activation='softmax')(x)

    # Create model
    model = models.Model(inputs, outputs)
    return model


In [13]:
model = DeeplabV3Plus(input_shape=(512, 512, 3), num_classes=21)

In [14]:
model.save('deeplabv3plus.h5')




In [15]:
from tensorflow import keras

model = keras.models.load_model('deeplabv3plus.h5', compile=False)
model.summary()
