In [1]:
import tensorflow as tf

import matplotlib.pyplot as plt
import numpy as np


def conv_block(inputs, filters, kernel_size=3):
    x = inputs
    for i in range(2):
        x = tf.keras.layers.Conv3D(filters, kernel_size, kernel_initializer="he_normal",
                                   padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)  # Apply BatchNorm
        x = tf.keras.layers.Activation("relu")(x)
    return x


def encoder_block(inputs, filters=64):
    f = conv_block(inputs, filters=filters)  # output for the concat
    p = tf.keras.layers.MaxPooling3D((2, 2, 2))(f)
    p = tf.keras.layers.Dropout(0.2)(p)
    return f, p


def encoder(inputs):
    filters = 64
    f1, p1 = encoder_block(inputs, filters)
    f2, p2 = encoder_block(p1, filters * 2)
    f3, p3 = encoder_block(p2, filters * 4)
    f4, p4 = encoder_block(p3, filters * 8)
    return p4, (f1, f2, f3, f4)


def bottleneck(inputs):
    bottle_neck = conv_block(inputs, 1024)
    return bottle_neck


def decoder_block(inputs, conv_out, filters=512, kernel_size=3, strides=(2, 2, 2)):
    u = tf.keras.layers.Conv3DTranspose(filters, kernel_size, strides, padding="same")(inputs)
    cat = tf.keras.layers.concatenate([u, conv_out])
    cat = tf.keras.layers.Dropout(0.2)(cat)
    cat = conv_block(cat, filters, kernel_size)
    return cat


def decoder(inputs, convs, out_channels, filters=512):
    f1, f2, f3, f4 = convs

    c6 = decoder_block(inputs, f4, filters)
    c7 = decoder_block(c6, f3, filters // 2)
    c8 = decoder_block(c7, f2, filters // 4)
    c9 = decoder_block(c8, f1, filters // 8)

    outputs = tf.keras.layers.Conv3D(out_channels, 1, activation="softmax")(c9)
    return outputs


In [2]:
OUTPUT_CHANNELS=4
def unet():
    
    inputs= tf.keras.layers.Input(shape=(128, 128, 128, 3))
    #inputs= tf.keras.layers.Lambda(lambda x: x/255)(inputs)
    encoder_output, convs = encoder(inputs)
    bottle_neck= bottleneck(encoder_output)
    outputs= decoder(bottle_neck, convs, OUTPUT_CHANNELS)
    model= tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

model= unet()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 1  0           []                               
                                28, 3)]                                                           
                                                                                                  
 conv3d (Conv3D)                (None, 128, 128, 12  5248        ['input_1[0][0]']                
                                8, 64)                                                            
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 12  256        ['conv3d[0][0]']                 
 alization)                     8, 64)                                                        

                                                                                                  
 activation_6 (Activation)      (None, 16, 16, 16,   0           ['batch_normalization_6[0][0]']  
                                512)                                                              
                                                                                                  
 conv3d_7 (Conv3D)              (None, 16, 16, 16,   7078400     ['activation_6[0][0]']           
                                512)                                                              
                                                                                                  
 batch_normalization_7 (BatchNo  (None, 16, 16, 16,   2048       ['conv3d_7[0][0]']               
 rmalization)                   512)                                                              
                                                                                                  
 activatio

                                256)                                                              
                                                                                                  
 batch_normalization_13 (BatchN  (None, 32, 32, 32,   1024       ['conv3d_13[0][0]']              
 ormalization)                  256)                                                              
                                                                                                  
 activation_13 (Activation)     (None, 32, 32, 32,   0           ['batch_normalization_13[0][0]'] 
                                256)                                                              
                                                                                                  
 conv3d_transpose_2 (Conv3DTran  (None, 64, 64, 64,   884864     ['activation_13[0][0]']          
 spose)                         128)                                                              
          