Sure, here's a brief overview of the ResUNet architecture:

1. **Encoder (Residual Blocks):**
   - The encoder consists of several layers of convolutional blocks, each followed by a max-pooling layer.
   - Each convolutional block contains two convolutional layers with batch normalization and ReLU activation, followed by a skip connection to preserve spatial information.
   - The skip connection adds the input of the block to the output of the second convolutional layer, ensuring that the gradients flow properly during training.

2. **Bridge:**
   - After several downsampling steps, a bridge connects the encoder to the decoder.
   - It typically consists of another convolutional block without downsampling, preserving the spatial resolution.

3. **Decoder (Upsampling Blocks):**
   - The decoder is responsible for upsampling the feature maps back to the original input resolution.
   - Each upsampling block typically consists of a transposed convolutional layer (or upsampling followed by convolution), followed by a concatenation with the corresponding feature map from the encoder via skip connections.
   - This process allows the decoder to recover spatial details lost during downsampling.

4. **Output Layer:**
   - The final layer of the network is a convolutional layer with sigmoid activation for binary classification tasks (such as in your case with two classes: vegetation and non-vegetation).
   - For multi-class segmentation tasks, the output layer would typically have softmax activation.

In [21]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, Dropout

In [22]:

def resunet(input_shape=(256, 256, 3), num_classes=1):
    inputs = Input(input_shape)
    
    # Encoder
    conv1 = resblock(inputs, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    conv2 = resblock(pool1, 128)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    conv3 = resblock(pool2, 256)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    conv4 = resblock(pool3, 512)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bridge
    conv5 = resblock(pool4, 1024)
    
    # Decoder
    upconv6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(conv5)
    skip_connection4 = conv4
    upconv6 = concatenate([upconv6, skip_connection4], axis=-1)
    conv6 = resblock(upconv6, 512)
    
    upconv7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    skip_connection3 = conv3
    upconv7 = concatenate([upconv7, skip_connection3], axis=-1)
    conv7 = resblock(upconv7, 256)
    
    upconv8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    skip_connection2 = conv2
    upconv8 = concatenate([upconv8, skip_connection2], axis=-1)
    conv8 = resblock(upconv8, 128)
    
    upconv9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    skip_connection1 = conv1
    upconv9 = concatenate([upconv9, skip_connection1], axis=-1)
    conv9 = resblock(upconv9, 64)
    
    # Output layer
    outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)
    
    # Create model
    model = Model(inputs=inputs, outputs=outputs)
    
    return model

In [23]:
def resblock(x, filters, kernel_size=(3, 3), activation='relu', padding='same'):
    conv1 = Conv2D(filters, kernel_size, activation=activation, padding=padding)(x)
    conv2 = Conv2D(filters, kernel_size, activation=None, padding=padding)(conv1)
    if x.shape[-1] != conv2.shape[-1]:
        x = Conv2D(filters, (1, 1), padding='same')(x)  # Adjust the number of filters
    add = tf.keras.layers.add([x, conv2])
    out = tf.keras.layers.Activation(activation)(add)
    return out

In [24]:
if __name__ == "__main__":
    model = resunet()
    model.save('resunet.keras')
