In [None]:
import os
import cv2
import numpy as np
import pathlib
import tensorflow as tf
# from object_detection.utils import visualization_utils as vis_util
from object_detection.utils import label_map_util
from object_detection.utils import dataset_util

In [None]:
def load_model(model_dir):
    model = tf.saved_model.load(model_dir)
    return model

def preprocess_image(image_path):
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_expanded = np.expand_dims(image_rgb, axis=0)
    return image_rgb, image_expanded

def run_inference(model, image):
    input_tensor = tf.convert_to_tensor(image)
    input_tensor = input_tensor[tf.newaxis,...]
    model_fn = model.signatures['serving_default']
    output_dict = model_fn(input_tensor)
    num_detections = int(output_dict.pop('num_detections'))
    output_dict = {key: value[0, :num_detections].numpy() for key, value in output_dict.items()}
    output_dict['num_detections'] = num_detections
    output_dict['detection_classes'] = output_dict['detection_classes'].astype(np.int64)
    return output_dict

In [None]:
def create_tf_example(image_path, output_dict, label_map_path):
    image = cv2.imread(image_path)
    height, width, _ = image.shape
    confidence_threshold = 0.5  # Confidence threshold for annotations

    label_map = label_map_util.load_labelmap(label_map_path)
    categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=90, use_display_name=True)
    category_index = label_map_util.create_category_index(categories)

    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes = []
    classes_text = []

    for i in range(output_dict['num_detections']):
        class_id = output_dict['detection_classes'][i]
        class_label = category_index[class_id]['name']
        score = output_dict['detection_scores'][i]
        if score > confidence_threshold:
            ymin, xmin, ymax, xmax = output_dict['detection_boxes'][i]
            xmins.append(xmin)
            xmaxs.append(xmax)
            ymins.append(ymin)
            ymaxs.append(ymax)
            classes.append(class_id)
            classes_text.append(class_label.encode('utf8'))

    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    _, encoded_image = cv2.imencode('.jpg', image_rgb)
    encoded_image_data = encoded_image.tobytes()

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(image_path.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(image_path.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

In [None]:
# Set the paths and parameters
model_name = 'faster_rcnn_resnet50_v1_640x640_coco17_tpu-8'
model_dir = '../models/research/object_detection/'
model_dir = str(pathlib.Path(model_dir)/model_name/"saved_model")
  # Path to the directory containing the saved model
image_dir = './data/resized_640x640/'  # Path to the directory containing the unlabeled images
output_dir = './annotations/'  # Output directory to store the generated TFRecord files
label_map_path = '../models/research/object_detection/data/mscoco_label_map.pbtxt'  # Path to the label map file

# Load the pretrained model
model = load_model(model_dir)

In [None]:
# Iterate over the unlabeled images
# image_files = os.listdir(image_dir)
img_count=0
for folder in os.listdir(image_dir):
    if not folder=='bluefox_2016-10-04-14-22-41_bag' and not folder=='bluefox_2016-09-30-15-19-35_bag':
        for image_file in os.listdir(os.path.join(image_dir,folder)):
            # print(folder,os.path.join(image_dir,folder),image_file)
            image_path = os.path.join(image_dir,folder, image_file)
            # Preprocess the image
            # image_rgb, image_expanded = preprocess_image(image_path)
            image_np = np.array(cv2.imread(str(image_path)))
            tfrecord_filename = os.path.splitext(image_file)[0] + '.tfrecord'
            tfrecord_path = os.path.join(output_dir,folder, tfrecord_filename)
            # Run inference on the image
            # output_dict = run_inference(model, image_expanded)
            if not os.path.exists(tfrecord_path):
                output_dict = run_inference(model, image = image_np)

                # Visualize the results on the image
                # annotated_image = visualize_results(image_rgb, output_dict, label_map_path, threshold=confidence_threshold)

                # Save the annotated image
                # annotated_image_path = os.path.join(output_dir, image_file)
                # cv2.imwrite(annotated_image_path, annotated_image)

                # Save the annotations in TFRecord format
                tf_example = create_tf_example(image_path, output_dict, label_map_path)
                tfrecord_filename = os.path.splitext(image_file)[0] + '.tfrecord'
                if os.path.exists(os.path.join(output_dir,folder)):pass
                else: os.makedirs(os.path.join(output_dir,folder))
                tfrecord_path = os.path.join(output_dir,folder, tfrecord_filename)

                with tf.io.TFRecordWriter(tfrecord_path) as writer:
                    writer.write(tf_example.SerializeToString())
                img_count+=1
                print(tfrecord_path,img_count)
            else:
                img_count+=1
                pass
print("Annotation generation and TFRecord conversion complete.")