# Create TFRecords files starting from csv data and images

Note: this code is suitable for creating TFRecord files for both domain adversarial and domain agnostic training with no modifications. If the csv file contains also lightbox and sunlamp information, the images will be taken from the corresponding folder based on the class associated to the csv entry (synthetic = 1; lightbox = 2; sunlamp = 3).

In [1]:
images_dir='speedplus\\speedplus'
csv_dir='train.csv'
output_path='TFRecords\\train{}.record' #Create this folder before running the script

In [2]:
#Import all required packages
import os
import pandas as pd
import io
import tensorflow.compat.v1 as tf
from PIL import Image
import numpy as np

In [3]:
#Define some utilities

def int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def int64_list_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [4]:
# Part of the code is adapted from https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/training.html#create-tensorflow-records

def create_tf_example(row, path):

    image_filename=row[0]
    with tf.gfile.GFile(os.path.join(path, '{}'.format(image_filename)), 'rb') as fid:
        encoded_jpg = fid.read()

    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    
    width = 1920
    height = 1200

    channels=image.mode

    if channels=="L":
      channels=1
    elif channels=="RGB":
      channels=3

    image_filename=image_filename.encode('utf8')
    image_format = ('jpg').encode('utf8')

    q_1 = row[1]
    q_2 = row[2]
    q_3 = row[3]
    q_4 = row[4]
    Xc = row[5]
    Yc = row[6]
    Zc = row[7]
    xmin = row[8]
    ymin = row[9]
    xmax = row[10]
    ymax = row[11]
    X_A=row[12]
    Y_A=row[13]
    X_B=row[14]
    Y_B=row[15]
    X_C=row[16]
    Y_C=row[17]
    X_D=row[18]
    Y_D=row[19]
    X_E=row[20]
    Y_E=row[21]
    X_F=row[22]
    Y_F=row[23]
    X_G=row[24]
    Y_G=row[25]
    X_H=row[26]
    Y_H=row[27]
    X_I=row[28]
    Y_I=row[29]
    X_L=row[30]
    Y_L=row[31]
    X_M=row[32]
    Y_M=row[33]
    dataset_class= row[34]
    class_text = ('Tango').encode('utf8')

    #Some fields are kept for compatibility with TF Object detection API
    class_nr = 1

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/actual_channels': int64_feature(channels),
        'image/channels': int64_feature(3),
        'image/height': int64_feature(height),
        'image/width': int64_feature(width),
        'image/filename': bytes_feature(image_filename),
        'image/dataset_class': int64_feature(dataset_class),
        'image/encoded': bytes_feature(encoded_jpg),
        'image/format': bytes_feature(image_format),
        
        'image/object/quaternions/q_1': float_feature(q_1),
        'image/object/quaternions/q_2': float_feature(q_2),
        'image/object/quaternions/q_3': float_feature(q_3),
        'image/object/quaternions/q_4': float_feature(q_4),
        
        'image/object/position/Xc': float_feature(Xc),
        'image/object/position/Yc': float_feature(Yc),
        'image/object/position/Zc': float_feature(Zc),
        
        'image/object/bbox/xmin': float_feature(xmin),
        'image/object/bbox/xmax': float_feature(xmax),
        'image/object/bbox/ymin': float_feature(ymin),
        'image/object/bbox/ymax': float_feature(ymax),
        'image/object/class/text': bytes_feature(class_text),
        'image/object/class/label': int64_feature(class_nr),  
        
        'image/object/kpts/X_A': float_feature(X_A),
        'image/object/kpts/Y_A': float_feature(Y_A),
        'image/object/kpts/X_B': float_feature(X_B),
        'image/object/kpts/Y_B': float_feature(Y_B),
        'image/object/kpts/X_C': float_feature(X_C),
        'image/object/kpts/Y_C': float_feature(Y_C),
        'image/object/kpts/X_D': float_feature(X_D),
        'image/object/kpts/Y_D': float_feature(Y_D),
        'image/object/kpts/X_E': float_feature(X_E),
        'image/object/kpts/Y_E': float_feature(Y_E),
        'image/object/kpts/X_F': float_feature(X_F),
        'image/object/kpts/Y_F': float_feature(Y_F),                
        'image/object/kpts/X_G': float_feature(X_G),
        'image/object/kpts/Y_G': float_feature(Y_G),
        'image/object/kpts/X_H': float_feature(X_H),
        'image/object/kpts/Y_H': float_feature(Y_H),
        'image/object/kpts/X_I': float_feature(X_I),
        'image/object/kpts/Y_I': float_feature(Y_I),
        'image/object/kpts/X_L': float_feature(X_L),
        'image/object/kpts/Y_L': float_feature(Y_L),
        'image/object/kpts/X_M': float_feature(X_M),
        'image/object/kpts/Y_M': float_feature(Y_M), 
    }))
    return tf_example

Run the following cell to create the TFRecord files

In [5]:
#Define the desired number of TFRecords files
number_of_tfrecords_files=100

#Read csv file content
csv = pd.read_csv(csv_dir).values

number_of_images_per_file=np.floor(len(csv)/number_of_tfrecords_files).astype('int32')
images_processed=0
images_index_start=0
images_index_end=number_of_images_per_file

for i in range(number_of_tfrecords_files):
    writer = tf.python_io.TFRecordWriter(output_path.format(i))
    
    
    for row in csv[images_index_start:images_index_end]:
      images_processed+=1
      features=row

      if row[34]==1:
        subfolder='synthetic'
        images_dir_full = os.path.join(images_dir,subfolder,'images')

      elif row[34]==2:
        subfolder='lightbox'
        images_dir_full = os.path.join(images_dir,subfolder,'images')

      elif row[34]==3:
        subfolder='sunlamp'
        images_dir_full = os.path.join(images_dir,subfolder,'images')

      
      tf_example = create_tf_example(row, images_dir_full)
      writer.write(tf_example.SerializeToString())

    images_index_start=images_index_end
    images_index_end=(i+2)*number_of_images_per_file

    if i==number_of_tfrecords_files-2:
      images_index_end=len(csv)
    writer.close()
    print('Successfully created the TFRecord file: {}'.format(output_path.format(i)))

Successfully created the TFRecord file: C:\Users\Alessandro Lotti\Documents\MATLAB\Matlab_codes\SPEC_2022_tools_v0\submission1\ai22\TFRecords\train0.record


KeyboardInterrupt: ignored