In [77]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Input, Model
def lightweight_block(input_tensor):
    # Depthwise convolution (preserves input dimensions)
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(input_tensor)
    
    # 1x1 convolution to mix channels and increase to 8
    x = layers.Conv2D(8, kernel_size=1, padding='same')(x)
    
    # 3x3 convolution block with BatchNorm and ReLU
    x = layers.Conv2D(8, kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "A")(x)
    
    return x

def downsample_block(input_tensor):
    # Depthwise convolution with stride=2 for downsampling
    x = layers.DepthwiseConv2D(kernel_size=3, strides=2, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # 1x1 convolution for channel mixing
    x = layers.Conv2D(input_tensor.shape[-1], kernel_size=1, padding='same')(x)
    
    # Second depthwise convolution
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "B")(x)
    
    return x

def double_channels_same_dim_block(input_tensor):
    # Depthwise convolution
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # 1x1 convolution to double the channels
    out_channels = input_tensor.shape[-1] * 2
    x = layers.Conv2D(out_channels, kernel_size=1, padding='same')(x)
    
    # Second depthwise convolution
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "C")(x)
    
    return x

def downsample_same_channels_block(input_tensor):
    # Depthwise conv with stride=2 for downsampling
    x = layers.DepthwiseConv2D(kernel_size=3, strides=2, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    
    # 1x1 conv for channel mixing (same channels)
    channels = input_tensor.shape[-1]
    x = layers.Conv2D(channels, kernel_size=1, padding='same')(x)
    
    # Second depthwise conv
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "D")(x)
    
    return x

def double_channels_block(input_tensor):
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    out_channels = input_tensor.shape[-1] * 2
    x = layers.Conv2D(out_channels, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "E")(x)

    return x

def downsample_block_same_channels(input_tensor):
    x = layers.DepthwiseConv2D(kernel_size=3, strides=2, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    channels = input_tensor.shape[-1]
    x = layers.Conv2D(channels, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "F")(x)

    return x

def double_channels_block_22(input_tensor):
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    out_channels = input_tensor.shape[-1] * 2
    x = layers.Conv2D(out_channels, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "G")(x)

    return x

def downsample_block_2(input_tensor):
    x = layers.DepthwiseConv2D(kernel_size=3, strides=2, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    channels = input_tensor.shape[-1]
    x = layers.Conv2D(channels, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "H")(x)

    return x

def double_channels_block_2(input_tensor):
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    out_channels = input_tensor.shape[-1] * 2
    x = layers.Conv2D(out_channels, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "I")(x)

    return x


#################################### upsampling and concat blocks ####################################

def fuse_and_double_channels_block(low_res_input, high_res_input):
    # Upsample the low-res input (11, 19, 128) to match high-res spatial dims (22, 38)
    low_res_upsampled = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(low_res_input)
    
    # Concatenate along channel axis -> shape becomes (22, 38, 64 + 128 = 192)
    x = layers.Concatenate(axis=-1)([high_res_input, low_res_upsampled])
    
    # Standard Conv2D to reduce channels from 192 to 96
    x = layers.Conv2D(96, kernel_size=3, padding='same')(x)
    
    # Depthwise separable processing
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(96, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name = "J")(x)

    return x

def fuse_without_conv_block(low_res_input, high_res_input):
    # Upsample the low-res input (22, 38, 96) to match high-res spatial dims (44, 76)
    low_res_upsampled = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(low_res_input)
    
    # Concatenate along channel axis -> shape becomes (44, 76, 96 + 32 = 128)
    x = layers.Concatenate(axis=-1)([high_res_input, low_res_upsampled])
    
    # Depthwise conv + BN + ReLU
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 1x1 conv for channel mixing to reduce channels to 40
    x = layers.Conv2D(40, kernel_size=1, padding='same')(x)
    
    # Another depthwise conv + BN + ReLU
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name="K")(x)
    
    return x

def fuse_and_adjust_channels_block(low_res_input, high_res_input):
    # Upsample the low-res input (44, 76, 32) to match high-res spatial dims (88, 152)
    low_res_upsampled = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(low_res_input)
    
    # Concatenate along channel axis -> shape becomes (88, 152, 32 + 16 = 48)
    x = layers.Concatenate(axis=-1)([high_res_input, low_res_upsampled])
    
    # Standard Conv2D to increase channels from 48 to 58
    x = layers.Conv2D(58, kernel_size=3, padding='same')(x)
    
    # Depthwise separable processing
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(58, kernel_size=1, padding='same')(x)

    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name="L")(x)

    return x

def fuse_without_conv_block_v2(low_res_input, high_res_input):
    # Upsample the low-res input (88, 152, 58) to match high-res spatial dims (176, 304)
    low_res_upsampled = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(low_res_input)
    
    # Concatenate along channel axis -> shape becomes (176, 304, 58 + 8 = 66)
    x = layers.Concatenate(axis=-1)([high_res_input, low_res_upsampled])
    
    # Depthwise conv + BN + ReLU
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 1x1 conv for channel mixing to get output channels = 64
    x = layers.Conv2D(64, kernel_size=1, padding='same')(x)

    # Another depthwise conv + BN + ReLU
    x = layers.DepthwiseConv2D(kernel_size=3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU(name="out")(x)

    return x

In [85]:
inputlayer = layers.Input(shape= (176,304,3))
skip_connections = []
outputs = lightweight_block(inputlayer)
skip_connections.append(outputs)

outputs = downsample_block(outputs)
outputs = double_channels_same_dim_block(outputs)
skip_connections.append(outputs)

outputs = downsample_same_channels_block(outputs)
outputs = double_channels_block(outputs)
skip_connections.append(outputs)

outputs = downsample_block_same_channels(outputs)
outputs = double_channels_block_22(outputs)
skip_connections.append(outputs)

outputs = downsample_block_2(outputs)
outputs = double_channels_block_2(outputs)

A, C, E, G = skip_connections
#concate I(==outputs) and H 
J = fuse_and_double_channels_block(outputs, G)
#concate J and E
K = fuse_without_conv_block(J, E)
#concate K and C
L = fuse_and_adjust_channels_block(K, C)
# concate L and A
out = fuse_without_conv_block_v2(L, A)

model = Model(inputs = inputlayer, outputs = out)
model.summary()

In [79]:
import time
for i in range(10):
    input_tensor = tf.random.normal([1,176,304,3])
    start  = time.time()
    output_data = model(input_tensor)
    print(f"time for operations is : {(time.time() - start)*1000.:2f} ms")


time for operations is : 147.927046 ms
time for operations is : 62.347889 ms
time for operations is : 58.679104 ms
time for operations is : 56.204796 ms
time for operations is : 59.333086 ms
time for operations is : 55.888176 ms
time for operations is : 55.975199 ms
time for operations is : 58.487177 ms
time for operations is : 61.177254 ms
time for operations is : 57.781935 ms


## tflite model

In [82]:
import numpy as np

# Assume build_model() is already defined and model is created
input_shape = (176, 608, 3)
optimized_model = Model(inputs = inputlayer, outputs = out)


# Step 1: Convert the Keras model to TensorFlow Lite
def convert_to_tflite(keras_model, tflite_model_path='encoder_10.tflite'):
    # Create a converter object from the Keras model
    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    # Optional optimizations (comment out if not needed)
    # converter.optimizations = [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    
    # Save the converted model to disk
    with open(tflite_model_path, 'wb') as f:
        f.write(tflite_model)
    print(f"TFLite model saved to: {tflite_model_path}")

convert_to_tflite(model)


INFO:tensorflow:Assets written to: /var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p/assets


INFO:tensorflow:Assets written to: /var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p/assets


Saved artifact at '/var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 176, 304, 3), dtype=tf.float32, name='keras_tensor_1652')
Output Type:
  TensorSpec(shape=(None, 176, 304, 64), dtype=tf.float32, name=None)
Captures:
  12920981648: TensorSpec(shape=(), dtype=tf.resource, name=None)
  12978692816: TensorSpec(shape=(), dtype=tf.resource, name=None)
  12920985488: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286475088: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286473936: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286467984: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286469328: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286470864: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286469136: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6286470096: TensorSpec(shape=(), dtype=tf.resource, name=None)
  6

W0000 00:00:1748701326.215673 1755961 tf_tfl_flatbuffer_helpers.cc:392] Ignored output_format.
W0000 00:00:1748701326.215683 1755961 tf_tfl_flatbuffer_helpers.cc:395] Ignored drop_control_dependency.
2025-05-31 19:52:06.215795: I tensorflow/cc/saved_model/reader.cc:83] Reading SavedModel from: /var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p
2025-05-31 19:52:06.219973: I tensorflow/cc/saved_model/reader.cc:52] Reading meta graph with tags { serve }
2025-05-31 19:52:06.219980: I tensorflow/cc/saved_model/reader.cc:147] Reading SavedModel debug info (if present) from: /var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p
2025-05-31 19:52:06.272595: I tensorflow/cc/saved_model/loader.cc:236] Restoring SavedModel bundle.
2025-05-31 19:52:06.557895: I tensorflow/cc/saved_model/loader.cc:220] Running initialization op on SavedModel bundle at path: /var/folders/xb/sshzwv1128l88chp71np05t40000gn/T/tmptq44mj8p
2025-05-31 19:52:06.644964: I tensorflow/cc/saved_model/loader.cc:

In [84]:
import time


def run_tflite_inference_with_timing(tflite_model_path='encoder_10.tflite', input_shape=(176,304,3), num_iterations=10):
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path, num_threads=4)
    interpreter.allocate_tensors()
    
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    for i in range(num_iterations):
        input_data = np.random.random_sample((1,) + input_shape).astype(np.float32)
        interpreter.set_tensor(input_details[0]['index'], input_data)
        
        start_time = time.time()
        interpreter.invoke()
        end_time = time.time()
        
        output_data = interpreter.get_tensor(output_details[0]['index'])
        
        inference_time = (end_time - start_time) * 1000  # milliseconds
        print(f"Inference {i+1}: output shape = {output_data.shape}, time = {inference_time:.2f} ms")

# Run inference with timing
run_tflite_inference_with_timing()

Inference 1: output shape = (1, 176, 304, 64), time = 23.54 ms
Inference 2: output shape = (1, 176, 304, 64), time = 16.70 ms
Inference 3: output shape = (1, 176, 304, 64), time = 10.61 ms
Inference 4: output shape = (1, 176, 304, 64), time = 11.28 ms
Inference 5: output shape = (1, 176, 304, 64), time = 9.89 ms
Inference 6: output shape = (1, 176, 304, 64), time = 11.64 ms
Inference 7: output shape = (1, 176, 304, 64), time = 10.10 ms
Inference 8: output shape = (1, 176, 304, 64), time = 10.41 ms
Inference 9: output shape = (1, 176, 304, 64), time = 9.93 ms
Inference 10: output shape = (1, 176, 304, 64), time = 10.09 ms
