<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#定义输入函数" data-toc-modified-id="定义输入函数-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>定义输入函数</a></span></li><li><span><a href="#定义模型" data-toc-modified-id="定义模型-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>定义模型</a></span></li><li><span><a href="#训练模型" data-toc-modified-id="训练模型-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>训练模型</a></span></li></ul></div>

In [2]:
import tensorflow as tf

In [3]:
# 重置计算图
tf.reset_default_graph()

# 设置记录消息的阈值
tf.logging.set_verbosity(tf.logging.INFO)

# 定义输入函数

In [4]:
# 定义特征名称
feature_names = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']

# 定义输入函数
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
    # 定义解码函数
    def decode_csv(line):
        # 把line(文本中每一行数据)分解成四个float, 一个int类型数据
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        
        # 获取label
        label = parsed_line[-1] # 最后一个元素为label
        del parsed_line[-1] # 删除label
        
        # 获取样本特征
        features = parsed_line # 除了最后一个元素, 其他的为样本的特征数据
        
        # 返回特征数据和label
        d = dict(zip(feature_names, features)), label
        return d
        
    # 读入文件路径并解析数据
    dataset = (tf.data.TextLineDataset(file_path) # 读取text file
               .skip(1) # 跳过文件的第一行
               .map(decode_csv)) # 对文本中每条数据应用函数decode_csv
               
    # 是否打乱数据顺序
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)
    
    # 设置数据重复次数
    dataset = dataset.repeat(repeat_count)
    
    # 设置batch大小
    dataset = dataset.batch(32)
    
    # 生成数据迭代器
    iterator = dataset.make_one_shot_iterator()
    
    # 获取一个batch的数据
    batch_features, batch_labels = iterator.get_next()
    
    # 返回一个batch的数据
    return batch_features, batch_labels

# 创建特征列: 所有的输入都是numeric
features_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# 定义模型

In [5]:
# 定义checkpoint存储的位置
PATH = "../../datasets/tf_dataset_and_estimator_api"

classifier = tf.estimator.DNNClassifier(feature_columns=features_columns, 
                                        hidden_units=[10, 10], 
                                        n_classes=3, 
                                        model_dir=PATH)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_steps': None, '_task_type': 'worker', '_master': '', '_tf_random_seed': None, '_num_ps_replicas': 0, '_is_chief': True, '_task_id': 0, '_keep_checkpoint_max': 5, '_global_id_in_cluster': 0, '_evaluation_master': '', '_keep_checkpoint_every_n_hours': 10000, '_save_checkpoints_secs': 600, '_session_config': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x0000027F6CF40F98>, '_save_summary_steps': 100, '_service': None, '_num_worker_replicas': 1, '_log_step_count_steps': 100, '_train_distribute': None, '_model_dir': '../../datasets/tf_dataset_and_estimator_api'}


# 训练模型

In [5]:
FILE_TRAIN = "../../../TensorFlow/datasets/iris_training.csv"
classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into ../../../TensorFlow/checkpoint/tf_dataset_and_estimator_api\model.ckpt.
INFO:tensorflow:loss = 58.99401, step = 1
INFO:tensorflow:Saving checkpoints for 30 into ../../../TensorFlow/checkpoint/tf_dataset_and_estimator_api\model.ckpt.
INFO:tensorflow:Loss for final step: 17.460835.


<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x23152f2f278>