# Try with 3D Unet

In [5]:
import os
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras

IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
IMG_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

def unet(n_levels, initial_features=32, n_blocks=2, kernel_size=3, pooling_size=2, strides=(1, 1, 2),
         image_height=256, image_width=256, image_depth=None, in_channels=1, out_channels=1):
    
    inputs = keras.layers.Input(shape=(image_height, image_width, image_depth,in_channels))
    x = inputs
    
    convpars = dict(kernel_size=kernel_size, activation='relu', padding='same')
    
    #downstream
    skips = {}
    for level in range(n_levels):
        for _ in range(n_blocks):
            x = keras.layers.Conv3D(initial_features * 2 ** level, **convpars)(x)
        if level < n_levels - 1:
            skips[level] = x
            x = keras.layers.MaxPool3D(pooling_size)(x)
            
    # upstream
    for level in reversed(range(n_levels-1)):
        x = keras.layers.Conv3DTranspose(initial_features * 2 ** level, strides=pooling_size, **convpars)(x)
        x = keras.layers.Concatenate()([x, skips[level]])
        for _ in range(n_blocks):
            x = keras.layers.Conv3D(initial_features * 2 ** level, **convpars)(x)
            
    # output
    activation = 'sigmoid' if out_channels == 1 else 'softmax'
    x = keras.layers.Conv3D(out_channels, kernel_size=1, activation=activation, padding='same')(x)
    
    return keras.Model(inputs=[inputs], outputs=[x], name=f'UNET-L{n_levels}-F{initial_features}')

In [10]:
model_3D = unet(n_levels=4, initial_features=64, n_blocks=2, kernel_size=3, pooling_size=2, 
                image_height=128, image_width=128, image_depth=None, in_channels=1, out_channels=1)

In [11]:
model_3D.summary()

Model: "UNET-L4-F64"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 128, 128, No 0                                            
__________________________________________________________________________________________________
conv3d_30 (Conv3D)              (None, 128, 128, Non 1792        input_3[0][0]                    
__________________________________________________________________________________________________
conv3d_31 (Conv3D)              (None, 128, 128, Non 110656      conv3d_30[0][0]                  
__________________________________________________________________________________________________
max_pooling3d_6 (MaxPooling3D)  (None, 64, 64, None, 0           conv3d_31[0][0]                  
________________________________________________________________________________________