In [1]:
import tensorflow as tf
import math
import os
import random
import sys

In [2]:
# 验证集数量
_NUM_TEST = 500
# 随机种子
_RANDOM_SEED = 0
# 数据块
_NUM_SHARDS = 5
# 数据集路径
DATASET_DIR = "./images/"
# 标签文件名字
LABELS_FILENAME = "./images/labels.txt"

# 定义tfrecord文件的路径+名字
def _get_dataset_filename(dataset_dir,split_name,shard_id):
    output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)
    return os.path.join(dataset_dir,output_filename)

# 判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train','test']:
        for shard_id in range(_NUM_SHARDS):
            # 定义tfrecord文件的路径+名字
            output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)
        if not tf.gfile.Exists(output_filename):
            return False
    return True



In [None]:
'''
    1.TensorFlow学习记录-- ７.TensorFlow高效读取数据之tfrecord详细解读:
      https://blog.csdn.net/qq_16949707/article/details/53483493
    2.TensorFlow 学习（二） 制作自己的TFRecord数据集，读取，显示及代码详解:
      https://blog.csdn.net/miaomiaoyuan/article/details/56865361
    3.Tensorflow中使用tfrecord方式读取数据
      https://blog.csdn.net/u010358677/article/details/70544241
'''

# 官网：
# tfecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的，下面是tf.train.Example的定义

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

# 生成Tfrecord文件
# 第一种定义【2】：
# tf.train.Example 协议内存块包含了Features字段，通过feature将图片的二进制数据和label进行统一封装， 然后
# 将example协议内存块转化为字符串， tf.python_io.TFRecordWriter 写入到TFRecords文件中。
# 生成文件dog_train.tfrecords 

import os 
import tensorflow as tf 
from PIL import Image  #注意Image,后面会用到
import matplotlib.pyplot as plt 
import numpy as np

cwd='D:\Python\data\dog\\' 
classes={'husky','chihuahua'} #人为 设定 2 类
writer= tf.python_io.TFRecordWriter("dog_train.tfrecords") #要生成的文件

for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name #每一个图片的地址

        img=Image.open(img_path)
        img= img.resize((128,128))
        img_raw=img.tobytes()#将图片转化为二进制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example对象对label和image数据进行封装
        writer.write(example.SerializeToString())  #序列化为字符串

writer.close()

# 第二种定义【3】
# 部分节选

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio


def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))

example = tf.train.Example(features=tf.train.Features(feature={
    'height': _int64_feature(height),
    'width': _int64_feature(width),
    'name': _bytes_feature(item[0]),
    'image_raw': _bytes_feature(img_raw),
    'mask_raw': _bytes_feature(mask_raw),
    'label': _int64_feature(label)}))

# 读取Tfrecord文件
# feature的属性“label”和“img_raw”名称要和制作时统一 
def read_and_decode(filename): # 读入dog_train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])#生成一个queue队列

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#返回文件名和文件
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#将image数据和label取出来

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [128, 128, 3])  #reshape为128*128的3通道图片
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中抛出img张量
    label = tf.cast(features['label'], tf.int32) #在流中抛出label张量
    return img, label