In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
from pprint import pprint

print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

sys.version_info(major=3, minor=7, micro=5, releaselevel='final', serial=0)
matplotlib 3.1.2
numpy 1.17.4
pandas 0.25.3
sklearn 0.22
tensorflow 2.0.0
tensorflow_core.keras 2.2.4-tf


## 1. tfrecord基本结构和使用

### 1.1 生成tf.train.Example对象 

In [2]:
# tfrecord 一种文件格式，以下是结构
# -> tf.train.Example
#       -> tf.train.Features -> {"key": tf.train.Feature}
#            -> tf.train.Feature  
#                 -> tf.train.ByteList
#                 -> tf.train.FloatList
#                 -> tf.train.Int64List

In [3]:
# tf.train.ByteList
favorite_books = [name.encode("utf-8") for name in ["machine learning", "ccl150"]]
favorite_books_bytelist = tf.train.BytesList(value=favorite_books)
print(favorite_books)
print(favorite_books_bytelist)

# tf.train.FloatList
hours_floatlist = tf.train.FloatList(value=[15.5, 0.9, 10.2, 5.1])
print(hours_floatlist)

# tf.train.Int64List
age_int64list = tf.train.Int64List(value=[42])
print(age_int64list)

[b'machine learning', b'ccl150']
value: "machine learning"
value: "ccl150"

value: 15.5
value: 0.8999999761581421
value: 10.199999809265137
value: 5.099999904632568

value: 42



In [4]:
features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(bytes_list =                                                                       favorite_books_bytelist),
        "hours": tf.train.Feature(float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list)
    }
)

print(features)

feature {
  key: "age"
  value {
    int64_list {
      value: 42
    }
  }
}
feature {
  key: "favorite_books"
  value {
    bytes_list {
      value: "machine learning"
      value: "ccl150"
    }
  }
}
feature {
  key: "hours"
  value {
    float_list {
      value: 15.5
      value: 0.8999999761581421
      value: 10.199999809265137
      value: 5.099999904632568
    }
  }
}



In [5]:
example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()   # 序列化以减小文件存储空间
print(serialized_example)

features {
  feature {
    key: "age"
    value {
      int64_list {
        value: 42
      }
    }
  }
  feature {
    key: "favorite_books"
    value {
      bytes_list {
        value: "machine learning"
        value: "ccl150"
      }
    }
  }
  feature {
    key: "hours"
    value {
      float_list {
        value: 15.5
        value: 0.8999999761581421
        value: 10.199999809265137
        value: 5.099999904632568
      }
    }
  }
}

b'\n]\n.\n\x0efavorite_books\x12\x1c\n\x1a\n\x10machine learning\n\x06ccl150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xAfff?33#A33\xa3@'


## 2. 生成tfrecord文件

In [6]:
output_dir = "tfrecord_basic"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

## 3. 读取并解析tfrecord文件

### 3.1 读取文件 

In [7]:
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

tf.Tensor(b'\n]\n.\n\x0efavorite_books\x12\x1c\n\x1a\n\x10machine learning\n\x06ccl150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xAfff?33#A33\xa3@', shape=(), dtype=string)
tf.Tensor(b'\n]\n.\n\x0efavorite_books\x12\x1c\n\x1a\n\x10machine learning\n\x06ccl150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xAfff?33#A33\xa3@', shape=(), dtype=string)
tf.Tensor(b'\n]\n.\n\x0efavorite_books\x12\x1c\n\x1a\n\x10machine learning\n\x06ccl150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xAfff?33#A33\xa3@', shape=(), dtype=string)


### 3.2 解析文件

In [8]:
# 定义数据类型
expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype=tf.string),
    "hours": tf.io.VarLenFeature(dtype=tf.float32),
    "age": tf.io.FixedLenFeature([], dtype=tf.int64)  # []很重要
    }
dataset = tf.data.TFRecordDataset([filename_fullpath])

# 直接打印 包含favorite_books,hours，age的 所有tensor
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(serialized_example_tensor,                                                    expected_features)
    print(example)
    
print("---------------------")

# 解析 book tensor
for serialized_example_tensor in dataset.take(1):
    example = tf.io.parse_single_example(serialized_example_tensor,                                                    expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))
        
    print("---------------------")
    
    hours = tf.sparse.to_dense(example["hours"])
    for hour in hours:
        print(hour.numpy())
        


{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844D6B288>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844D6BEC8>, 'age': <tf.Tensor: id=46, shape=(), dtype=int64, numpy=42>}
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844E06088>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844E06EC8>, 'age': <tf.Tensor: id=55, shape=(), dtype=int64, numpy=42>}
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844D6B488>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x0000017844D6C208>, 'age': <tf.Tensor: id=64, shape=(), dtype=int64, numpy=42>}
---------------------
machine learning
ccl150
---------------------
15.5
0.9
10.2
5.1


## 4. 生成zip压缩文件

In [9]:
output_dir = "tfrecord_basic"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath_zip = os.path.join(output_dir, filename) + ".zip"
options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

## 5. 读取压缩文件

In [10]:
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], compression_type="GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(serialized_example_tensor,                                                    expected_features)
    books = tf.sparse.to_dense(example["favorite_books"])
    for book in books:
        print(book.numpy().decode("UTF-8"))

machine learning
ccl150
machine learning
ccl150
machine learning
ccl150
