In [5]:
import os
import h5py
import numpy as np
from PIL import Image
import tensorflow as tf

In [16]:
def get_files(file_dir, type = 'train'):
    if type == 'train':
        with h5py.File(os.path.join(file_dir, 'train_happy.h5')) as f:
            print(list(f.keys()))
            data = np.array(list(f['train_set_x']))
            label = np.array(list(f['train_set_y'])).reshape([-1, 1])
            return data, label
    
    elif type == 'validation':
        with h5py.File(os.path.join(file_dir, 'test_happy.h5')) as f:
            print(list(f.keys()))
            data = np.array(list(f['test_set_x']))
            label = np.array(list(f['test_set_y']))
            return data, label
    
    
def image2tfrecord(img, label, str_name):
    """
    将图片数据制作成 tfrecord 格式，tfrecord 是谷歌推荐的一种二进制文件格式，
    理论上它可以保存任何格式的信息
    
    Args:
        img: 图片数组，4维数组
        label: 图片对应的标签
        str_name: tfrecord 格式的文件名
    """
    
    writer = tf.python_io.TFRecordWriter(str_name)
    i = 0
    for image in img:
        image = Image.fromarray(image)
        image = image.resize((224, 224))
        image_bytes = image.tobytes()        # 将图片转换成二进制格式
        features = {}
        
        # 保存的是图片的二进制数据
        features['image_raw'] = tf.train.Feature(bytes_list = 
                                    tf.train.BytesList(value = [image_bytes]))
        # 保存的是图片的标签，也可以用来保存图片的尺寸信息
        features['label'] = tf.train.Feature(int64_list = 
                                    tf.train.Int64List(value = [int(label[i])]))
        i += 1
        
        # 将所有的 feature 合成 features
        tf_features = tf.train.Features(feature = features)
        
        # 转成 example
        tf_example = tf.train.Example(features = tf_features)
        
        # 序列化样本
        tf_serialized = tf_example.SerializeToString()
        writer.write(tf_serialized)
    
    writer.close()
        

# 将测试图片和验证图片保存为 tfrecords 格式
train_imgs, train_labels = get_files('datasets')
val_imgs, val_labels = get_files('datasets', 'validation')
image2tfrecord(train_imgs, train_labels, 'train.tfrecords')
image2tfrecord(val_imgs, val_labels, 'val.tfrecords')

['list_classes', 'train_set_x', 'train_set_y']
['list_classes', 'test_set_x', 'test_set_y']


In [17]:
from tensorflow.python.framework import graph_util

def read_and_decode_tfrecord_files(filename, batch_size):
    """
    从 tfrecord 格式的文件中读取出数据，并将其转换成正常的图片和标签
    
    Args:
        filename: tfrecords 文件的文件名（完整的路径）
        batch_size: 批大小
    """
    
    filename_queue = tf.train.string_input_producer([filename])   # 创建一个文件名队列
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    img_features = tf.parse_single_example(serialized_example, 
                        features = {'label': tf.FixedLenFeature([], tf.int64),
                                    'image_raw': tf.FixedLenFeature([], tf.string)})
    image = tf.decode_raw(img_features['image_raw', tf.uint8])
    image = tf.reshape(image, [224, 224, 3])
    image = tf.cast(image, tf.float32) / 255.0
    label = tf.cast(img_features)
    