In [8]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, Concatenate, BatchNormalization, Activation, Dense

class Basic_UNet(tf.keras.Model):
    def __init__(self):
        super(Basic_UNet, self).__init__()

        # Encoder
        self.conv1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv2 = Conv2D(64, 3, padding='same', activation='relu')
        self.pool1 = MaxPooling2D(pool_size=(2, 2))

        self.conv3 = Conv2D(128, 3, padding='same', activation='relu')
        self.conv4 = Conv2D(128, 3, padding='same', activation='relu')
        self.pool2 = MaxPooling2D(pool_size=(2, 2))

        self.conv5 = Conv2D(256, 3, padding='same', activation='relu')
        self.conv6 = Conv2D(256, 3, padding='same', activation='relu')
        self.pool3 = MaxPooling2D(pool_size=(2, 2))

        self.conv7 = Conv2D(512, 3, padding='same', activation='relu')
        self.conv8 = Conv2D(512, 3, padding='same', activation='relu')
        self.pool4 = MaxPooling2D(pool_size=(2, 2))

        # Bridge
        self.conv9 = Conv2D(1024, 3, padding='same', activation='relu')
        self.conv10 = Conv2D(1024, 3, padding='same', activation='relu')

        # Decoder
        self.upconv1 = Conv2DTranspose(512, 3, strides=2, padding='same', activation='relu')
        self.conv11 = Conv2D(512, 3, padding='same', activation='relu')
        self.conv12 = Conv2D(512, 3, padding='same', activation='relu')

        self.upconv2 = Conv2DTranspose(256, 3, strides=2, padding='same', activation='relu')
        self.conv13 = Conv2D(256, 3, padding='same', activation='relu')
        self.conv14 = Conv2D(256, 3, padding='same', activation='relu')

        self.upconv3 = Conv2DTranspose(128, 3, strides=2, padding='same', activation='relu')
        self.conv15 = Conv2D(128, 3, padding='same', activation='relu')
        self.conv16 = Conv2D(128, 3, padding='same', activation='relu')

        self.upconv4 = Conv2DTranspose(64, 3, strides=2, padding='same', activation='relu')
        self.conv17 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv18 = Conv2D(64, 3, padding='same', activation='relu')

        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(128, activation='relu')

        self.output_conv = Conv2D(2, 1, padding='same', activation='linear')

    def call(self, inputs):
        moving_image, fixed_image = inputs
        x = tf.concat([moving_image, fixed_image], axis=-1)

        # Encoder
        e1 = self.conv1(x)
        e1 = self.conv2(e1)
        p1 = self.pool1(e1)

        e2 = self.conv3(p1)
        e2 = self.conv4(e2)
        p2 = self.pool2(e2)

        e3 = self.conv5(p2)
        e3 = self.conv6(e3)
        p3 = self.pool3(e3)

        e4 = self.conv7(p3)
        e4 = self.conv8(e4)
        p4 = self.pool4(e4)

        # Bridge
        b = self.conv9(p4)
        b = self.conv10(b)

        # Decoder
        d1 = self.upconv1(b)
        d1 = Concatenate()([d1, e4])
        d1 = self.conv11(d1)
        d1 = self.conv12(d1)

        d2 = self.upconv2(d1)
        d2 = Concatenate()([d2, e3])
        d2 = self.conv13(d2)
        d2 = self.conv14(d2)

        d3 = self.upconv3(d2)
        d3 = Concatenate()([d3, e2])
        d3 = self.conv15(d3)
        d3 = self.conv16(d3)

        d4 = self.upconv4(d3)
        d4 = Concatenate()([d4, e1])
        d4 = self.conv17(d4)
        d4 = self.conv18(d4)

        dense1 = self.dense1(d4)
        dense2 = self.dense2(dense1)

        output_def = self.output_conv(dense2)

        return output_def

fixed_input = Input(shape=(128, 128, 1), name="fixed_image")
moving_input = Input(shape=(128, 128, 1), name="moving_image")

unet = Basic_UNet()
out_def = unet([moving_input, fixed_input])

model = Model(inputs=[moving_input, fixed_input], outputs=out_def)

# Print model summary
model.summary()

Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 moving_image (InputLayer)   [(None, 128, 128, 1)]        0         []                            
                                                                                                  
 fixed_image (InputLayer)    [(None, 128, 128, 1)]        0         []                            
                                                                                                  
 basic_u_net_3 (Basic_UNet)  (None, 128, 128, 2)          3453779   ['moving_image[0][0]',        
                                                          4          'fixed_image[0][0]']         
                                                                                                  
Total params: 34537794 (131.75 MB)
Trainable params: 34537794 (131.75 MB)
Non-trainable para

In [9]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Conv2D, Conv2DTranspose, MaxPooling2D, Concatenate, 
    BatchNormalization, Activation, Dense
)

class Basic_UNet2(tf.keras.Model):
    def __init__(self):
        super(Basic_UNet2, self).__init__()

        # Encoder
        self.conv1 = Conv2D(64, 3, padding='same')
        self.bn1 = BatchNormalization()
        self.act1 = Activation('relu')

        self.conv2 = Conv2D(64, 3, padding='same')
        self.bn2 = BatchNormalization()
        self.act2 = Activation('relu')
        self.pool1 = MaxPooling2D(pool_size=(2, 2))

        self.conv3 = Conv2D(128, 3, padding='same')
        self.bn3 = BatchNormalization()
        self.act3 = Activation('relu')

        self.conv4 = Conv2D(128, 3, padding='same')
        self.bn4 = BatchNormalization()
        self.act4 = Activation('relu')
        self.pool2 = MaxPooling2D(pool_size=(2, 2))

        self.conv5 = Conv2D(256, 3, padding='same')
        self.bn5 = BatchNormalization()
        self.act5 = Activation('relu')

        self.conv6 = Conv2D(256, 3, padding='same')
        self.bn6 = BatchNormalization()
        self.act6 = Activation('relu')
        self.pool3 = MaxPooling2D(pool_size=(2, 2))

        self.conv7 = Conv2D(512, 3, padding='same')
        self.bn7 = BatchNormalization()
        self.act7 = Activation('relu')

        self.conv8 = Conv2D(512, 3, padding='same')
        self.bn8 = BatchNormalization()
        self.act8 = Activation('relu')
        self.pool4 = MaxPooling2D(pool_size=(2, 2))

        # Bridge
        self.conv9 = Conv2D(1024, 3, padding='same')
        self.bn9 = BatchNormalization()
        self.act9 = Activation('relu')

        self.conv10 = Conv2D(1024, 3, padding='same')
        self.bn10 = BatchNormalization()
        self.act10 = Activation('relu')

        # Decoder
        self.upconv1 = Conv2DTranspose(512, 3, strides=2, padding='same')
        self.bn11 = BatchNormalization()
        self.act11 = Activation('relu')

        self.conv11 = Conv2D(512, 3, padding='same')
        self.bn12 = BatchNormalization()
        self.act12 = Activation('relu')

        self.conv12 = Conv2D(512, 3, padding='same')
        self.bn13 = BatchNormalization()
        self.act13 = Activation('relu')

        self.upconv2 = Conv2DTranspose(256, 3, strides=2, padding='same')
        self.bn14 = BatchNormalization()
        self.act14 = Activation('relu')

        self.conv13 = Conv2D(256, 3, padding='same')
        self.bn15 = BatchNormalization()
        self.act15 = Activation('relu')

        self.conv14 = Conv2D(256, 3, padding='same')
        self.bn16 = BatchNormalization()
        self.act16 = Activation('relu')

        self.upconv3 = Conv2DTranspose(128, 3, strides=2, padding='same')
        self.bn17 = BatchNormalization()
        self.act17 = Activation('relu')

        self.conv15 = Conv2D(128, 3, padding='same')
        self.bn18 = BatchNormalization()
        self.act18 = Activation('relu')

        self.conv16 = Conv2D(128, 3, padding='same')
        self.bn19 = BatchNormalization()
        self.act19 = Activation('relu')

        self.upconv4 = Conv2DTranspose(64, 3, strides=2, padding='same')
        self.bn20 = BatchNormalization()
        self.act20 = Activation('relu')

        self.conv17 = Conv2D(64, 3, padding='same')
        self.bn21 = BatchNormalization()
        self.act21 = Activation('relu')

        self.conv18 = Conv2D(64, 3, padding='same')
        self.bn22 = BatchNormalization()
        self.act22 = Activation('relu')

        self.dense1 = Dense(128, activation='relu')
        self.dense2 = Dense(128, activation='relu')

        self.output_conv = Conv2D(2, 1, padding='same', activation='linear')

    def call(self, inputs):
        moving_image, fixed_image = inputs
        x = tf.concat([moving_image, fixed_image], axis=-1)

        # Encoder
        e1 = self.conv1(x)
        e1 = self.bn1(e1)
        e1 = self.act1(e1)

        e1 = self.conv2(e1)
        e1 = self.bn2(e1)
        e1 = self.act2(e1)

        p1 = self.pool1(e1)

        e2 = self.conv3(p1)
        e2 = self.bn3(e2)
        e2 = self.act3(e2)

        e2 = self.conv4(e2)
        e2 = self.bn4(e2)
        e2 = self.act4(e2)

        p2 = self.pool2(e2)

        e3 = self.conv5(p2)
        e3 = self.bn5(e3)
        e3 = self.act5(e3)

        e3 = self.conv6(e3)
        e3 = self.bn6(e3)
        e3 = self.act6(e3)

        p3 = self.pool3(e3)

        e4 = self.conv7(p3)
        e4 = self.bn7(e4)
        e4 = self.act7(e4)

        e4 = self.conv8(e4)
        e4 = self.bn8(e4)
        e4 = self.act8(e4)

        p4 = self.pool4(e4)

        # Bridge
        b = self.conv9(p4)
        b = self.bn9(b)
        b = self.act9(b)

        b = self.conv10(b)
        b = self.bn10(b)
        b = self.act10(b)

        # Decoder
        d1 = self.upconv1(b)
        d1 = self.bn11(d1)
        d1 = self.act11(d1)

        d1 = Concatenate()([d1, e4])
        d1 = self.conv11(d1)
        d1 = self.bn12(d1)
        d1 = self.act12(d1)

        d1 = self.conv12(d1)
        d1 = self.bn13(d1)
        d1 = self.act13(d1)

        d2 = self.upconv2(d1)
        d2 = self.bn14(d2)
        d2 = self.act14(d2)

        d2 = Concatenate()([d2, e3])
        d2 = self.conv13(d2)
        d2 = self.bn15(d2)
        d2 = self.act15(d2)

        d2 = self.conv14(d2)
        d2 = self.bn16(d2)
        d2 = self.act16(d2)

        d3 = self.upconv3(d2)
        d3 = self.bn17(d3)
        d3 = self.act17(d3)

        d3 = Concatenate()([d3, e2])
        d3 = self.conv15(d3)
        d3 = self.bn18(d3)
        d3 = self.act18(d3)

        d3 = self.conv16(d3)
        d3 = self.bn19(d3)
        d3 = self.act19(d3)

        d4 = self.upconv4(d3)
        d4 = self.bn20(d4)
        d4 = self.act20(d4)

        d4 = Concatenate()([d4, e1])
        d4 = self.conv17(d4)
        d4 = self.bn21(d4)
        d4 = self.act21(d4)

        d4 = self.conv18(d4)
        d4 = self.bn22(d4)
        d4 = self.act22(d4)

        dense1 = self.dense1(d4)
        dense2 = self.dense2(dense1)

        output_def = self.output_conv(dense2)

        return output_def
    

fixed_input = Input(shape=(128, 128, 1), name="fixed_image")
moving_input = Input(shape=(128, 128, 1), name="moving_image")

unet = Basic_UNet2()
out_def = unet([moving_input, fixed_input])

model = Model(inputs=[moving_input, fixed_input], outputs=out_def)

# Print model summary
model.summary()

Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 moving_image (InputLayer)   [(None, 128, 128, 1)]        0         []                            
                                                                                                  
 fixed_image (InputLayer)    [(None, 128, 128, 1)]        0         []                            
                                                                                                  
 basic_u_net2_1 (Basic_UNet  (None, 128, 128, 2)          3456518   ['moving_image[0][0]',        
 2)                                                       6          'fixed_image[0][0]']         
                                                                                                  
Total params: 34565186 (131.86 MB)
Trainable params: 34551490 (131.80 MB)
Non-trainable para