In [2]:
# 加载库文件
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

#### 1.将输入数据转化成TFRecord格式保存

In [5]:
# 定义生成整数类型属性的函数
def _int64_feature(value):
    # 生成整数类型的属性
    int64_list = tf.train.Int64List(value=[value])
    return tf.train.Feature(int64_list=int64_list)

# 定义生成字符类型的属性的函数
def _bytes_feature(value):
    # 把数据转化为字符类型
    bytes_list = tf.train.BytesList(value=[value])
    # 用字符类型的数据生成字符类型的属性
    return tf.train.Feature(bytes_list=bytes_list)

# 定义将数据转化为tf.train.Example格式的函数
def _make_example(pixels, label, image):
    # 将图像数据转化为字符串
    image_raw = image.tostring()
    # 将一个样例转化为Example Buffer Protocol格式，并将所有信息写入
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels':_int64_feature(pixels),
        'label':_int64_feature(np.argmax(label)),
        'image_raw':_bytes_feature(image_raw)
    }))
    
    # 返回一个样例数据
    return example

# 定义数据集路径
dataset_path = "../../../../TensorFlow/datasets/MNIST_data/"

# 读取mnist训练数据
mnist = input_data.read_data_sets(dataset_path, dtype=tf.uint8, one_hot=True)
train_images = mnist.train.images
train_labels = mnist.train.labels
train_pixels = train_images.shape[1]
train_num = mnist.train.num_examples

# 定义TFRecord文件地址
filename = "../../../../practise/output.tfrecords"

# 输出包含训练数据的TFRecord文件
with tf.python_io.TFRecordWriter(filename) as writer:
    for index in range(train_num):
        example = _make_example(train_pixels, train_labels[index], train_images[index])
        writer.write(example.SerializeToString())
print("TFRecord训练文件已经保存")

# 读取mnist测试数据
test_images = mnist.test.images
test_labels = mnist.test.labels
test_pixels = test_images.shape[1]
test_num = mnist.test.num_examples

# 定义TFRecord文件地址
filename = "../../../../practise/test_output.tfrecord"

# 输出包含测试数据的TFRecord文件
with tf.python_io.TFRecordWriter(filename) as writer:
    for index in range(test_num):
        # 构建样例
        example = _make_example(test_pixels, test_labels[index], test_images[index])
        # 写入样例
        writer.write(example.SerializeToString())
print("TFRecord测试数据已经保存")
    

Extracting ../../../../TensorFlow/datasets/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../../../TensorFlow/datasets/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../../../TensorFlow/datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../../../TensorFlow/datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
TFRecord训练文件已经保存
TFRecord测试数据已经保存


#### 2.读取TFRecord文件

In [8]:
# 定义训练集tfrecord的路径
train_path = "../../../../practise/output.tfrecords"

# 创建一个reader来读取tfrecord的样例
tfrecord_reader = tf.TFRecordReader()

# 创建一个队列来维护输入文件列表
file_queue = tf.train.string_input_producer([train_path])

# 从文件中读取一个样例。使用read_up_to可以一次性读多个样例
_, serialized_example = tfrecord_reader.read(file_queue)

# 解析读到的样例
features = tf.parse_single_example(
    serialized_example,
    features={
        'image_raw':tf.FixedLenFeature([], tf.string),
        'pixels':tf.FixedLenFeature([], tf.int64),
        'label':tf.FixedLenFeature([], tf.int64)
    })

# 使用tf.decode_raw可以将字符串解析成图像对应的像素
images = tf.decode_raw(features['image_raw'], tf.uint8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()

# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(10):
    image, label, pixel = sess.run([images, labels, pixels])