In [None]:
import os

PROJECT_ID = 'qwiklabs-gcp-da02053fb2a13c97' # CHANGE THIS
BUCKET = 'qwiklabs-gcp-da02053fb2a13c97' # CHANGE THIS
MODEL_BASE = 'taxi_trained/export/exporter'
MODEL_PATH = os.path.join(MODEL_BASE,os.listdir(MODEL_BASE)[-1])
MODEL_NAME = 'taxifare'
VERSION_NAME = 'v1'

# Deploy for Online Prediction

To get our predictions, in addition to the features provided by the client, we also need to fetch the latest traffic information from BigQuery. We then combine these and invoke our tensorflow model. This is visualized by the 'on-demand' portion (red arrows) in the below diagram:

<img src="assets/architecture.png" >

To do this we'll take advantage of [AI Platforms Custom Prediction Routines](https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routines) which allows us to execute custom python code in response to every online prediction request. There are 5 steps to creating a custom prediction routine:

1. Upload Model Artifacts to GCS
2. Implement Predictor interface 
3. Package the prediction code and dependencies
4. Deploy
5. Invoke API

## 1. Upload Model Artifacts to GCS

Here we upload our model weights so that AI Platform can access them.

In [None]:
!gcloud storage cp --recursive $MODEL_PATH/* gs://$BUCKET/taxifare/model/

## 2. Implement Predictor Interface

Interface Spec: https://cloud.google.com/ml-engine/docs/tensorflow/custom-prediction-routines#predictor-class

This tells AI Platform how to load the model artifacts, and is where we specify our custom prediction code.

Note: the correct PROJECT_ID will automatically be inserted using the bash `sed` command in the subsequent cell.

In [None]:
%%writefile predictor.py
import tensorflow as tf
from google.cloud import bigquery

PROJECT_ID = 'will_be_replaced'

class TaxifarePredictor(object):
    def __init__(self, predict_fn):
      self.predict_fn = predict_fn   
    
    def predict(self, instances, **kwargs):
        bq = bigquery.Client(PROJECT_ID)
        query_string = """
        SELECT
          *
        FROM
          `taxifare.traffic_realtime`
        ORDER BY
          time DESC
        LIMIT 1
        """
        trips = bq.query(query_string).to_dataframe()['trips_last_5min'][0]
        instances['trips_last_5min'] = [trips for _ in range(len(list(instances.items())[0][1]))]
        predictions = self.predict_fn(instances)
        return predictions['predictions'].tolist() # convert to list so it is JSON serialiable (requirement)
    

    @classmethod
    def from_path(cls, model_dir):
        predict_fn = tf.contrib.predictor.from_saved_model(model_dir,'predict')
        return cls(predict_fn)

In [None]:
!sed -i -e 's/will_be_replaced/{PROJECT_ID}/g' predictor.py

### Test Predictor Class Works Locally

In [None]:
import predictor

instances = {'dayofweek' : [6,5],
             'hourofday' : [12,11],
             'pickuplon' : [-73.99,-73.99], 
             'pickuplat' : [40.758,40.758],
             'dropofflat' : [40.742,40.758],
             'dropofflon' : [-73.97,-73.97]}

predictor = predictor.TaxifarePredictor.from_path(MODEL_PATH)
predictor.predict(instances)

## 3. Package Predictor Class and Dependencies

We must package the predictor as a tar.gz source distribution package. Instructions for this are specified [here](http://cloud.google.com/ml-engine/docs/custom-prediction-routines#predictor-tarball). The AI Platform runtime comes preinstalled with several packages [listed here](https://cloud.google.com/ml-engine/docs/tensorflow/runtime-version-list). However it does not come with `google-cloud-bigquery` so we list that as a dependency below.

In [None]:
%%writefile setup.py
from setuptools import setup

setup(
    name='taxifare_custom_predict_code',
    version='0.1',
    scripts=['predictor.py'],
    install_requires=[
        'google-cloud-bigquery==1.16.0',
    ])

In [None]:
!python setup.py sdist --formats=gztar

In [None]:
!gcloud storage cp dist/taxifare_custom_predict_code-0.1.tar.gz gs://$BUCKET/taxifare/predict_code/

## 4. Deploy

This is similar to how we deploy standard models to AI Platform, with a few extra command line arguments.

Note the use of the `--service-acount` parameter below.

The default service account does not have permissions to read from BigQuery, so we [specify a different service account](https://cloud.google.com/ml-engine/docs/tensorflow/deploying-models#service-account) that does have permission.

Specifically we use the [Compute Engine default service account](https://cloud.google.com/compute/docs/access/service-accounts#compute_engine_default_service_account) which has the IAM project editor role.

In [None]:
!gcloud beta ai-platform models create $MODEL_NAME --regions us-central1 --enable-logging --enable-console-logging
#!gcloud ai-platform versions delete $VERSION_NAME --model taxifare --quiet
!gcloud beta ai-platform versions create $VERSION_NAME \
  --model $MODEL_NAME \
  --origin gs://$BUCKET/taxifare/model \
  --service-account $(gcloud projects list --filter="$PROJECT_ID" --format="value(PROJECT_NUMBER)")-compute@developer.gserviceaccount.com \
  --runtime-version 1.14 \
  --python-version 3.5 \
  --package-uris gs://$BUCKET/taxifare/predict_code/taxifare_custom_predict_code-0.1.tar.gz \
  --prediction-class predictor.TaxifarePredictor

## 5. Invoke API

**Warning:** You will see `ImportError: file_cache is unavailable when using oauth2client >= 4.0.0 or google-auth` when you run this. While it looks like an error this is actually just a warning and is safe to ignore, the subsequent cell will still work.

In [None]:
import googleapiclient.discovery

instances = {'dayofweek' : [6], 
             'hourofday' : [12],
             'pickuplon' : [-73.99], 
             'pickuplat' : [40.758],
             'dropofflat' : [40.742],
             'dropofflon' : [-73.97]}

service = googleapiclient.discovery.build('ml', 'v1')
name = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID, MODEL_NAME, VERSION_NAME)

In [None]:
response = service.projects().predict(
    name=name,
    body={'instances': instances}
).execute()

if 'error' in response:
    raise RuntimeError(response['error'])
else:
  print(response['predictions'])

Try re-running the query again after 15 seconds (the windowing period for DataFlow), note how the prediction changes in response to the new traffic data!