In [4]:
from IPython import display

import os
import tensorflow as tf
import glob

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [47]:
ROOT_PATH = '/content'
DATASET_PATH = '/content/.../.../custom_Dataset'

In [48]:
def get_img_file_names(dir, format='jpg'):
    formats = ['jpg', 'png', 'jpeg', 'JPG', 'JPEG', 'PNG']
    img_list = []

    if format in formats:
      rgx_for_img = dir + "/*." + format
      img_list = glob.glob(rgx_for_img)

    return img_list

In [49]:
img_files = get_img_file_names(DATASET_PATH, 'JPG')

In [50]:
# display.display(display.Image(filename=img_files[0]))

In [51]:
import numpy as np

labels = list(np.zeros(len(img_files)).astype(int))

In [52]:
images = {img_files[i]: labels[i] for i in range(len(img_files))}

### Write TFRecord File

In [53]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [72]:
def create_image_tfRecord_example(image_string, label):
    image_shape = tf.io.decode_jpeg(image_string).shape
    # image_shape = tf.io.decode_png(image_string).shape

    feature_map = {
      'height': _int64_feature(image_shape[0]),
      'width': _int64_feature(image_shape[1]),
      'depth': _int64_feature(image_shape[2]),
      'label': _int64_feature(label),
      'image_raw': _bytes_feature(image_string),
    }

    return tf.train.Example(features=tf.train.Features(feature=feature_map))

In [73]:
TFRECORD_PATH = '/content/images.tfrecords'

In [74]:
def create_image_tfRecord_file(record_file_name, images_dict):
  with tf.io.TFRecordWriter(record_file_name) as tfrecord_writer_:
    for filename, label in images_dict.items():
        image_string = open(filename, 'rb').read()
        tf_example = create_image_tfRecord_example(image_string, label)
        tfrecord_writer_.write(tf_example.SerializeToString())

In [75]:
create_image_tfRecord_file(record_file_name=TFRECORD_PATH, images_dict=images)

### Read TFRecord File

In [77]:
# Create a dictionary describing the features.
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'depth': tf.io.FixedLenFeature([], tf.int64),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'image_raw': tf.io.FixedLenFeature([], tf.string),
}

In [79]:
def read_tfRecord_image(filename, feature_description, up_to_batch=None):
    result = []

    raw_dataset = tf.data.TFRecordDataset(filename)

    def _parse_function(example_proto):
      # Parse the input `tf.train.Example` proto using the dictionary above.
      return tf.io.parse_single_example(example_proto, feature_description)

    parsed_dataset = raw_dataset.map(_parse_function)

    if up_to_batch != None:
      for raw_record in parsed_dataset.take(up_to_batch):
          result.append(raw_record)
    else:
      for raw_record in parsed_dataset:
          result.append(raw_record)

    return result

In [80]:
image_dataset = read_tfRecord_image(TFRECORD_PATH, image_feature_description)

### Visualize Images read from TFRecord File

In [81]:
for image_features in image_dataset.take(2):
  image_raw = image_features['image_raw'].numpy()
  display.display(display.Image(data=image_raw))
  print()