In [3]:
## Import ROS1 module
import rospy
from sensor_msgs.msg import Image
import sys
import cv2
import numpy as np
from cv_bridge import CvBridge

In [4]:
## Publish setting
pub = rospy.Publisher("output_image", Image, queue_size=1)

In [5]:
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
print(tf.__version__)

2.4.0


In [6]:
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
concrete_func.inputs[0].set_shape([1, 50, 50, 3])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# # Save the TF Lite model.
# with tf.io.gfile.GFile('ESRGAN.tflite', 'wb') as f:
#   f.write(tflite_model)

esrgan_model_path = '/home/autoware02/Desktop/esrgan_jupyteros/ESRGAN.tflite'

Download a test image (insect head).

## Generate a super resolution image using TensorFlow Lite

In [7]:
# lr = tf.io.read_file(test_img_path)
# lr = tf.image.decode_jpeg(lr)
def process_image(msg):
    global pub
    try:
    ## Change for ROS1 ----------------------------------------
        bridge = CvBridge()
        lr_rgb = bridge.imgmsg_to_cv2(msg,"rgb8")
        lr_rgb = cv2.resize(lr_rgb,(50,50))
        lr = np.asarray(lr_rgb)
    ## ------------------------------------------------------
        
        lr = np.expand_dims(lr,axis=0)
        lr = tf.cast(lr, tf.float32)

        # Load TFLite model and allocate tensors.
        interpreter = tf.lite.Interpreter(model_path=esrgan_model_path)
        interpreter.allocate_tensors()
        
        # Get input and output tensors.
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        
        # Run the model
        interpreter.set_tensor(input_details[0]['index'], lr)
        interpreter.invoke()
        
        # Extract the output and postprocess it
        output_data = interpreter.get_tensor(output_details[0]['index'])
        sr = tf.squeeze(output_data, axis=0)
        sr = tf.clip_by_value(sr, 0, 255)
        sr = tf.round(sr)
        sr = tf.cast(sr, tf.uint8)
    ## Change for ROS1 ----------------------------------------
        output_img = cv2.resize(sr.numpy(),(356,200))
        output_img = bridge.cv2_to_imgmsg(output_img,"rgb8")
        pub.publish(output_img)
    ## --------------------------------------------------------
    except Exception as err:
        print (err)

## ROSPY init

In [8]:
def start_node():
    rospy.init_node('esrgan')
    rospy.Subscriber("camera/color/image_raw",Image,process_image)
    rospy.spin()

In [None]:
try:
    start_node()
except rospy.ROSInterruptException as err:
    print (err)

## Visualize the result

In [None]:
# # lr = tf.cast(tf.squeeze(lr, axis=0), tf.uint8)
# # plt.figure(figsize = (1, 1))
# # plt.title('LR')
# # plt.imshow(lr.numpy());
# # 
# # plt.figure(figsize=(10, 4))
# # plt.subplot(1, 2, 1)        
# # plt.title(f'ESRGAN (x4)')
# plt.imshow(sr.numpy());

# bicubic = tf.image.resize(lr, [200, 200], tf.image.ResizeMethod.BICUBIC)
# bicubic = tf.cast(bicubic, tf.uint8)
# plt.subplot(1, 2, 2)   
# plt.title('Bicubic')
# plt.imshow(bicubic.numpy());