# Overview

This example demonstrates the use `tf.feature_column.crossed_column` on some simulated Atlanta housing price data. 
This spatial data is used primarily so the results can be easily visualized. 

These functions are designed primarily for categorical data, not to build interpolation tables. 

If you actually want to build smart interpolation tables in TensorFlow you may want to consider [TensorFlow Lattice](https://research.googleblog.com/2017/10/tensorflow-lattice-flexibility.html).

# Imports

In [None]:
# Builtin.
import os
import subprocess
import tempfile

In [None]:
# Third Party
import tensorflow as tf
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
#Local
import synthetic_data

In [None]:
assert tf.VERSION.split('.') >= ['1','4']

In [None]:
tf.logging.set_verbosity(tf.logging.INFO)

In [None]:
%matplotlib inline
mpl.rcParams['figure.figsize'] = 12, 6
mpl.rcParams['image.cmap'] = 'viridis'

In [None]:
logdir = tempfile.mkdtemp()

In [None]:
logdir

# Build Synthetic Data

In [None]:
# Define the grid
resolution = 100
atlanta = synthetic_data.Grid(
    latitude=synthetic_data.Linspace(33.641336, 33.887157, resolution),
    longitude=synthetic_data.Linspace(-84.558798, -84.287259, resolution),
)

In [None]:
# This blobs function expects inputs in `[0,1]`
_price_fun = synthetic_data.Blobs(20)

# Add a wrapper that normalizes inputs from the `atlanta` grid to `[0,1]`
def atlanta_price_fn(latitude, longitude):
  latitude_norm, longitude_norm = atlanta.normalize(latitude, longitude)
    
  return _price_fun(x=longitude_norm, y=latitude_norm)

In [None]:
# Evaluate the price at each center-point
latitude_centers, longitude_centers = atlanta.center_mesh()
actual_price_grid = atlanta_price_fn(latitude_centers, longitude_centers)

In [None]:
plotter = synthetic_data.GridPlotter(atlanta, vmin=actual_price_grid.min(), vmax=actual_price_grid.max())

In [None]:
plotter(actual_price_grid)

# Build Datasets

In [None]:
def make_dataset(latitude, longitude, labels):
    assert latitude.shape == longitude.shape == labels.shape

    features = {'latitude': latitude.flatten(),
                'longitude': longitude.flatten()}
    labels=labels.flatten()

    return tf.data.Dataset.from_tensor_slices((features, labels))

In [None]:
# For the test data we will use the actual price grid.
test_ds = make_dataset(latitude_grid, longitude_grid, actual_price_grid)
test_ds = test_ds.cache().batch(512).prefetch(1)

# For training data we will use a set of random points.
train_latitude, train_longitude = atlanta.denormalize(np.random.rand(50000),np.random.rand(50000))
train_price = atlanta_price_fn(train_latitude, train_longitude)

train_ds = make_dataset(train_latitude, train_longitude, train_price)
train_ds = train_ds.cache().repeat().shuffle(100000).batch(512).prefetch(1)

# A shortcut to build an `input_fn` from a `Dataset`
def dataset_input_fn(ds):
    return lambda : ds.make_one_shot_iterator().get_next()


# Generate a plot from an Estimator

In [None]:
def plot_est(est, ds = test_ds):
    # Create two plot axes
    actual, predicted = plt.subplot(1,2,1), plt.subplot(1,2,2)

    # Plot the actual price.
    plt.sca(actual)
    plotter(actual_price_grid)
    
    # Generate predictions over the grid from the estimator.
    pred =  est.predict(dataset_input_fn(ds))
    # Convert them to a numpy array.
    pred = np.fromiter((item['predictions'] for item in pred), np.float32)
    # Plot the predictions on the secodn axis.
    plt.sca(predicted)
    plotter(pred.reshape(resolution, resolution))

# Using `numeric_column` with DNNRegressor
In this case the data has spatial relationships that the `DNNRegressor` can exploit to make good predictions. Pure categorical data doesn't have these relationships. Embeddings are a way your model can _learn_ spatial relationships.

In [None]:
# Use `normalizer_fn` so that the model only sees values in [0, 1]
fc = [tf.feature_column.numeric_column('latitude', normalizer_fn = atlanta.latitude.normalize), 
      tf.feature_column.numeric_column('longitude', normalizer_fn = atlanta.longitude.normalize)]

# Build and train the Estimator
est = tf.estimator.DNNRegressor(
    hidden_units=[100,100], 
    feature_columns=fc, 
    model_dir = os.path.join(logdir,'DNN'))

est.train(dataset_input_fn(train_ds), steps = 5000)
est.evaluate(dataset_input_fn(test_ds))

In [None]:
plot_est(est)

# Using `bucketized_column`
`bucketized_column` on it's own defines a seperable function, as the sum of marginals over the two axes.

In [None]:
# Bucketize the latitude and longitude usig the `edges`
latitude_bucket_fc = tf.feature_column.bucketized_column(
    tf.feature_column.numeric_column('latitude'), 
    list(atlanta.latitude.edges))

longitude_bucket_fc = tf.feature_column.bucketized_column(
    tf.feature_column.numeric_column('longitude'),
    list(atlanta.longitude.edges))

fc = [
    latitude_bucket_fc,
    longitude_bucket_fc]

# Build and train the Estimator.
est = tf.estimator.LinearRegressor(fc, model_dir = os.path.join(logdir,'separable'))
est.train(dataset_input_fn(train_ds), steps = 5000)
est.evaluate(dataset_input_fn(test_ds))

In [None]:
plot_est(est)

# Using `crossed_column` on its own.
Using `crossed_column` defines a joint function over the two axes, with some random weight sharing caused by the fact that `crossed_columns` use hashing, like `categorical_column_with_hash_bucket`.

The single-cell "holes" in the figure are caused by cells which do not contain examples.

In [None]:
# Cross the bucketized columns, using 5000 hash bins (for an average weight sharing of 2).
crossed_lat_lon_fc = tf.feature_column.crossed_column(
    [latitude_bucket_fc, longitude_bucket_fc], int(5e3))

fc = [crossed_lat_lon_fc]

# Build and train the Estimator.
est = tf.estimator.LinearRegressor(fc, model_dir=os.path.join(logdir, 'crossed'))

est.train(dataset_input_fn(train_ds), steps = 5000)
est.evaluate(dataset_input_fn(test_ds))

In [None]:
plot_est(est)

# Using raw categories with `crossed_column` 
The model generalizes better if it also has access to the marginal categories, outside of the `crossed_column`. In this case the marginal columns learn averages, similar to when se only used the marginals (`bucketed_column`). While the `crossed_column` learns how individual cells deviate from those averages. This also mitigates hash collisions, as the model has access to other features it can use to distinguish between examples that colide.

In [None]:
fc = [
    latitude_bucket_fc,
    longitude_bucket_fc,
    crossed_lat_lon_fc]

# Build and train the Estimator.
est = tf.estimator.LinearRegressor(fc, model_dir=os.path.join(logdir, 'both'))
est.train(dataset_input_fn(train_ds), steps = 5000)
est.evaluate(dataset_input_fn(test_ds))

In [None]:
plot_est(est)

# Open TensorBoard

## Start TensorBoard
The following command will kill all running TensorBoard processes, and start a new one monitoring to the above logdir. 

In [None]:
%env LOGDIR={logdir}

In [None]:
%%bash --bg
pkill -f tensorboard
tensorboard --logdir $LOGDIR

In [None]:
%%html
<iframe width="900" height="800" src="http://0.0.0.0:6006#scalars&_smoothingWeight=0.85" frameborder="0"></iframe>