In [1]:
import numpy as np
import tensorflow as tf
import os
import pprint
from tensorflow import keras

In [2]:
source_dir = "./generate_csv/"

# 通过判断开头去添加文件
def get_filename_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results


train_filenames = get_filename_by_prefix(source_dir, 'train')
valid_filenames = get_filename_by_prefix(source_dir, 'valid')
test_filenames = get_filename_by_prefix(source_dir, 'test')
pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)

['./generate_csv/train_00.csv',
 './generate_csv/train_01.csv',
 './generate_csv/train_02.csv',
 './generate_csv/train_03.csv',
 './generate_csv/train_04.csv',
 './generate_csv/train_05.csv',
 './generate_csv/train_06.csv',
 './generate_csv/train_07.csv',
 './generate_csv/train_08.csv',
 './generate_csv/train_09.csv',
 './generate_csv/train_10.csv',
 './generate_csv/train_11.csv',
 './generate_csv/train_12.csv',
 './generate_csv/train_13.csv',
 './generate_csv/train_14.csv',
 './generate_csv/train_15.csv',
 './generate_csv/train_16.csv',
 './generate_csv/train_17.csv',
 './generate_csv/train_18.csv',
 './generate_csv/train_19.csv']
['./generate_csv/valid_00.csv',
 './generate_csv/valid_01.csv',
 './generate_csv/valid_02.csv',
 './generate_csv/valid_03.csv',
 './generate_csv/valid_04.csv',
 './generate_csv/valid_05.csv',
 './generate_csv/valid_06.csv',
 './generate_csv/valid_07.csv',
 './generate_csv/valid_08.csv',
 './generate_csv/valid_09.csv']
['./generate_csv/test_00.csv',
 './gener

In [3]:
os.listdir(source_dir)

['test_00.csv',
 'test_01.csv',
 'test_02.csv',
 'test_03.csv',
 'test_04.csv',
 'test_05.csv',
 'test_06.csv',
 'test_07.csv',
 'test_08.csv',
 'test_09.csv',
 'train_00.csv',
 'train_01.csv',
 'train_02.csv',
 'train_03.csv',
 'train_04.csv',
 'train_05.csv',
 'train_06.csv',
 'train_07.csv',
 'train_08.csv',
 'train_09.csv',
 'train_10.csv',
 'train_11.csv',
 'train_12.csv',
 'train_13.csv',
 'train_14.csv',
 'train_15.csv',
 'train_16.csv',
 'train_17.csv',
 'train_18.csv',
 'train_19.csv',
 'valid_00.csv',
 'valid_01.csv',
 'valid_02.csv',
 'valid_03.csv',
 'valid_04.csv',
 'valid_05.csv',
 'valid_06.csv',
 'valid_07.csv',
 'valid_08.csv',
 'valid_09.csv']

In [4]:
def parse_csv_line(line, n_fields=9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y


def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffile_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()  # 无限次重复服务于epoch
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length=n_readers)
    dataset.shuffle(shuffile_buffer_size),
    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset


batch_size = 32
train_set = csv_reader_dataset(train_filenames, batch_size=batch_size)
valid_set = csv_reader_dataset(valid_filenames, batch_size=batch_size)
test_set = csv_reader_dataset(test_filenames, batch_size=batch_size)

In [5]:
# 把train_set,valid_set,test_set存储到tfrecord类型的文件中
# 封装
def serialize_example(x, y):
    """Converts x, y to tf.train.Example and serialize"""
    input_features = tf.train.FloatList(value=x)  #特征
    label = tf.train.FloatList(value=y)
    features = tf.train.Features(
        feature={
            'input_features': tf.train.Feature(float_list=input_features),
            'label': tf.train.Feature(float_list=label)})
    example = tf.train.Example(features=features)
    return example.SerializeToString()


# n_shards是存为多少个文件，steps_per_shard和 steps_per_epoch类似
def csv_dataset_to_tfrecords(base_filename, dataset,
                             n_shards, steps_per_shard,
                             compression_type=None):
    # 压缩文件类型
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    all_filenames = []
    for shard_id in range(n_shards):
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
            base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames

In [6]:
%%time
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // 10
test_steps_per_shard = 5170 // batch_size // 10
output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, 10, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, 10, test_steps_per_shard, None)

CPU times: total: 49.5 s
Wall time: 38.7 s


In [7]:
# 生成压缩文件
# n_shards = 20
# train_steps_per_shard = 11610 // batch_size // n_shards
# valid_steps_per_shard = 3880 // batch_size // n_shards
# test_steps_per_shard = 5170 // batch_size // n_shards

# output_dir = "generate_tfrecords_zip"
# if not os.path.exists(output_dir):
#     os.mkdir(output_dir)

# train_basename = os.path.join(output_dir, "train")
# valid_basename = os.path.join(output_dir, "valid")
# test_basename = os.path.join(output_dir, "test")
# #只需修改参数的类型即可
# train_tfrecord_filenames = csv_dataset_to_tfrecords(
#     train_basename, train_set, n_shards, train_steps_per_shard,
#     compression_type = "GZIP")
# valid_tfrecord_filenames = csv_dataset_to_tfrecords(
#     valid_basename, valid_set, n_shards, valid_steps_per_shard,
#     compression_type = "GZIP")
# test_tfrecord_fielnames = csv_dataset_to_tfrecords(
#     test_basename, test_set, n_shards, test_steps_per_shard,
#     compression_type = "GZIP")

In [8]:
pprint.pprint(train_tfrecord_filenames)
pprint.pprint(valid_tfrecord_filenames)
pprint.pprint(test_tfrecord_fielnames)

['generate_tfrecords\\train_00000-of-00020',
 'generate_tfrecords\\train_00001-of-00020',
 'generate_tfrecords\\train_00002-of-00020',
 'generate_tfrecords\\train_00003-of-00020',
 'generate_tfrecords\\train_00004-of-00020',
 'generate_tfrecords\\train_00005-of-00020',
 'generate_tfrecords\\train_00006-of-00020',
 'generate_tfrecords\\train_00007-of-00020',
 'generate_tfrecords\\train_00008-of-00020',
 'generate_tfrecords\\train_00009-of-00020',
 'generate_tfrecords\\train_00010-of-00020',
 'generate_tfrecords\\train_00011-of-00020',
 'generate_tfrecords\\train_00012-of-00020',
 'generate_tfrecords\\train_00013-of-00020',
 'generate_tfrecords\\train_00014-of-00020',
 'generate_tfrecords\\train_00015-of-00020',
 'generate_tfrecords\\train_00016-of-00020',
 'generate_tfrecords\\train_00017-of-00020',
 'generate_tfrecords\\train_00018-of-00020',
 'generate_tfrecords\\train_00019-of-00020']
['generate_tfrecords\\valid_00000-of-00010',
 'generate_tfrecords\\valid_00001-of-00010',
 'generate

In [9]:
%%time
expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}


def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]


def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()  #为了能够无限次epoch
    dataset = dataset.interleave(
        # lambda filename: tf.data.TFRecordDataset(filename, compression_type = "GZIP"),
        lambda filename: tf.data.TFRecordDataset(filename),
        cycle_length=n_readers)
    # 洗牌打乱样本数据顺序
    dataset.shuffle(shuffle_buffer_size)
    # 字节流样本转变为浮点类型
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

CPU times: total: 0 ns
Wall time: 0 ns


In [10]:
# 测试tfrecords_reader_dataset
tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size=3)
for x_batch, y_batch in tfrecords_train.take(10):
    print(x_batch)
    print(y_batch)

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C575F53A0>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C575F53A0>: no matching AST found among candidates:

Cause: Unable to locate the source code of <function parse_example at 0x0000023C7E1B4310>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.experimental.do_not_convert. Original error: could not get source code
Cause: Unable to locate the source code of <function parse_example at 0x0000023C7E1B4310>. Note that functions defined in certain environments, like the interactive Python shell, do not expose their source code. If that is the case, you should defi

In [11]:
%%time
#dataset的tensor，可以直接拿tensor训练
batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size=batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size=batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size=batch_size)
print(type(tfrecords_train_set))
for i in tfrecords_train_set.take(1):
    print(i)

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C7DFC60D0>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C7DFC60D0>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C575F5430>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C575F5430>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C7E11C670>: no matching AST found among candidates:

Cause: could not parse the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x0000023C7E11C670>: no matching AST found among candidates:

<class 'tensorflow.python.data.ops.datas

In [12]:
model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]
history = model.fit(tfrecords_train_set,
                    validation_data=tfrecords_valid_set,
                    steps_per_epoch=11160 // batch_size,
                    validation_steps=3870 // batch_size,
                    epochs=100,
                    callbacks=callbacks)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100


In [13]:
print(model.evaluate(tfrecords_test_set, steps=5160 // batch_size))

0.3699362576007843
