In [18]:
# Set Parameters
Height = 256
Width = 256
Channels = 3
kernel_size = (3,3)
shape = (Height, Width)

In [19]:
import tensorflow as tf
# Build the UNet Architecture

inputLayer = tf.keras.layers.Input((Height, Width, Channels))



# Encoder Path / Contraction Path

efficientnet = tf.keras.applications.EfficientNetB7(input_tensor=inputLayer, include_top=False, weights='imagenet')

base = efficientnet.output

In [63]:
conv4 = efficientnet.get_layer('block6a_expand_activation').output
conv3 = efficientnet.get_layer('block5a_expand_activation').output
conv2 = efficientnet.get_layer('block4a_expand_activation').output
conv1 = efficientnet.get_layer('block3a_expand_activation').output

In [64]:
# Decoder Path / Expansion Path

#trans1 = tf.keras.layers.Conv2DTranspose(512, (2,2), strides=2, activation='relu')(base)
#skip4 = tf.keras.layers.Concatenate()([trans1, conv4])
#conv5 = tf.keras.layers.Conv2D(512, kernel_size, padding='same', activation='relu')(skip4)
#conv5 = tf.keras.layers.Conv2D(512, kernel_size, padding='same', activation='relu')(conv5)

trans2 = tf.keras.layers.Conv2DTranspose(256, (2,2), strides=2, activation='relu')(base)
skip3 = tf.keras.layers.Concatenate()([trans2, conv3])
conv6 = tf.keras.layers.Conv2D(256, kernel_size, padding='same', activation='relu')(skip3)
conv6 = tf.keras.layers.Conv2D(256, kernel_size, padding='same', activation='relu')(conv6)

trans3 = tf.keras.layers.Conv2DTranspose(128, (2,2), strides=2, activation='relu')(conv6)
skip2 = tf.keras.layers.Concatenate()([trans3, conv2])
conv7 = tf.keras.layers.Conv2D(128, kernel_size, padding='same', activation='relu')(skip2)
conv7 = tf.keras.layers.Conv2D(128, kernel_size, padding='same', activation='relu')(conv7)

trans4 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=2, activation='relu')(conv7)
skip1 = tf.keras.layers.Concatenate()([trans4, conv1])
conv8 = tf.keras.layers.Conv2D(64, kernel_size, padding='same', activation='relu')(skip1)
conv8 = tf.keras.layers.Conv2D(64, kernel_size, padding='same', activation='relu')(conv8)


# Output Layer
outputLayer = tf.keras.layers.Conv2D(3, (1,1), padding='same', activation='sigmoid')(conv8)


#Build the model
UNet = tf.keras.models.Model(inputs=inputLayer, outputs=outputLayer, name='UNet')

In [59]:
UNet.compile(optimizer="Adam", loss="binary_crossentropy", metrics=['accuracy'])

In [None]:
!pip install keras_flops

In [60]:
# Get the number of parameters
num_params = UNet.count_params()

# Print the number of parameters
print("Number of parameters: ", num_params)

Number of parameters:  71363482


In [25]:
def count_flops(model):
    import warnings
    warnings.simplefilter("ignore", category=UserWarning)
    
    from keras_flops import get_flops
    flops = get_flops(model, batch_size=1)
    print(f"FLOPS: {flops / 10 ** 9:.03} G")

count_flops(UNet)

FLOPS: 36.7 G


## EfficientU-Net

In [28]:
import tensorflow as tf

In [29]:
def conv_block(inputs, filters):
    x = inputs

    x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    x = tf.keras.layers.Conv2D(filters, (3, 3), padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

    return x

In [70]:
def encoder(inputs):
    skip_connections = []

    model = tf.keras.applications.EfficientNetB7(include_top=False, weights='imagenet', input_tensor=inputs)
    names = ["block2a_expand_activation", "block3a_expand_activation", "block4a_expand_activation", "block5a_expand_activation"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.output
    return output, skip_connections

In [31]:
def decoder(inputs, skip_connections):
    num_filters = [256, 128, 64, 32]
    skip_connections.reverse()
    x = inputs

    for i, f in enumerate(num_filters):
        x = tf.keras.layers.UpSampling2D((2, 2), interpolation='bilinear')(x)
        x = tf.keras.layers.Concatenate()([x, skip_connections[i]])
        x = conv_block(x, f)

    return x

In [32]:
def output_block(inputs):
    x = tf.keras.layers.Conv2D(3, (1, 1), padding="same")(inputs)
    x = tf.keras.layers.Activation('sigmoid')(x)
    return x

In [33]:
def Upsample(tensor, size):
    """Bilinear upsampling"""
    def _upsample(x, size):
        return tf.image.resize(images=x, size=size)
    return tf.keras.layers.Lambda(lambda x: _upsample(x, size), output_shape=size)(tensor)

In [35]:
def AC_block(x, filter):
    shape = x.shape

    y1 = tf.keras.layers.AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = tf.keras.layers.Conv2D(filter, 1, padding="same")(y1)
    y1 = tf.keras.layers.BatchNormalization()(y1)
    y1 = tf.keras.layers.Activation("relu")(y1)
    y1 = tf.keras.layers.UpSampling2D((shape[1], shape[2]), interpolation='bilinear')(y1)

    y2 = tf.keras.layers.Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(x)
    y2 = tf.keras.layers.BatchNormalization()(y2)
    y2 = tf.keras.layers.Activation("relu")(y2)

    y3 = tf.keras.layers.Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False)(x)
    y3 = tf.keras.layers.BatchNormalization()(y3)
    y3 = tf.keras.layers.Activation("relu")(y3)

    y4 = tf.keras.layers.Conv2D(filter, 3, dilation_rate=12, padding="same", use_bias=False)(x)
    y4 = tf.keras.layers.BatchNormalization()(y4)
    y4 = tf.keras.layers.Activation("relu")(y4)

    y5 = tf.keras.layers.Conv2D(filter, 3, dilation_rate=18, padding="same", use_bias=False)(x)
    y5 = tf.keras.layers.BatchNormalization()(y5)
    y5 = tf.keras.layers.Activation("relu")(y5)

    y = tf.keras.layers.Concatenate()([y1, y2, y3, y4, y5])

    y = tf.keras.layers.Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(y)
    y = tf.keras.layers.BatchNormalization()(y)
    y = tf.keras.layers.Activation("relu")(y)

    return y

In [68]:
def build_model(shape):
    inputs = tf.keras.layers.Input(shape)
    x, skip_1 = encoder(inputs)
    x = AC_block(x, 64)
    x = decoder(x, skip_1)
    outputs = output_block(x)

    model = tf.keras.models.Model(inputs, outputs)
    return model

In [71]:
model = build_model((256, 256, 3))

In [72]:
model.compile(optimizer="Adam", loss="binary_crossentropy", metrics=['accuracy'])

In [73]:
# Get the number of parameters
num_params = model.count_params()

# Print the number of parameters
print("Number of parameters: ", num_params)

Number of parameters:  73179898


In [74]:
# Calculae FLOPS
import warnings
warnings.simplefilter("ignore", category=UserWarning)

from keras_flops import get_flops
flops = get_flops(model, batch_size=1)
print(f"FLOPS: {flops / 10 ** 9:.03} G")

FLOPS: 24.1 G
