In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

# Define your U-Net model architecture
def unet_model(input_size=(256, 256, 3)):
    inputs = tf.keras.Input(input_size)
   
    # Encoder
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(pool4)
    conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(conv5)

    # Decoder
    up6 = layers.concatenate([layers.Conv2DTranspose(512, 2, strides=(2, 2), padding='same')(conv5), conv4], axis=3)
    up6 = layers.Conv2D(512, 3, activation='relu', padding='same')(up6)
    up6 = layers.Conv2D(512, 3, activation='relu', padding='same')(up6)

    up7 = layers.concatenate([layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(up6), conv3], axis=3)
    up7 = layers.Conv2D(256, 3, activation='relu', padding='same')(up7)
    up7 = layers.Conv2D(256, 3, activation='relu', padding='same')(up7)

    up8 = layers.concatenate([layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(up7), conv2], axis=3)
    up8 = layers.Conv2D(128, 3, activation='relu', padding='same')(up8)
    up8 = layers.Conv2D(128, 3, activation='relu', padding='same')(up8)

    up9 = layers.concatenate([layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(up8), conv1], axis=3)
    up9 = layers.Conv2D(64, 3, activation='relu', padding='same')(up9)
    up9 = layers.Conv2D(64, 3, activation='relu', padding='same')(up9)

    # Output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(up9)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model

# Create an instance of the model
model = unet_model()

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Display the model summary
model.summary()





Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 256, 256, 64)         1792      ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 256, 256, 64)         36928     ['conv2d[0][0]']              
                                                                                                  
 max_pooling2d (MaxPooling2  (None, 128, 128, 64)         0         ['conv2d_1[0][0]']            
 D)                                                                                       