# Estimator

http://tensorflow.classcat.com/2017/09/18/tensorflow-get-started-estimator/
https://www.tensorflow.org/get_started/estimator
https://book.mynavi.jp/manatee/detail/id=79420


In [2]:
import os
import urllib.request

import numpy as np
import tensorflow as tf

In [3]:
IRIS_TRAINING = "iris_training.csv"
IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv"

IRIS_TEST = "iris_test.csv"
IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

In [8]:
def main():
    if not os.path.exists(IRIS_TRAINING):
        urllib.request.urlretrieve(IRIS_TRAINING_URL, IRIS_TRAINING)
    if not os.path.exists(IRIS_TEST):
        urllib.request.urlretrieve(IRIS_TEST_URL, IRIS_TEST)
    
    training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=IRIS_TRAINING,
        target_dtype=np.int,
        features_dtype=np.float32)
    test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
        filename=IRIS_TEST,
        target_dtype=np.int,
        features_dtype=np.float32)

    feature_columns = [tf.feature_column.numeric_column('x', shape=[4])]
    
    classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10,20,10],
                                           n_classes=3,
                                           model_dir='/tmp/iris_model')
    train_input_fn = tf.estimator.inputs.numpy_input_fn(
        x = {'x':np.array(training_set.data)},
        y = np.array(training_set.target),
        num_epochs=None,
        shuffle=True)
    
    classifier.train(input_fn=train_input_fn, steps=2000)
    
    test_input_fn = tf.estimator.inputs.numpy_input_fn(
        x = {'x':np.array(test_set.data)},
        y = np.array(test_set.target),
        num_epochs=1,
        shuffle=False)
    
    accuracy_score = classifier.evaluate(input_fn=test_input_fn)['accuracy']
    
    print('\nTest Accuracy: {0:f}\n'.format(accuracy_score))
    
    new_samples = np.array(
        [[6.4, 3.2, 4.5, 1.5],
        [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)
    predict_input_fn = tf.estimator.inputs.numpy_input_fn(
        x={'x':new_samples},
        num_epochs=1,
        shuffle=False)
    predictions = list(classifier.predict(input_fn=predict_input_fn))
    predicted_classes = [p['classes'] for p in predictions]
    
    print(
        'New Sample, Class Predictions: {}\n'.format(predicted_classes))
    
    
if __name__ == "__main__":
    main()


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_keep_checkpoint_max': 5, '_model_dir': '/tmp/iris_model', '_log_step_count_steps': 100, '_session_config': None, '_tf_random_seed': 1, '_save_summary_steps': 100, '_keep_checkpoint_every_n_hours': 10000}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.
INFO:tensorflow:step = 1, loss = 292.973
INFO:tensorflow:global_step/sec: 1013.67
INFO:tensorflow:step = 101, loss = 14.446 (0.100 sec)
INFO:tensorflow:global_step/sec: 1021.22
INFO:tensorflow:step = 201, loss = 7.75755 (0.098 sec)
INFO:tensorflow:global_step/sec: 997.039
INFO:tensorflow:step = 301, loss = 12.0455 (0.101 sec)
INFO:tensorflow:global_step/sec: 989.255
INFO:tensorflow:step = 401, loss = 6.32394 (0.101 sec)
INFO:tensorflow:global_step/sec: 945.513
INFO:tensorflow:step = 501, loss = 8.52233 (0.106 sec)
INFO:tensorflow:global_