In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model

def conv_block(x, filters):
    x = keras.layers.SeparableConv2D(filters, (3, 3), padding='same')(x)
    x = keras.layers.ReLU()(x)
    return x

# Attention Block
def attention_block(skip_connection, gating_signal, filters):
    theta_x = keras.layers.Conv2D(filters, (1, 1), strides=(2, 2), padding='same')(skip_connection)
    phi_g = keras.layers.Conv2D(filters, (1, 1), padding='same')(gating_signal)
    add = keras.layers.Add()([theta_x, phi_g])
    relu_add = keras.layers.ReLU()(add)
    psi = keras.layers.Conv2D(1, (1, 1), padding='same', activation='sigmoid')(relu_add)
    upsample_psi = keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(psi)
    attention_output = keras.layers.Multiply()([skip_connection, upsample_psi])
    return attention_output

# U-Net model with attention blocks
def build_model___(input_shape, output_activation='sigmoid'):
    inputs = tf.keras.layers.Input(input_shape) #256,256,3

    # Encoder block 1
    e_b1_c1 = keras.layers.Conv2D(32, (1, 1), padding='same')(inputs) #256,256,32
    e_b1_c1_a1 = keras.layers.ReLU()(e_b1_c1) #256,256,32

    e_b1_c2 = conv_block(e_b1_c1_a1, 32) #256,256,32

    e_b1_c3 = keras.layers.Conv2D(64, (1, 1), padding='same')(e_b1_c2) #256,256,64

    shortcut_1 = keras.layers.Conv2D(64, (1, 1), padding='same')(inputs)  # Create shortcut
    e_b1_c3 = keras.layers.Add()([e_b1_c3, shortcut_1])  # Add the shortcut

    e_b1_c3_a3 = keras.layers.ReLU()(e_b1_c3) #256,256,64
    e_b1_pool_1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(e_b1_c3_a3) #128,128,64

    # Encider block 2
    e_b2_c1 = keras.layers.Conv2D(64, (1, 1), padding='same')(e_b1_pool_1) #128,128,64
    e_b2_c1_a1 = keras.layers.ReLU()(e_b2_c1) #128,128,64

    e_b2_c2 = conv_block(e_b2_c1_a1, 64) #128,128,64

    e_b2_c3 = keras.layers.Conv2D(128, (1, 1), padding='same')(e_b2_c2) #128,128,128

    shortcut_2 = keras.layers.Conv2D(128, (1, 1), padding='same')(e_b1_pool_1)  # Create shortcut
    e_b2_c3 = keras.layers.Add()([e_b2_c3, shortcut_2])  # Add the shortcut

    e_b2_c3_a3 = keras.layers.ReLU()(e_b2_c3) #128,128,128
    e_b2_pool_2 = keras.layers.MaxPooling2D(pool_size=(2, 2))(e_b2_c3_a3) #64,64,128


    #Encoder block 3
    e_b3_c1 = keras.layers.Conv2D(128, (1, 1), padding='same')(e_b2_pool_2) #64,64,128
    e_b3_c1_a1 = keras.layers.ReLU()(e_b3_c1) #64,64,128

    e_b3_c2 = conv_block(e_b3_c1_a1, 128) #64,64,128

    e_b3_c3 = keras.layers.Conv2D(256, (1, 1), padding='same')(e_b3_c2) #64,64,256

    shortcut_3 = keras.layers.Conv2D(256, (1, 1), padding='same')(e_b2_pool_2)  # Create shortcut
    e_b3_c3 = keras.layers.Add()([e_b3_c3, shortcut_3])  # Add the shortcut

    e_b3_c3_a3 = keras.layers.ReLU()(e_b3_c3) #64,64,256
    e_b3_pool_3 = keras.layers.MaxPooling2D(pool_size=(2, 2))(e_b3_c3_a3) #32,32,256

    # Center block
    center_c1 = keras.layers.Conv2D(256, (1, 1), padding='same')(e_b3_pool_3) #32,32,256
    center_c1_a1 = keras.layers.ReLU()(center_c1) #32,32,256
    center_c2 = conv_block(center_c1_a1, 256) #32,32,256
    center_c3 = keras.layers.Conv2D(512, (1, 1), padding='same')(center_c2) #32,32,512
    center_c3_a3 = keras.layers.ReLU()(center_c3) #32,32,512

    center_c4 = conv_block(center_c3_a3, 256) #32,32,256
    center_c5 = keras.layers.Conv2D(256, (1, 1), padding='same')(center_c4) #32,32,256
    center_c5_a5 = keras.layers.ReLU()(center_c5) #32,32,256

    # Decoder with attention
    upsampling1 = keras.layers.UpSampling2D(size=(2, 2))(center_c5_a5) #64,64,512
    attention1 = attention_block(e_b3_c3_a3, center_c5_a5, 256)
    dec_merged_1 = keras.layers.Concatenate(axis=3)([attention1, upsampling1])
    dec_b3_c1 = keras.layers.Conv2D(256, (1, 1), padding='same')(dec_merged_1) #64,64,256
    dec_b3_c1_a1 = keras.layers.ReLU()(dec_b3_c1) #64,64,256
    dec_b3_c2 = conv_block(dec_b3_c1_a1, 128) #64,64,128
    dec_b3_c3 = keras.layers.Conv2D(128, (1, 1), padding='same')(dec_b3_c2) #64,64,128
    dec_b3_c3_a3 = keras.layers.ReLU()(dec_b3_c3) #64,64,128

    upsampling2 = keras.layers.UpSampling2D(size=(2, 2))(dec_b3_c3_a3) #128,128,256
    attention2 = attention_block(e_b2_c3_a3, dec_b3_c3_a3, 128)
    dec_merged_2 = keras.layers.Concatenate(axis=3)([attention2, upsampling2])
    dec_b2_c1 = keras.layers.Conv2D(128, (1, 1), padding='same')(dec_merged_2) #128,128,128
    dec_b2_c1_a1 = keras.layers.ReLU()(dec_b2_c1) #128,128,128
    dec_b2_c2 = conv_block(dec_b2_c1_a1, 64) #128,128,64
    dec_b2_c3 = keras.layers.Conv2D(64, (1, 1), padding='same')(dec_b2_c2) #128,128,64
    dec_b2_c3_a3 = keras.layers.ReLU()(dec_b2_c3) #128,128,64

    upsampling3 = keras.layers.UpSampling2D(size=(2, 2))(dec_b2_c3_a3) #256,256,128
    attention3 = attention_block(e_b1_c3_a3, dec_b2_c3_a3, 64)
    dec_merged_3 = keras.layers.Concatenate(axis=3)([attention3, upsampling3])
    dec_b3_c1 = keras.layers.Conv2D(64, (1, 1), padding='same')(dec_merged_3) #256,256,64
    dec_b3_c1_a1 = keras.layers.ReLU()(dec_b3_c1) #256,256,64
    dec_b3_c2 = conv_block(dec_b3_c1_a1, 32) #256,256,32
    dec_b3_c3 = keras.layers.Conv2D(32, (1, 1), padding='same')(dec_b3_c2) #256,256,32
    dec_b3_c3_a3 = keras.layers.ReLU()(dec_b3_c3) #256,256,32

    # Output layer
    output = keras.layers.Conv2D(1, (1, 1), activation=output_activation)(dec_b3_c3_a3) #256,256,1

    return keras.Model(inputs=inputs, outputs=output)

def visualize_model(model):
    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file='min2_split_aug.png', show_shapes=True, show_layer_names=True)

# Example of building the model
model = build_model___((256, 256, 3))
model.summary()
#plot_model(model_att___, to_file='test_att_res.png', show_shapes=True, show_layer_names=True)

In [None]:
import tensorflow as tf

def get_flops(model, input_shape=(1, 256, 256, 3)):
    # Generate a concrete function for the model to perform a single forward pass
    input_data = tf.random.normal(input_shape)
    concrete_func = tf.function(lambda x: model(x)).get_concrete_function(tf.TensorSpec(input_shape, tf.float32))

    # Calculate FLOPS using a TensorFlow utility
    try:
        # Profile FLOPS by generating graph information
        frozen_func = concrete_func
        run_meta = tf.compat.v1.RunMetadata()
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

        # Use the profiler to get FLOPS based on the frozen graph
        flops = tf.compat.v1.profiler.profile(graph=frozen_func.graph, run_meta=run_meta, options=opts)
        flops_count = flops.total_float_ops  # Total FLOPS estimate
    except Exception as e:
        print("Error calculating FLOPS:", e)
        flops_count = 0

    # Convert to GFLOPS
    gflops = flops_count / 1e9
    return flops_count, gflops

# Example usage
model = build_model___((256, 256, 3))  # Replace build_model___ with your model function
flops, gflops = get_flops(model)
print(f"FLOPS: {flops}")
print(f"GFLOPS: {gflops:.2f} GFLOPS")


Instructions for updating:
This API was designed for TensorFlow v1. See https://www.tensorflow.org/guide/migrate for instructions on how to migrate your code to TensorFlow v2.


FLOPS: 8511116294
GFLOPS: 8.51 GFLOPS
