Function to process the TFRecord and create a csv with the fields - filename, class, xmin, ymin, xmax, ymax, prediction_score

In [1]:
# import modules
import csv, os, sys, io
import tensorflow as tf
import pandas as pd
import numpy as np
sys.path.append('/home/ubuntu/data/tensorflow/my_workspace/camera-trap-detection/data/')
from utils import dataset_util
from PIL import ImageFile
from PIL import Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
# input parameters
filename_list = ['/home/ubuntu/data/tensorflow/my_workspace/training_demo/Predictions/snapshot_serengeti_s01_s06-0-10000.record']
batch_size = 10000
score_threshold = 0.5

In [3]:
"""This section contains the helper functions needed for 
decoding TFRecord...
1. decode_record: This function decodes a TFexample with the following features
               'image/filename' - a fixed length feature with the file name
               'image/encoded' - a fixed length feature with image encodings
               'image/format' - a fixed length features with image format
               'image/detection/bbox/xmin' - a varaible name feature with normalized xmin values 
               'image/detection/bbox/xmax' - normalized xmax values
               'image/detection/bbox/ymin' - normalized ymin values
               'image/detection/bbox/ymax' - normalized ymax values
               'image/detection/label' - bounding box labels
               'image/detection/score' - prediction score

"""

def decode_record(serialized_example):
    context_features = {
                        'image/filename': tf.FixedLenFeature([], tf.string),
                        'image/encoded': tf.FixedLenFeature([], tf.string),
                        'image/format': tf.FixedLenFeature([], tf.string),
                        "image/detection/bbox/xmin" : tf.VarLenFeature(tf.float32),
                        "image/detection/bbox/xmax" : tf.VarLenFeature(tf.float32),
                        "image/detection/bbox/ymin" : tf.VarLenFeature(tf.float32),
                        "image/detection/bbox/ymax" : tf.VarLenFeature(tf.float32),
                        "image/detection/label" : tf.VarLenFeature(tf.int64),
                        "image/detection/score" : tf.VarLenFeature(tf.float32)
                    }


    context, sequence = tf.parse_single_sequence_example(serialized=serialized_example,
                                              context_features=context_features,
#                                               sequence_features=sequence_features,
                                              example_name=None,
                                              name=None)

    return ({k: v for k, v in context.items()},{k: v for k, v in sequence.items()})

# Create an iterator to traverse the file
dataset = tf.data.Dataset.from_tensor_slices(filename_list)
dataset = tf.data.TFRecordDataset(dataset)
dataset = dataset.map(lambda x: decode_record(serialized_example=x)).batch(batch_size)

iterator = dataset.make_one_shot_iterator() # create the iterator
batch_data = iterator.get_next()

# Run the session and extract the feature values
with tf.Session() as sess:
    (context, sequence) = sess.run(batch_data)
    img = context['image/encoded']
    filename = context['image/filename']
    # Features added during the detection phase 
    xmin_d = (context['image/detection/bbox/xmin'])
    ymin_d = (context['image/detection/bbox/ymin'])
    xmax_d = (context['image/detection/bbox/xmax'])
    ymax_d = (context['image/detection/bbox/ymax'])
    label_d = (context['image/detection/label'])
    score = (context['image/detection/score'])
    
    file_name = [] # list of mages with no groundtruth
    no_prediction = [] # Storing the image name of the correct predictions
    incorrect_prediction = [] # Storing the image name of the incorrect predictions
    
    # index of the boxes with score greater than the threshold
    index_score_gt_threshold = score[1] > score_threshold # can be checked only once. So moved out of the for loop
    xmins_d, ymins_d, xmaxs_d, ymaxs_d, labels_d, scores = [], [], [], [], [], []
    for rec_i in range(len(img)):
        encoded_jpg_io = io.BytesIO(img[rec_i])
        image = Image.open(encoded_jpg_io)
        width, height = image.size
        
        # index of the image
        index_image_rec_i = xmin_d[0][:, 0] == rec_i
        # index of the boxes with score greater than the threshold
        # index_score_gt_threshold = score[1] > score_threshold
        index_img_gt_thres = list(np.where([a and b for a, b in zip(index_image_rec_i, index_score_gt_threshold)])[0])
        
        xmins_d = xmins_d + list(xmin_d[1][index_img_gt_thres])
        ymins_d = ymins_d + list(ymin_d[1][index_img_gt_thres])
        xmaxs_d = xmaxs_d + list(xmax_d[1][index_img_gt_thres])
        ymaxs_d = ymaxs_d + list(ymax_d[1][index_img_gt_thres])
        labels_d = labels_d + list(label_d[1][index_img_gt_thres])
        scores = scores + list(score[1][index_img_gt_thres])
        file_name = file_name + [filename[rec_i].decode('ascii')]*len(index_img_gt_thres)
        
        if len(index_img_gt_thres)==0:
            no_prediction.append(filename[rec_i].decode('ascii'))

In [4]:
# Create pandas dataframe
df_predictions = pd.DataFrame({'labels':labels_d, 
                               'filename':file_name,
                               'score': scores,
                               'xmin': xmins_d,
                               'ymin': ymins_d,
                               'xmax': xmaxs_d,
                               'ymax': ymaxs_d})
# write predictions to csv
df_predictions.to_csv('/home/ubuntu/data/tensorflow/my_workspace/training_demo/Predictions/snapshot_serengeti_s01_s06-0-10000.csv')

In [5]:
df_predictions.head()

Unnamed: 0,filename,labels,score,xmax,xmin,ymax,ymin
0,S1/B04/B04_R1/S1_B04_R1_PICT0009,5,0.990571,0.790646,0.354476,0.946177,0.552227
1,S1/B04/B04_R1/S1_B04_R1_PICT0012,5,0.992132,0.213541,0.073717,0.806135,0.643844
2,S1/B04/B04_R1/S1_B04_R1_PICT0017,5,0.995016,0.303868,0.150443,0.842562,0.610818
3,S1/B04/B04_R1/S1_B04_R1_PICT0018,5,0.996753,0.760716,0.496047,0.937487,0.676392
4,S1/B05/B05_R1/S1_B05_R1_PICT0002,7,0.996732,0.825283,0.532042,0.930573,0.52743
