# Genera los TFRecords en base a los XMLs correspondientes a las imágenes

Adaptado de https://github.com/douglasrizzo/detection_util_scripts 

Fuente: https://medium.com/@omcar17/how-to-convert-xml-files-into-tfrecords-in-tensorflow2-0-86120b553f0b


0) Preparar el ambiente:

In [None]:
#@title Actualizar e instalar paquetes necesarios
!pip install -U --pre tensorflow=="2.*"
!pip install tf_slim
!pip install pycocotools

In [None]:
#@title Clonar el repositorio de modelos de TF si no está ya disponible
import os
import pathlib

if "models" in pathlib.Path.cwd().parts:
  while "models" in pathlib.Path.cwd().parts:
    os.chdir('..')
elif not pathlib.Path('models').exists():
  !git clone --depth 1 https://github.com/tensorflow/models

In [None]:
#@title Instalar el Object Detection API
# Nota: si falla por falta de requerimientos, ejecutarlo de nuevo y funcionará ;)...
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

In [None]:
#@title Re-instalar el Object Detection API
# Nota: se ejecuta de nuevo para que lo instale bien
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

1) Define las librerías a utilizar:

In [None]:
#@title Cargar Librerías
import os
import glob
import pandas as pd

import xml.etree.ElementTree as ET

import io
import tensorflow as tf

from PIL import Image
from tqdm import tqdm
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict

print("Librerías cargadas.")

2) Monta el Drive y define archivos a utilizar:

In [None]:
#@title Montar Drive y definir archivos a procesar

# monta Google Drive nuevamente (se pierde conexión cuando se reinicia el entrorno anterior)
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

# configuración de directorios local en Google Drive
drive_path = '/content/gdrive/My Drive/GEMIS/objDetectionCursogramas'
data_dir_path = drive_path + '/Cursogramas'

print("\n")
print("> Datos disponibles en: ", data_dir_path)
print("\n")

# define las carpetas de XMLs
train_image_dir = data_dir_path + '/train/images'
train_xml_dir = data_dir_path + '/train/annotations'
test_image_dir = data_dir_path + '/validation/images'
test_xml_dir = data_dir_path + '/validation/annotations' 
print("> train_image_dir: ", train_image_dir)
print("> train_xml_dir: ", train_xml_dir)
print("> test_image_dir: ", test_image_dir)
print("> test_xml_dir: ", test_xml_dir)
print("\n")

# define los nombres de los archivos
train_csv_fname = data_dir_path + '/train_labels.csv'
test_csv_fname = data_dir_path + '/test_labels.csv'
print("> train_csv_fname: ", train_csv_fname)
print("> test_csv_fname: ", test_csv_fname)

label_map_pbtxt_fname = data_dir_path + '/label_map.pbtxt'
print("> label_map_pbtxt_fname: ", label_map_pbtxt_fname)

train_record_fname = data_dir_path + '/train.record'
test_record_fname = data_dir_path + '/test.record'
print("> train_record_fname: ", train_record_fname)
print("> test_record_fname: ", test_record_fname)

3) Genera CSVs auxiliares en base a los XMLs correspondientes a las imágenes: 

In [None]:
#@title Definir funciones auxiliares para CSV
def __list_to_csv(annotations, output_file):
    column_name = [
        'filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax'
    ]
    xml_df = pd.DataFrame(annotations, columns=column_name)
    xml_df.to_csv(output_file, index=None)

def xml_to_csv(xml_dir, output_file):
    """Reads all XML files, generated by labelImg, from a directory and generates a single CSV file"""
    annotations = []
    for xml_file in glob.glob(xml_dir + '/*.xml'):
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall('object'):
            value = (root.find('filename').text,
                     int(root.find('size')[0].text),
                     int(root.find('size')[1].text), member[0].text,
                     int(member[4][0].text), int(member[4][1].text),
                     int(member[4][2].text), int(member[4][3].text))
            annotations.append(value)

    __list_to_csv(annotations, output_file)

print("Funciones auxiliares para CSV definidas")    

In [None]:
#@title Generar CSVs

# Por las dudas, se elimina los archivos si ya existen (para asegurar su actualización)
if os.path.isfile(train_csv_fname):
  os.remove(train_csv_fname)
  print(train_csv_fname, " eliminado.")

if os.path.isfile(test_csv_fname):
  os.remove(test_csv_fname)
  print(test_csv_fname, " eliminado.")

# genera CSVs
print("\n")
print("-- generando train_csv: ", train_csv_fname)
xml_to_csv(train_xml_dir, train_csv_fname)
print("> train_csv generado.")
print("\n")
print("-- generando test_csv: ", test_csv_fname)
xml_to_csv(test_xml_dir, test_csv_fname)
print("> test_csv generado.")

4) Genera archivo de clases label_map_pbtxt:

In [None]:
#@title Definir funciones auxiliares para label_map
def pbtxt_from_csv(csv_path, pbtxt_path):
    class_list = list(pd.read_csv(csv_path)['class'].unique())
    class_list.sort()

    pbtxt_from_classlist(class_list, pbtxt_path)


def pbtxt_from_classlist(l, pbtxt_path):
    pbtxt_text = ''

    for i, c in enumerate(l):
        pbtxt_text += 'item {\n    id: ' + str(
            i + 1) + '\n    display_name: "' + c + '"\n}\n\n'

    with open(pbtxt_path, "w+") as pbtxt_file:
        pbtxt_file.write(pbtxt_text)

print("Funciones auxiliares para label_map definidas")    

In [None]:
#@title Generar label_map_pbtx

#Por las dudas, se elimina los archivos si ya existen (para asegurar su actualización)
if os.path.isfile(label_map_pbtxt_fname):
  os.remove(label_map_pbtxt_fname)
  print(label_map_pbtxt_fname, " eliminado.")

print("\n")
print("-- generando label_map_pbtxt: ", label_map_pbtxt_fname)
pbtxt_from_csv(train_csv_fname, label_map_pbtxt_fname)
print("> label_map_pbtxt generado.")

5) Generar TFRecords en base a los XMLs correspondientes a las imágenes:

In [None]:
#@title Definir funciones auxiliares para TFRecords
def generate_TFrecords(pbtxt_input, csv_input, image_dir, output_path):
    class_dict = class_dict_from_pbtxt(pbtxt_input)

    writer = tf.io.TFRecordWriter(output_path)
    path = os.path.join(image_dir)
    examples = pd.read_csv(csv_input)
    grouped = __split(examples, 'filename')

    for group in tqdm(grouped, desc='groups'):
        tf_example = create_tf_example(group, path, class_dict)
        writer.write(tf_example.SerializeToString())

    writer.close()

def create_tf_example(group, path, class_dict):
    with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)),
                        'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size

    filename = group.filename.encode('utf8')
    image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    for index, row in group.object.iterrows():
        if set(['xmin_rel', 'xmax_rel', 'ymin_rel', 'ymax_rel']).issubset(
                set(row.index)):
            xmin = row['xmin_rel']
            xmax = row['xmax_rel']
            ymin = row['ymin_rel']
            ymax = row['ymax_rel']

        elif set(['xmin', 'xmax', 'ymin', 'ymax']).issubset(set(row.index)):
            xmin = row['xmin'] / width
            xmax = row['xmax'] / width
            ymin = row['ymin'] / height
            ymax = row['ymax'] / height

        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
        classes_text.append(row['class'].encode('utf8'))
        classes.append(class_dict[row['class']])

    tf_example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/height':
                dataset_util.int64_feature(height),
                'image/width':
                dataset_util.int64_feature(width),
                'image/filename':
                dataset_util.bytes_feature(filename),
                'image/source_id':
                dataset_util.bytes_feature(filename),
                'image/encoded':
                dataset_util.bytes_feature(encoded_jpg),
                'image/format':
                dataset_util.bytes_feature(image_format),
                'image/object/bbox/xmin':
                dataset_util.float_list_feature(xmins),
                'image/object/bbox/xmax':
                dataset_util.float_list_feature(xmaxs),
                'image/object/bbox/ymin':
                dataset_util.float_list_feature(ymins),
                'image/object/bbox/ymax':
                dataset_util.float_list_feature(ymaxs),
                'image/object/class/text':
                dataset_util.bytes_list_feature(classes_text),
                'image/object/class/label':
                dataset_util.int64_list_feature(classes),
            }))
    return tf_example

def __split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [
        data(filename, gb.get_group(x))
        for filename, x in zip(gb.groups.keys(), gb.groups)
    ]

def class_dict_from_pbtxt(pbtxt_path):
    # open file, strip \n, trim lines and keep only
    # lines beginning with id or display_name
    data = [
        l.rstrip('\n').strip()
        for l in open(pbtxt_path, 'r', encoding='utf-8-sig')
        if 'id:' in l or 'display_name:'
    ]
    ids = [int(l.replace('id:', '')) for l in data if l.startswith('id')]
    names = [
        l.replace('display_name:', '').replace('"', '').strip() for l in data
        if l.startswith('display_name')
    ]

    #print(data)

    # join ids and display_names into a single dictionary
    class_dict = {}
    for i in range(len(ids)):
        class_dict[names[i]] = ids[i]

    return class_dict


print("Funciones auxiliares para TFRecords definidas")   

In [None]:
#@title Generar TFRecords

# por las dudas, se elimina los archivos si ya existen (para asegurar su actualización)
if os.path.isfile(train_record_fname):
  os.remove(train_record_fname)
  print(train_record_fname, " eliminado.")

if os.path.isfile(test_record_fname):
  os.remove(test_record_fname)
  print(test_record_fname, " eliminado.")

print("\n")
print("-- generando train_TFRecords: ", train_record_fname)
generate_TFrecords(label_map_pbtxt_fname, train_csv_fname, train_image_dir, train_record_fname)
print("\n> train_TFRecords generado.")
print("\n")
print("-- generando test_TFRecords: ", test_record_fname)
generate_TFrecords(label_map_pbtxt_fname, test_csv_fname, test_image_dir, test_record_fname)
print("\n> test_TFRecords generado.")