In [10]:
import tensorflow as tf

class UNet(tf.keras.Model):
    def __init__(self, num_classes):
        super(UNet, self).__init__()

        self.down1 = DownsampleBlock(64, 2)
        self.down2 = DownsampleBlock(128, 2)
        self.down3 = DownsampleBlock(256, 2)
        self.down4 = DownsampleBlock(512, 2)

        self.bottom = ConvBlock(1024)

        self.up1 = UpsampleBlock(512, 2)
        self.up2 = UpsampleBlock(256, 2)
        self.up3 = UpsampleBlock(128, 2)
        self.up4 = UpsampleBlock(64, 2)

        self.out_conv = tf.keras.layers.Conv2D(num_classes, kernel_size=1, activation='softmax')

    def call(self, inputs, training=False):
        down1 = self.down1(inputs, training=training)
        down2 = self.down2(down1, training=training)
        down3 = self.down3(down2, training=training)
        down4 = self.down4(down3, training=training)

        bottom = self.bottom(down4, training=training)

        up1 = self.up1(bottom, down4, training=training)
        up2 = self.up2(up1, down3, training=training)
        up3 = self.up3(up2, down2, training=training)
        up4 = self.up4(up3, down1, training=training)

        output = self.out_conv(up4)
        return output

class DownsampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, num_layers):
        super(DownsampleBlock, self).__init__()
        self.conv_layers = []
        self.maxpool = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))
        self.batchnorm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.dropout = tf.keras.layers.Dropout(0.2)

        for _ in range(num_layers):
            self.conv_layers.append(ConvBlock(filters))

    def call(self, inputs, training=False):
        x = inputs

        for conv in self.conv_layers:
            x = conv(x, training=training)
            x = self.batchnorm(x, training=training)
            x = self.relu(x)
            x = self.dropout(x, training=training)

        x = self.maxpool(x)
        return x

class UpsampleBlock(tf.keras.layers.Layer):
    def __init__(self, filters, num_layers):
        super(UpsampleBlock, self).__init__()
        self.filters = filters
        self.num_layers = num_layers
        self.conv_layers = []
        self.upsample = tf.keras.layers.Conv2DTranspose(filters, kernel_size=2, strides=2, padding='same')
        self.resize = tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')
        self.batchnorm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.dropout = tf.keras.layers.Dropout(0.2)

        for _ in range(self.num_layers):
            self.conv_layers.append(ConvBlock(self.filters))

    def call(self, inputs, skip, training=False):
        x = self.upsample(inputs)
        skip = self.resize(skip)
        x = tf.concat([x, skip], axis=-1)

        for conv in self.conv_layers:
            x = conv(x, training=training)
            x = self.batchnorm(x, training=training)
            x = self.relu(x)
            x = self.dropout(x, training=training)

        return x


class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, filters):
        super(ConvBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters, kernel_size=3, padding='same')
        self.conv2 = tf.keras.layers.Conv2D(filters, kernel_size=3, padding='same')
        self.batchnorm = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        self.dropout = tf.keras.layers.Dropout(0.2)

    def call(self, inputs, training=False):
        x = self.conv1(inputs)
        x = self.batchnorm(x, training=training)
        x = self.relu(x)
        x = self.dropout(x, training=training)

        x = self.conv2(x)
        x = self.batchnorm(x, training=training)
        x = self.relu(x)
        x = self.dropout(x, training=training)

        return x


In [11]:
num_classes = 7
model = UNet(num_classes)

# Build the model with an input shape of 512x512x3
input_shape = (None, 512, 512, 3)
model.build(input_shape)

# Print the model summary
model.summary()

# Generate a sample input tensor
sample_input = tf.random.normal((1, 512, 512, 3))

# Get the model's output
output = model(sample_input)

Model: "u_net_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 downsample_block_16 (Downsa  multiple                 113344    
 mpleBlock)                                                      
                                                                 
 downsample_block_17 (Downsa  multiple                 518144    
 mpleBlock)                                                      
                                                                 
 downsample_block_18 (Downsa  multiple                 2068480   
 mpleBlock)                                                      
                                                                 
 downsample_block_19 (Downsa  multiple                 8265728   
 mpleBlock)                                                      
                                                                 
 conv_block_67 (ConvBlock)   multiple                  1416