<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#定义输入函数" data-toc-modified-id="定义输入函数-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>定义输入函数</a></span></li><li><span><a href="#定义模型" data-toc-modified-id="定义模型-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>定义模型</a></span></li><li><span><a href="#训练模型" data-toc-modified-id="训练模型-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>训练模型</a></span></li></ul></div>

In [1]:
import tensorflow as tf

In [2]:
# 重置计算图
tf.reset_default_graph()

# 设置记录消息的阈值
tf.logging.set_verbosity(tf.logging.INFO)

# 定义输入函数

In [3]:
# 定义特征名称
feature_names = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']

# 定义输入函数
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
    # 定义解码函数
    def decode_csv(line):
        # 把line(文本中每一行数据)分解成四个float, 一个int类型数据
        parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
        
        # 获取label
        label = parsed_line[-1] # 最后一个元素为label
        del parsed_line[-1] # 删除label
        
        # 获取样本特征
        features = parsed_line # 除了最后一个元素, 其他的为样本的特征数据
        
        # 返回特征数据和label
        d = dict(zip(feature_names, features)), label
        return d
        
    # 读入文件路径并解析数据
    dataset = (tf.data.TextLineDataset(file_path) # 读取text file
               .skip(1) # 跳过文件的第一行
               .map(decode_csv)) # 对文本中每条数据应用函数decode_csv
               
    # 是否打乱数据顺序
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)
    
    # 设置数据重复次数
    dataset = dataset.repeat(repeat_count)
    
    # 设置batch大小
    dataset = dataset.batch(32)
    
    # 生成数据迭代器
    iterator = dataset.make_one_shot_iterator()
    
    # 获取一个batch的数据
    batch_features, batch_labels = iterator.get_next()
    
    # 返回一个batch的数据
    return batch_features, batch_labels

# 创建特征列: 所有的输入都是numeric
features_columns = [tf.feature_column.numeric_column(k) for k in feature_names]

# 定义模型

In [4]:
# 定义checkpoint存储的位置
PATH = "../../../TensorFlow/checkpoint/tf_dataset_and_estimator_api"

classifier = tf.estimator.DNNClassifier(feature_columns=features_columns, 
                                        hidden_units=[10, 10], 
                                        n_classes=3, 
                                        model_dir=PATH)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_every_n_hours': 10000, '_tf_random_seed': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001FE7E571748>, '_model_dir': '../../../TensorFlow/checkpoint/tf_dataset_and_estimator_api', '_is_chief': True, '_num_ps_replicas': 0, '_master': '', '_task_id': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_save_checkpoints_secs': 600, '_service': None, '_num_worker_replicas': 1, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_session_config': None, '_save_summary_steps': 100}


# 训练模型

In [5]:
FILE_TRAIN = "../../../TensorFlow/datasets/iris_training.csv"
classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from ../../../TensorFlow/checkpoint/tf_dataset_and_estimator_api\model.ckpt-30


InternalError: Blas GEMM launch failed : a.shape=(32, 4), b.shape=(4, 10), m=32, n=10, k=4
	 [[Node: dnn/hiddenlayer_0/MatMul = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](dnn/input_from_feature_columns/input_layer/concat, dnn/hiddenlayer_0/kernel/part_0/read)]]
	 [[Node: dnn/head/labels/_117 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_121_dnn/head/labels", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Caused by op 'dnn/hiddenlayer_0/MatMul', defined at:
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\traitlets\config\application.py", line 658, in launch_instance
    app.start()
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelapp.py", line 477, in start
    ioloop.IOLoop.instance().start()
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tornado\ioloop.py", line 888, in start
    handler_func(fd_obj, events)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\zmq\eventloop\zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tornado\stack_context.py", line 277, in null_wrapper
    return fn(*args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 283, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 235, in dispatch_shell
    handler(stream, idents, msg)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\kernelbase.py", line 399, in execute_request
    user_expressions, allow_stdin)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\ipykernel\zmqshell.py", line 533, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2662, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2785, in _run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2909, in run_ast_nodes
    if self.run_code(code, result):
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-5-a371262099fa>", line 2, in <module>
    classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 8))
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 314, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 743, in _train_model
    features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\estimator.py", line 725, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\canned\dnn.py", line 324, in _model_fn
    config=config)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\canned\dnn.py", line 176, in _dnn_model_fn
    logits = logit_fn(features=features, mode=mode)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\estimator\canned\dnn.py", line 100, in dnn_logit_fn
    name=hidden_layer_scope)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\core.py", line 253, in dense
    return layer.apply(inputs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py", line 762, in apply
    return self.__call__(inputs, *args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py", line 652, in __call__
    outputs = self.call(inputs, *args, **kwargs)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\core.py", line 162, in call
    outputs = standard_ops.matmul(inputs, self.kernel)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\math_ops.py", line 2022, in matmul
    a, b, transpose_a=transpose_a, transpose_b=transpose_b, name=name)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 2799, in _mat_mul
    name=name)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
    op_def=op_def)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3160, in create_op
    op_def=op_def)
  File "D:\Program Files (x86)\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 1625, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(32, 4), b.shape=(4, 10), m=32, n=10, k=4
	 [[Node: dnn/hiddenlayer_0/MatMul = MatMul[T=DT_FLOAT, transpose_a=false, transpose_b=false, _device="/job:localhost/replica:0/task:0/device:GPU:0"](dnn/input_from_feature_columns/input_layer/concat, dnn/hiddenlayer_0/kernel/part_0/read)]]
	 [[Node: dnn/head/labels/_117 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_121_dnn/head/labels", tensor_type=DT_INT32, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
