In [None]:
import os
import sys
import tensorflow.compat.v1 as tf
import pandas as pd
import hashlib
import PIL
from sklearn.model_selection import train_test_split

In [None]:
def create_example(id, filename, filepath, bbox, label):
    img_raw = open(filepath, "rb").read()
    key = hashlib.sha256(img_raw).hexdigest()
    width, height = PIL.Image.open(filepath).size

    example = tf.train.Example(features=tf.train.Features(feature={
        "image/height": tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),
        "image/width": tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),
        "image/filename": tf.train.Feature(bytes_list=tf.train.BytesList(value=[filename.encode("utf-8")])),
        "image/source_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(id).encode("utf-8")])),
        "image/key/sha256": tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode("utf-8")])),
        "image/encoded": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        "image/format": tf.train.Feature(bytes_list=tf.train.BytesList(value=["jpg".encode("utf8")])),
        "image/object/bbox/xmin": tf.train.Feature(float_list=tf.train.FloatList(value=[int(bbox[0]) / width])),
        "image/object/bbox/xmax": tf.train.Feature(float_list=tf.train.FloatList(value=[int(bbox[2]) / width])),
        "image/object/bbox/ymin": tf.train.Feature(float_list=tf.train.FloatList(value=[int(bbox[1]) / height])),
        "image/object/bbox/ymax": tf.train.Feature(float_list=tf.train.FloatList(value=[int(bbox[3]) / height])),
        "image/object/class/text": tf.train.Feature(bytes_list=tf.train.BytesList(value=["1d".encode("utf-8")])),
        "image/object/class/label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
        "image/object/difficult": tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),
        "image/object/truncated": tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),
        "image/object/view": tf.train.Feature(bytes_list=tf.train.BytesList(value=["Unspecified".encode("utf-8")])),
    }))
    return example


def create_tfrecord(annotations, dataset_dir, output_path, samples_per_file):    
    print(f"creating tf records at {output_path}")

    # write examples into tfrecord
    fid = 0
    partition_path = output_path + f"-{fid:02d}"
    tfwriter = tf.io.TFRecordWriter(partition_path)
    for i, annotation in annotations.iterrows():
        if (i + 1) % samples_per_file == 0:
            fid += 1
            partition_path = output_path + f"-{fid:02d}"
            tfwriter.close()
            tfwriter = tf.io.TFRecordWriter(partition_path)

        print(f'annotation: {annotation}')
        filename = annotation['file']
        filepath = os.path.join(dataset_dir, filename)
        label = annotation['barcode_type']
        bbox_str = annotation['bounding_box']
        if bbox_str.startswith('[') and bbox_str.endswith(']'):
            bbox_str = bbox_str[1:-1]
        bbox = bbox_str.split(',')
        
        example = create_example(i + 1, filename, filepath, bbox, label)
        tfwriter.write(example.SerializeToString())
        
    tfwriter.close()    


def create_train_val_tfrecords(dataset_dir, annotations_file, output_dir, samples_per_file):
    print(f'creating train and validation tf records from {annotations_file} at {output_dir}')

    # create the output directory if not exist
    if not os.path.exists(output_dir):
        os.mkdir(os.path.basename(os.path.dirname(output_dir)))
    
    # read annotations
    annotations = pd.read_csv(annotations_file)
    
    # split into train and validation sets
    train, val = train_test_split(annotations, test_size=0.2, random_state=42, shuffle=True)
    print(f'train size: {train.shape[0]} validation size: {val.shape[0]}')
    
    # create tfrecords for train and validation set
    create_tfrecord(train, dataset_dir, os.path.join(output_dir, "train.tfrecord"), samples_per_file)
    create_tfrecord(val, dataset_dir, os.path.join(output_dir, "val.tfrecord"), samples_per_file)

In [None]:
TFRECORDS_DIR = os.path.join(os.getcwd(), 'tfrecords/')
DATASET_DIR = os.path.join(os.getcwd(), 'Muenster_Barcode_Database/N95-2592x1944_scaledTo640x480bilinear')
ANNOTATIONS_PATH = os.path.join(os.getcwd(), 'Muenster_Barcode_Database/annotations.csv')

create_train_val_tfrecords(DATASET_DIR, ANNOTATIONS_PATH, TFRECORDS_DIR, 25)