TensorFlow的高级机器学习API (tf.contrib.learn)很容易配置、训练以及评估各种各样的学习模型。这次要构造一个神经网络分类模型，用鸢尾花训练数据进行训练，然后预测鸢尾花对类别。步骤如下：

·把csv格式的训练/测试数据加载到TensorFlow

·构建一个神经网络分类器

·训练模型

·评估模型精度

·预测新样本的分类

### 下载训练与准备数据

![flowers](https://www.tensorflow.org/images/iris_three_species.jpg)

每条数据 包含三个特征量和一条标签，样例数据如下

![test](http://ok33lph8y.bkt.clouddn.com/snipaste20180323_065237.png)

### 导入依赖库

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

  from ._conv import register_converters as _register_converters


然后用 load_csv_with_header()方法把数据加载到数据集(Datasets)变量里. 这个方法有3个参数：

filename：csv文件地址 target_dtype: 数据集目标值的类型。本例中，目标是分类，用0～2表示，所以是整数。 features_dtype：数据集特征值的类型。本例中，特征是花萼、花瓣的长宽，是浮点数。

In [3]:
# Data sets
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

tf.contrib.learn的数据集(Datasets)是命名元组(named tuples)类型的; 可以通过data和target属性来访问特征数据和标签数据。

### 准备我们的模型
tf.contrib.learn预定义了很多模型，可以通过Estimators来使用。我们可以直接拿它们来训练和评估数据。下面用 DNNClassifier 来初始化一个深度神经网络分类器。

In [4]:
# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="/tmp/iris_model")

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_steps': None, '_session_config': None, '_evaluation_master': '', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000001D7F3B6DB38>, '_tf_random_seed': None, '_num_ps_replicas': 0, '_task_id': 0, '_keep_checkpoint_every_n_hours': 10000, '_task_type': None, '_keep_checkpoint_max': 5, '_model_dir': '/tmp/iris_model', '_master': '', '_log_step_count_steps': 100, '_environment': 'local', '_num_worker_replicas': 0, '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1
}
, '_save_checkpoints_secs': 600, '_save_summary_steps': 100, '_is_chief': True}


第一行定义模型的特征列，设定数据集里面特征值的数据类型。数据集里有4个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度)，所以把维度(dimensions)设为4.

DNNClassifier模型的参数：

feature_columns=feature_columns 特征列的值

hidden_units=[10, 20, 10] 三个隐藏层，分别有10、20、10个神经元

n_classes=3 三个分类，分别代表三种花

model_dir=/tmp/iris_model 训练的时候，TensorFlow用来存储检查点(checkpoint)数据。详情可参考：
https://www.tensorflow.org/versions/r0.12/tutorials/monitors/index.html

### 训练模型
现在有了DNN classifier 模型，可以通过fit()方法去适配训练数据集。传入训练集的特征(training_set.data)、目标值(training_set.target)和训练步数(这里是2000)

In [5]:
# Fit model
classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model\model.ckpt.
INFO:tensorflow:loss = 2.0995634, step = 1
INFO:tensorflow:global_step/sec: 512.771
INFO:tensorflow:loss = 0.30067253, step = 101 (0.198 sec)
INFO:tensorflow:global_step/sec: 717.511
INFO:tensorflow:loss = 0.12677288, step = 201 (0.139 sec)
INFO:tensorflow:global_step/sec: 

DNNClassifier(params={'activation_fn': <function relu at 0x000001D7EFFC61E0>, 'gradient_clip_norm': None, 'embedding_lr_multipliers': None, 'optimizer': None, 'input_layer_min_slice_size': None, 'hidden_units': [10, 20, 10], 'dropout': None, 'feature_columns': (_RealValuedColumn(column_name='', dimension=4, default_value=None, dtype=tf.float32, normalizer=None),), 'head': <tensorflow.contrib.learn.python.learn.estimators.head._MultiClassHead object at 0x000001D7F3B6DCC0>})

模型的状态会存在classifier里，所以可以反复训练。上面的代码等价于

In [6]:
classifier.fit(x=training_set.data, y=training_set.target, steps=1000)
classifier.fit(x=training_set.data, y=training_set.target, steps=1000)

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from /tmp/iris_model\model.ckpt-2000
INFO:tensorflow:Saving checkpoints for 2001 into /tmp/iris_model\model.ckpt.
INFO:tensorflow:loss = 0.04439896, step = 2001
INFO:tensorflow:global_step/sec: 824.245
INFO:tensorflow:loss = 0.0440429, step = 2101 (0.122 sec)
INFO:tensorflow:global_step/sec: 882.601
INFO:tensorflow:loss = 0.043414734, step = 2201 (0.114 sec)
INFO:tensorflow:global_step/sec: 852.428
INFO:tensorflow:loss = 0.0435036, step = 2301 (0.117 sec)
INFO:tensorflow:global_step/sec: 886.528
INFO:tensorflow:loss = 0.04209782, step = 2401 (0.113 sec)
INFO:tensorflow:global_step/sec: 834.571
INFO:tensorflow:loss = 0.04159502, step = 2501 (0.123 sec)
INFO:tensorflow:global_step/sec: 838.066
INFO:tensorflow:loss = 0.041020524, step = 2601 (0.117 sec)
INFO:tensorflow:global_step/sec: 776.07
INFO:tensorflow:loss = 0.04287686, step = 2701 (0.128 sec)
INFO:tensorflow:global_step/sec: 727.988
INFO:tensorflow:lo

DNNClassifier(params={'activation_fn': <function relu at 0x000001D7EFFC61E0>, 'gradient_clip_norm': None, 'embedding_lr_multipliers': None, 'optimizer': None, 'input_layer_min_slice_size': None, 'hidden_units': [10, 20, 10], 'dropout': None, 'feature_columns': (_RealValuedColumn(column_name='', dimension=4, default_value=None, dtype=tf.float32, normalizer=None),), 'head': <tensorflow.contrib.learn.python.learn.estimators.head._MultiClassHead object at 0x000001D7F3B6DCC0>})

### 评估模型精度
我们已经训练了DNNClassifier模型，可以用evaluate()方法在测试数据集上检查它的精度。我们已经训练了DNNClassifier模型，可以用evaluate也需要传入data和target，返回一个dict类型的结果。

In [7]:
accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
INFO:tensorflow:Starting evaluation at 2018-03-22-23:05:54
INFO:tensorflow:Restoring parameters from /tmp/iris_model\model.ckpt-4000
INFO:tensorflow:Finished evaluation at 2018-03-22-23:05:55
INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.96666664, global_step = 4000, loss = 0.066047914
Accuracy: 0.966667


###  预测新样本的分类

TensorFlow提供了classifier.predict()来预测新样本。

In [8]:
# Classify two new flower samples.
new_samples = np.array([[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))

Instructions for updating:
Please switch to predict_classes, or set `outputs` argument.
Instructions for updating:
Estimator is decoupled from Scikit Learn interface by moving into
separate class SKCompat. Arguments x, y and batch_size are only
available in the SKCompat class, Estimator will only accept input_fn.
Example conversion:
  est = Estimator(...) -> est = SKCompat(Estimator(...))
INFO:tensorflow:Restoring parameters from /tmp/iris_model\model.ckpt-4000
Predictions: [1, 2]
