In [None]:
import cv2
import numpy as np
import os
# import tensorflow as tf

from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict
import xml.etree.ElementTree as ET
import tensorflow.compat.v1 as tf
import argparse
import random
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
def check_dir(dirpath):
    if not os.path.exists(dirpath):
        try:
            os.makedirs(dirpath)
        except Exception as e:
            pass

def class_text_to_int(row_label, classes='food-recipe'):
    if (row_label == "food") and (row_label in classes):
        return 1
    if (row_label == "recipe") and (row_label in classes):
        return 2
    else:
        raise Exception('Label not found', row_label)
        
def get_class_name(boxes, with_recipe_bb): return 'recipe'

def get_box_coords(box, height, width):
    ymin = float(box.find("ymin").text)
    xmin = float(box.find("xmin").text)
    ymax = float(box.find("ymax").text)
    xmax = float(box.find("xmax").text)
    if xmin < 0: xmin = 0
    if ymin < 0: ymin = 0
    if xmax > width - 1 : xmax = width - 1
    if ymax > height - 1: ymax = height - 1
    return (xmin, ymin), (xmax, ymax)

def draw_bbs(img, xmins, ymins, xmaxs, ymaxs):
    height, width, _ = img.shape
    for xmin, ymin, xmax, ymax in zip(xmins, ymins, xmaxs, ymaxs):
        print(width, height)
        print((int(xmin*width), int(ymin*height)), (int(xmax*width),int(ymax*height)))
        cv2.rectangle(img, (int(xmin*width), int(ymin*height)), (int(xmax*width),int(ymax*height)), [0,0,255], 10)
    plt.imshow(img)
    plt.show()
    input('Press enter to continue')
    
    
def create_tf_example(example, path, with_recipe_bb, gt_classes, multi_piatto_area):

        root = ET.parse(os.path.join(ANNOTATIONS_PATH, example))
        
        filename = example.replace(".xml", ".jpg")
        if filename.endswith(".jpeg"):
            filename = filename.replace(".jpeg", ".jpg")
            
        file_path = os.path.join(IMAGES_PATH, filename)
        if not os.path.isfile(file_path):
            file_path = file_path.replace(".jpg", ".png")

        image = np.asarray(Image.open(file_path))
        width, height, _ = image.shape

        image_format = b'jpg'

        xmins, xmaxs = [], []
        ymins, ymaxs = [], []
        classes_text = []
        classes = []
        for boxes in root.iter('object'):
            box = boxes.findall("bndbox")[0]
            
            (xmin, ymin), (xmax, ymax) = get_box_coords(box, height, width)  
            if (xmax <= xmin) or (ymax <= ymin): continue

            crop_width, crop_height = xmax - xmin, ymax - ymin
            if crop_width*crop_height > multi_piatto_area: 
                xmins.append(xmin/width) ; xmaxs.append(xmax/width)
                ymins.append(ymin/height); ymaxs.append(ymax/height)
                classes_text.append('recipe'.encode('utf8'))
                classes.append(class_text_to_int('recipe', gt_classes))
            else:                
                inside_another_box = False
                for boxes in root.iter('object'):
                    (xbb1, ybb1), (xbb2, ybb2) = get_box_coords(boxes.findall("bndbox")[0], crop_height, crop_width)
                    if (xbb2 <= xbb1) or (ybb2 <= ybb1): continue

                    xc, yc = (xmin + xmax)/2, (ymin + ymax)/2
                    if xbb1 < xc < xbb2 and ybb1 < yc < ybb2: 
                        inside_another_box = True
                        break
                
                if not inside_another_box:
                    xmins.append(xmin/width) ; xmaxs.append(xmax/width)
                    ymins.append(ymin/height); ymaxs.append(ymax/height)                        
                    classes_text.append('food'.encode('utf8'))
                    classes.append(class_text_to_int('food', gt_classes))

                tf_example = tf.train.Example(features=tf.train.Features(feature={
                    'image/height'            : dataset_util.int64_feature(crop_height),
                    'image/width'             : dataset_util.int64_feature(crop_width),
                    'image/filename'          : dataset_util.bytes_feature(new_filename),
                    'image/source_id'         : dataset_util.bytes_feature(new_filename),
                    'image/encoded'           : dataset_util.bytes_feature(encoded_jpg),
                    'image/format'            : dataset_util.bytes_feature(image_format),
                    'image/object/bbox/xmin'  : dataset_util.float_list_feature(xmins),
                    'image/object/bbox/xmax'  : dataset_util.float_list_feature(xmaxs),
                    'image/object/bbox/ymin'  : dataset_util.float_list_feature(ymins),
                    'image/object/bbox/ymax'  : dataset_util.float_list_feature(ymaxs),
                    'image/object/class/text' : dataset_util.bytes_list_feature(classes_text),
                    'image/object/class/label': dataset_util.int64_list_feature(classes),
                }))

                writer.write(tf_example.SerializeToString())