# Writing data into TFRecord

In [2]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

# Generating int type
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Geberating string type
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets('./mnist',dtype=tf.uint8,one_hot=True)

images = mnist.train.images

labels = mnist.train.labels

pixels = images.shape[1]

num_examples = mnist.train.num_examples

# The path to input TFRecord files
filename = './mnist/output.tfrecords'

# Creating a writer for TFRecord files
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    # transfering image matrix into a string
    image_raw = images[index].tostring()
    
    # trandfering a example into Example Protocol Buffer, and writing all information into this structure
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels':_int64_feature(pixels),
        'label':_int64_feature(np.argmax(labels[index])),
        'image_raw': _bytes_feature(image_raw)
    }))
    
    # Writing an Example into TFRecord file
    writer.write(example.SerializeToString())
    
writer.close()

Extracting ./mnist/train-images-idx3-ubyte.gz
Extracting ./mnist/train-labels-idx1-ubyte.gz
Extracting ./mnist/t10k-images-idx3-ubyte.gz
Extracting ./mnist/t10k-labels-idx1-ubyte.gz


# Reading data from TFRecord

In [4]:
import tensorflow as tf

# Creating a reader for reading the example in TFRecord
reader = tf.TFRecordReader()

# Creating a queue for protecting the list of inputdata
filename_queue = tf.train.string_input_producer(
    ['./mnist/output.tfrecords'])

# Reading one example from file; you can also read more examples by read_up_to
_, serialized_example = reader.read(filename_queue)

# Parsing the reading example; parse_example for more examples
features = tf.parse_single_example(
    serialized_example,
    features={
        # tf.FixedLenFeatue, leads to Tensor; tf.VarLenFeature, leads to SparseTensor
        'image_raw':tf.FixedLenFeature([], tf.string),
        'pixels': tf.FixedLenFeature([], tf.int64),
        'label': tf.FixedLenFeature([], tf.int64)
    })

# tf.decode_raw parses strings into pixels matrix
images = tf.decode_raw(features['image_raw'],tf.uint8)
labels = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)

sess = tf.Session()
# Set up multi-thread process
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# Reading one example from TFRecord every running
for i in range(10):
    image, label, pixel = sess.run([images,labels,pixels])