Creates tfrecord files for Google Object Detection API

In [None]:
import sys
import os

import pandas as pd
import numpy as np
import xml.etree.ElementTree as ET

import tensorflow as tf

from object_detection.utils import dataset_util
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
from time import time

from sklearn.model_selection import train_test_split

In [None]:
LABEL_FOLDER = './data/'
labels = pd.read_csv(os.path.join(LABEL_FOLDER, 'data.csv'))
labels = labels.sort_values('label')
labels['code'] = labels.label.apply(lambda x: x.split('_')[0])
prev_name = ''
prev_value = 1
for idx, row in labels.iterrows():
    if row.code == prev_name:
        prev_value += 1
    else:
        prev_value = 1
        prev_name = row.code
    code = '{}_{}'.format(prev_name, str(prev_value).zfill(2))
    labels.at[idx, 'code'] = code
labels = labels.set_index('code')
labels = labels.sort_index()
labels.fillna(0, inplace=True)
labels.head()

In [None]:

OUTPUT_TRAIN = './tfrecord/train_{}.record'
OUTPUT_TEST = './tfrecord/test_{}.record'

TEST_IDX = 'Binya_03' # 

# For debugging and tests
COLORS = {
    'vehicle' : (255, 0, 0),
    'truck' : (0, 255, 0),
    'motorcycle' : (0, 0, 255),
    'other' : (255, 255, 255)
}        

In [None]:
def set_first_frame(df, fps):
    first_time = df.TimeStamp.astype(int).min() / 1e9
    if first_time > 0:
        first_frame = np.round(first_time * fps)
        #print ('First frame = {}'.format(first_frame))
        tmp = df.groupby('TimeStamp').size().sort_index().reset_index().drop(0, axis=1)
        tmp['frame'] = np.arange(len(tmp)) + first_frame
        df = df.drop('frame', axis=1).merge(tmp, on='TimeStamp')
    return df

def change_values(df):
    """
    Put everything that you want to change in data here.
    """
    df.type.replace(to_replace = 'motorcycle', vlaue='vehicle', inplace=True)
    return df

def read_data(idx):
    df = pd.read_excel(os.path.join(LABEL_FOLDER, labels.at[idx, 'label']))
    df.TimeStamp = pd.to_timedelta(df.TimeStamp)
    tmp = df.groupby('TimeStamp').size().sort_index().reset_index().drop(0, axis=1)
    tmp['frame'] = np.arange(len(tmp)) - labels.at[idx, 'shift']
    df = df.merge(tmp, on='TimeStamp')
    
    cap = cv2.VideoCapture(labels.loc[idx, 'video']) 
    fps = cap.get(cv2.CAP_PROP_FPS)
    df = set_first_frame(df, fps)
    df.frame += labels.at[idx, 'shift']
    df = df[df.frame > 0]
    for col in ['x', 'y', 'width', 'height', 'frame']:
        df[col] = df[col].astype(int)
    return df, cap


In [None]:
if False: #Automatic
    df_classes = df['type'].value_counts().to_frame()
    df_classes.insert(0, 'class_id', range(1, 1 + len(df_classes)))
else: # Manual
    classes_dic = {
        1 : 'vehicle',
        2 : 'truck',
        -1 : 'delete'
    }
    df_classes = pd.DataFrame(list(classes_dic.items()),columns = ['class_id','class_name']).set_index('class_name')


In [None]:
def process_image(img, rows):
    ### Plot rectangles filled with zeros
    tmp = rows[rows.type == 'delete']
    if len(tmp) > 0:
        for idx, row in tmp.iterrows():
            color= img[row.y:row.y + row.height, row.x:row.x + row.width].mean(axis=1).mean(axis=0)
            cv2.rectangle(img, (row.x, row.y), (row.x + row.width, row.y + row.height), color, -1)
    return img

if False: ### Let's test it
    idx = TEST_IDX
    test_df, test_cap = read_data(idx)
    frameno = test_df.frame.min()
    test_cap.set(cv2.CAP_PROP_POS_FRAMES, frameno)
    print ('Frame pointer set to {}'.format(frameno))
    ret, test_img = test_cap.read()
    test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
    plt.imshow(test_img)
    plt.show()
    test_img = process_image(test_img, test_df[test_df.frame == frameno])
    plt.imshow(test_img)
    plt.show()



In [None]:
#test_df.head()

In [None]:
def create_tf_example(filename, img, frameno, df, debug=False):
    
    rows = df[df.frame == frameno]
    img = process_image(img, rows)
    rows = rows[rows.type != 'delete']
    height = img.shape[0]
    width = img.shape[1]
    
    xmins = list(rows['x'] / width)
    xmaxs = list((rows['x'] + rows['width']) / width)
    ymins = list(rows['y'] / height)
    ymaxs = list((rows['y'] + rows['height']) / height)

    classes_text = list(rows['type'])
    classes = list(df_classes.loc[classes_text, 'class_id'])
    
    #with tf.gfile.GFile(img_path, 'rb') as fid:
    #    encoded_image_data = fid.read()
    is_success, im_buf_arr = cv2.imencode(".jpg", img)
    encoded_image_data = im_buf_arr.tobytes()

    
    # Change strings into bytes
    #image_format = filename.split('.')[-1].encode('utf-8')
    image_format = 'jpg'.encode('utf-8')
    filename = '{}_frame_{}'.format(filename, str(frameno).zfill(5)).encode('utf-8')
    classes_text = [x.encode('utf-8') for x in classes_text]

    
    if debug: #For debugging
        print ('Width: {}, Height: {}'.format(width, height))
        print ('xmins:', xmins)
        print ('xmaxs:', xmaxs)
        print ('ymins:', ymins)
        print ('ymaxs:', ymaxs)
        print ('classes_text:', classes_text)
        print ('classes:', classes)
        print ('Len of encoded data:', len(encoded_image_data))
        print ('Image format:', image_format)
        
    tf_label_and_data = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    
    if debug:
        for xmin, xmax, ymin, ymax, cl in zip(xmins, xmaxs, ymins, ymaxs, classes_text):
            x1 = int(xmin * width)
            x2 = int(xmax * width)
            y1 = int(ymin * height)
            y2 = int(ymax * height)
            
            cv2.rectangle(img, (x1, y1), (x2, y2), COLORS[cl.decode('ascii')], 3)
        return tf_label_and_data, img
    else:
        return tf_label_and_data

In [None]:
if False: ### Let's test it
    test_filename = labels.loc[TEST_IDX, 'video']
    test_df, test_cap = read_data(TEST_IDX)
    test_df.head()
    frameno = test_df.frame.min()
    test_cap.set(cv2.CAP_PROP_POS_FRAMES, frameno)
    ret, test_img = test_cap.read()
    test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
    tld, test_img = create_tf_example(test_filename, test_img, frameno, test_df, debug=True)
    plt.imshow(test_img)
    plt.show()

In [None]:
test_codes = ['Amir_01', 'Binya_01']
labels_test = labels[labels.index.isin(test_codes)].copy()
labels_train = labels[~labels.index.isin(test_codes)].copy()

print ('Labels test has {} datasets: {}'.format(len(labels_test), labels_test.index.unique()))
print ('Labels train has {} datasets: {}'.format(len(labels_train), labels_train.index.unique()))
labels_test

In [None]:
def generate_tf_record(labels, idx, suffix, output_path):
    t = time()

    df, cap = read_data(idx)    
    output_path = output_path.format(str(suffix).zfill(2))
    
    filename = labels.loc[idx, 'video']

    writer = tf.compat.v1.python_io.TFRecordWriter(output_path)
    
    i = 0
    frameno, frameno_max = df.frame.min(), df.frame.max()
    nr_frames = frameno_max - frameno
    cap.set(cv2.CAP_PROP_POS_FRAMES, frameno)
    while frameno <= frameno_max:
        if i%20 == 0:
            sys.stdout.write('{:.1f}% ({}/{}) images processed in {:.1f} seconds           \r'
                            .format(100 * i / nr_frames, i, nr_frames, time() - t))        
        i += 1    
        ret, img = cap.read()
        
        if frameno not in df.frame.unique():
            frameno += 1
            print ('Warning: No vehicles for frame: {}'.format(frameno))
            continue
        
        tf_example = create_tf_example(filename, img, frameno, df, debug=False)
        writer.write(tf_example.SerializeToString())
        frameno += 1
        
    writer.close()
    
    print('{}/{} ({}) done. {} images has been processed in {:.1f} seconds and written to {}                              '
          .format(
              suffix+1, len(labels), idx,
              nr_frames, time() - t, output_path))
    

In [None]:
for suffix, idx in enumerate(labels_train.index):
    generate_tf_record(labels_train, idx, suffix, OUTPUT_TRAIN)
for suffix, idx in enumerate(labels_test.index):
    generate_tf_record(labels_test, idx, suffix, OUTPUT_TEST)


## Read image from tfrecord

In [None]:
def create_file_dic(height, width, img_format, xmin, xmax, ymin, ymax, label, text, image = None):
    dic = {}
    dic['height'] = height
    dic['width'] = width
    dic['format'] = img_format
    dic['xmin'] = xmin
    dic['xmax'] = xmax
    dic['ymin'] = ymin
    dic['ymax'] = ymax
    dic['label'] = label
    dic['text'] = text
    if image is not None:
        dic['image'] = image
    return dic

def load_file_data(tfrecord_path, load_images = True):
    i = 0
    features = {'image/filename' : tf.io.FixedLenFeature([], tf.string),
                'image/height' : tf.io.FixedLenFeature([], tf.int64),
                'image/width' : tf.io.FixedLenFeature([], tf.int64),
                'image/format' : tf.io.FixedLenFeature([], tf.string),
                'image/encoded' : tf.io.FixedLenFeature([], tf.string),
                'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
                'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
                'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
                'image/object/class/text': tf.io.VarLenFeature(dtype=tf.string),                     
               }

    loaded_files = {}

    with tf.compat.v1.Session() as sess:
        for s_example in tf.compat.v1.python_io.tf_record_iterator(tfrecord_path):
            example = tf.io.parse_single_example(s_example, features=features)
            (filename, height, width, img_format, 
             xmin, xmax, ymin, ymax, label, text) = sess.run([
                example['image/filename'],
                example['image/height'],
                example['image/width'],
                example['image/format'],
                example['image/object/bbox/xmin'],
                example['image/object/bbox/xmax'],
                example['image/object/bbox/ymin'],
                example['image/object/bbox/ymax'],
                example['image/object/class/label'],
                example['image/object/class/text'],
            ])

            filename = filename.decode("utf-8") 
            img_format = img_format.decode("utf-8") 

            text = [x.decode('utf-8') for x in text.values]
            label = label.values
            xmin = xmin.values
            xmax = xmax.values
            ymin = ymin.values
            ymax = ymax.values

            if load_images:
                if img_format == 'png':
                    image = tf.image.decode_png(example['image/encoded'])
                elif img_format == 'jpg' or img_format == 'jpeg':
                    image = tf.image.decode_jpeg(example['image/encoded'])
                else:
                    raise ('Unknown Image Format:' + img_format)

                image = sess.run(image)
                
                loaded_files[filename] = create_file_dic(height, width, img_format, 
                    xmin, xmax, ymin, ymax,
                    label, text, image)
            else:
                loaded_files[filename] = create_file_dic(height, width, img_format, 
                    xmin, xmax, ymin, ymax,
                    label, text, image=None)

            i += 1
            sys.stdout.write('{} files processed.     \r'.format(i))
            if i > 10:
                return loaded_files 

In [None]:
fd = load_file_data('./tfrecord/test_00.record')

In [None]:
fd.keys()

In [None]:
mypoint = fd[list(fd.keys())[-1]]
print (mypoint.keys())
print (mypoint['xmin'])


In [None]:
import matplotlib.pyplot as plt
myimg = mypoint['image']
for xmin, xmax, ymin, ymax, cl in zip(mypoint['xmin'], mypoint['xmax'],
                                      mypoint['ymin'], mypoint['ymax'],
                                      mypoint['text']):
    x1 = int(xmin * mypoint['width'])
    x2 = int(xmax * mypoint['width'])
    y1 = int(ymin * mypoint['height'])
    y2 = int(ymax * mypoint['height'])

    cv2.rectangle(myimg, (x1, y1), (x2, y2), COLORS[cl], 3)
plt.imshow(myimg)
plt.show()