# Wide & Deep Tutorial
**tfversion**: v1.3.0-rc2-20-g0787eee
**commit**: 41881b93b1a2b766b69602eb79d3a0514043b7e3 

In [1]:
import argparse
import shutil
import sys
import tempfile

import pandas as pd
from six.moves import urllib
import tensorflow as tf

In [2]:
CSV_COLUMNS = [
    "age", "workclass", "fnlwgt", "education", "education_num",
    "marital_status", "occupation", "relationship", "race", "gender",
    "capital_gain", "capital_loss", "hours_per_week", "native_country",
    "income_bracket"
]

In [3]:
# Categorical base columns.
gender = tf.feature_column.categorical_column_with_vocabulary_list(
    "gender", ["Female", "Male"])
education = tf.feature_column.categorical_column_with_vocabulary_list(
    "education", [
        "Bachelors", "HS-grad", "11th", "Masters", "9th",
        "Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
        "Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
        "Preschool", "12th"
    ])
marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
    "marital_status", [
        "Married-civ-spouse", "Divorced", "Married-spouse-absent",
        "Never-married", "Separated", "Married-AF-spouse", "Widowed"
    ])
relationship = tf.feature_column.categorical_column_with_vocabulary_list(
    "relationship", [
        "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
        "Other-relative"
    ])
workclass = tf.feature_column.categorical_column_with_vocabulary_list(
    "workclass", [
        "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
        "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
    ])

# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
    "occupation", hash_bucket_size=1000)
native_country = tf.feature_column.categorical_column_with_hash_bucket(
    "native_country", hash_bucket_size=1000)

# Continuous base columns.
age = tf.feature_column.numeric_column("age")
education_num = tf.feature_column.numeric_column("education_num")
capital_gain = tf.feature_column.numeric_column("capital_gain")
capital_loss = tf.feature_column.numeric_column("capital_loss")
hours_per_week = tf.feature_column.numeric_column("hours_per_week")

# Transformations.
age_buckets = tf.feature_column.bucketized_column(
    age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

In [4]:
# Wide columns and deep columns.
base_columns = [
    gender, education, marital_status, relationship, workclass, occupation,
    native_country, age_buckets,
]

crossed_columns = [
    tf.feature_column.crossed_column(
        ["education", "occupation"], hash_bucket_size=1000),
    tf.feature_column.crossed_column(
        [age_buckets, "education", "occupation"], hash_bucket_size=1000),
    tf.feature_column.crossed_column(
        ["native_country", "occupation"], hash_bucket_size=1000)
]

deep_columns = [
    tf.feature_column.indicator_column(workclass),
    tf.feature_column.indicator_column(education),
    tf.feature_column.indicator_column(gender),
    tf.feature_column.indicator_column(relationship),
    # To show an example of embedding
    tf.feature_column.embedding_column(native_country, dimension=8),
    tf.feature_column.embedding_column(occupation, dimension=8),
    age,
    education_num,
    capital_gain,
    capital_loss,
    hours_per_week,
]

In [5]:
def maybe_download(train_data, test_data):
    """Maybe downloads training data and returns train and test file names."""
    if train_data:
        train_file_name = train_data
    else:
        train_file = tempfile.NamedTemporaryFile(delete=False)
        urllib.request.urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
            train_file.name)  # pylint: disable=line-too-long
        train_file_name = train_file.name
        train_file.close()
        print("Training data is downloaded to %s" % train_file_name)
    
    if test_data:
        test_file_name = test_data
    else:
        test_file = tempfile.NamedTemporaryFile(delete=False)
        urllib.request.urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
            test_file.name)  # pylint: disable=line-too-long
        test_file_name = test_file.name
        test_file.close()
        print("Test data is downloaded to %s"% test_file_name)
    
    return train_file_name, test_file_name

In [6]:
def build_estimator(model_dir, model_type):
    """Build an estimator."""
    if model_type == "wide":
        m = tf.estimator.LinearClassifier(
            model_dir=model_dir, feature_columns=base_columns + crossed_columns)
    elif model_type == "deep":
        m = tf.estimator.DNNClassifier(
            model_dir=model_dir,
            feature_columns=deep_columns,
            hidden_units=[100, 50])
    else:
        m = tf.estimator.DNNLinearCombinedClassifier(
            model_dir=model_dir,
            linear_feature_columns=crossed_columns,
            dnn_feature_columns=deep_columns,
            dnn_hidden_units=[100, 50])
    return m

In [7]:
def input_fn(data_file, num_epochs, shuffle):
    """Input builder function."""
    df_data = pd.read_csv(
        tf.gfile.Open(data_file),
        names=CSV_COLUMNS,
        skipinitialspace=True,
        engine="python",
        skiprows=1)
    # remove NaN elements
    df_data = df_data.dropna(how="any", axis=0)
    labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
    return tf.estimator.inputs.pandas_input_fn(
        x=df_data,
        y=labels,
        batch_size=100,
        num_epochs=num_epochs,
        shuffle=shuffle,
        num_threads=5)

In [8]:
def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
    """Train and evaluate the model."""
    train_file_name, test_file_name = maybe_download(train_data, test_data)
    # Specify file path below if want to find the output easily
    model_dir = tempfile.mkdtemp() if not model_dir else model_dir
  
    m = build_estimator(model_dir, model_type)
    # set num_epochs to None to get infinite stream of data.
    m.train(
        input_fn=input_fn(train_file_name, num_epochs=None, shuffle=True),
        steps=train_steps)
    # set steps to None to run evaluation until all data consumed.
    results = m.evaluate(
        input_fn=input_fn(test_file_name, num_epochs=1, shuffle=False),
        steps=None)
    print("model directory = %s" % model_dir)
    for key in sorted(results):
        print("%s: %s" % (key, results[key]))
    return m

In [9]:
m=train_and_eval(model_type='wide_n_deep',
                 train_steps=200,
                 train_data='',
                 test_data='',
                 model_dir='./tmp')

Training data is downloaded to /var/folders/qv/glzl2pyj2g15vz3tn5s52c9w0000gn/T/tmpcdmdfssx
Test data is downloaded to /var/folders/qv/glzl2pyj2g15vz3tn5s52c9w0000gn/T/tmpc7mwxlnb
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_max': 5, '_session_config': None, '_save_summary_steps': 100, '_log_step_count_steps': 100, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_save_checkpoints_steps': None, '_model_dir': './tmp', '_save_checkpoints_secs': 600}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt-200
INFO:tensorflow:Saving checkpoints for 201 into ./tmp/model.ckpt.
INFO:tensorflow:loss = 42.5265, step = 201
INFO:tensorflow:global_step/sec: 118.392
INFO:tensorflow:loss = 53.4587, step = 301 (0.846 sec)
INFO:tensorflow:Saving checkpoints for 400 into ./tmp/model.ckpt.
INFO:tensorflow:Loss for final step: 54.8437.
INFO:tensorflow:Starting evaluation at 2017-09-07-01:49:45
INFO:ten

# Export TensorFlow Model using .export_savedmodel()

In [10]:
# I step
feature_columns = crossed_columns + deep_columns

# II step
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)

# III step
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

# IV step
servable_model_dir = "./serving_savemodel"
servable_model_path = m.export_savedmodel(servable_model_dir, export_input_fn)
servable_model_path

INFO:tensorflow:Restoring parameters from ./tmp/model.ckpt-400
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'./serving_savemodel/1504748998/saved_model.pb'


b'./serving_savemodel/1504748998'