In [1]:
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

In [2]:
import os
BUCKET = 'going-tfx'
PROJECT = 'going-tfx'
REGION = 'us-east-1'
BQ_DATASET = 'examples'
#os.environ['BUCKET'] = BUCKET
#os.environ['PROJECT'] = PROJECT
#os.environ['REGION'] = REGION

In [3]:
def sample_query(project, dataset):
    return """SELECT
      ORIGIN,
      FL_YEAR,
      FL_MONTH,
      FL_DOW,
      UNIQUE_CARRIER,
      DEST,
      CRS_ARR_TIME,
      DEP_DELAY,
      ARR_DELAY
    FROM `{}.{}.ATL_JUNE` 
    where
      MOD(ABS(FARM_FINGERPRINT(
        CONCAT(
          STRING(TIMESTAMP(FL_DATE)),
          UNIQUE_CARRIER,
          DEST
        )
      )) + CRS_ARR_TIME, 10000) >= {} and 
      MOD(ABS(FARM_FINGERPRINT(
        CONCAT(
          STRING(TIMESTAMP(FL_DATE)),
          UNIQUE_CARRIER,
          DEST
        )
      )) + CRS_ARR_TIME, 10000) < {} 
    """.format(project, dataset, '{}', '{}')

In [4]:
def create_queries(training_percentage, eval_percentage):
    """
        returns 3 queries that return distinct samples of the ATL_JUNE table.
        Use these for your convenience to define train,eval,test splits
    """
    cut1 = int(100 * training_percentage)
    cut2 = cut1 + int(100 * eval_percentage)
    query = sample_query(PROJECT, BQ_DATASET)
    q1 = query.format(0, cut1)
    q2 = query.format(cut1+1, cut2)
    q3 = query.format(cut2+1, 9999)
    return q1, q2, q3

---
The below query retrieves only 1/10'000 of the 400k entries, that's about 40-50

In [5]:
query_0001, e, s = create_queries(0.01,0.02)
print(query_0001)

SELECT
      ORIGIN,
      FL_YEAR,
      FL_MONTH,
      FL_DOW,
      UNIQUE_CARRIER,
      DEST,
      CRS_ARR_TIME,
      DEP_DELAY,
      ARR_DELAY
    FROM `going-tfx.examples.ATL_JUNE` 
    where
      MOD(ABS(FARM_FINGERPRINT(
        CONCAT(
          STRING(TIMESTAMP(FL_DATE)),
          UNIQUE_CARRIER,
          DEST
        )
      )) + CRS_ARR_TIME, 10000) >= 0 and 
      MOD(ABS(FARM_FINGERPRINT(
        CONCAT(
          STRING(TIMESTAMP(FL_DATE)),
          UNIQUE_CARRIER,
          DEST
        )
      )) + CRS_ARR_TIME, 10000) < 1 
    


In [6]:
# An alternative way of getting some data
# sample = pd.read_csv(os.path.join(DATA_DIR, "atl_june_46.csv"));

In [7]:
import google.datalab.bigquery as bq
sample = bq.Query(query_0001).execute().result().to_dataframe()
sample[:4]

Unnamed: 0,ORIGIN,FL_YEAR,FL_MONTH,FL_DOW,UNIQUE_CARRIER,DEST,CRS_ARR_TIME,DEP_DELAY,ARR_DELAY
0,ATL,2009,6,7,AA,MIA,1610,0,24
1,ATL,2006,6,1,DL,ABQ,2307,20,8
2,ATL,2007,6,3,DL,SFO,1153,8,-7
3,ATL,2008,6,7,DL,ABQ,1255,3,12


---
Metadata and schema

In [8]:
records = sample.to_dict(orient='records')
records[:1]

[{u'ARR_DELAY': 24,
  u'CRS_ARR_TIME': 1610,
  u'DEP_DELAY': 0,
  u'DEST': 'MIA',
  u'FL_DOW': 7,
  u'FL_MONTH': 6,
  u'FL_YEAR': 2009,
  u'ORIGIN': 'ATL',
  u'UNIQUE_CARRIER': 'AA'}]

In [9]:
def tft_metadata(columns):
    """
    columns: dict of column names and tf.types
    """
    import tensorflow as tf
    from tensorflow_transform.tf_metadata import (dataset_metadata, dataset_schema)

    raw_data_metadata = dataset_metadata.DatasetMetadata(
        dataset_schema.from_feature_spec( {
        col: tf.FixedLenFeature([], _type) for col, _type in columns.items()
    }))
    return raw_data_metadata

In [10]:
import tensorflow as tf
raw_data_metadata = tft_metadata({
    'ORIGIN': tf.string, 
    'FL_YEAR': tf.int64,
    'FL_MONTH': tf.int64,
    'FL_DOW': tf.int64,
    'UNIQUE_CARRIER': tf.string,
    'DEST': tf.string,
    'CRS_ARR_TIME': tf.int64,
    'DEP_DELAY': tf.float32,
    'ARR_DELAY': tf.float32
})

---
The pre-processing function scales the arrival delay and lets all other columns unchanged. Pay particular attention to the name of the returned ARR_DELAY tensor

In [11]:
def preprocessing_fn(inputs):
    import tensorflow as tf
    import tensorflow_transform as tft
    # print(inputs)
    arr_delay=tft.scale_to_0_1(inputs['ARR_DELAY'])
    res = {'ARR_DELAY': arr_delay}
    for col in ['ORIGIN', 'FL_YEAR', 'FL_MONTH', 'FL_DOW', 'UNIQUE_CARRIER', 'DEST', 'CRS_ARR_TIME', 'DEP_DELAY']:
        res[col] = tf.identity(inputs[col])
    return res

In [12]:
preprocessing_fn(records[0])

{'ARR_DELAY': <tf.Tensor 'scale_by_min_max/add:0' shape=() dtype=float32>,
 'CRS_ARR_TIME': <tf.Tensor 'Identity_6:0' shape=() dtype=int32>,
 'DEP_DELAY': <tf.Tensor 'Identity_7:0' shape=() dtype=int32>,
 'DEST': <tf.Tensor 'Identity_5:0' shape=() dtype=string>,
 'FL_DOW': <tf.Tensor 'Identity_3:0' shape=() dtype=int32>,
 'FL_MONTH': <tf.Tensor 'Identity_2:0' shape=() dtype=int32>,
 'FL_YEAR': <tf.Tensor 'Identity_1:0' shape=() dtype=int32>,
 'ORIGIN': <tf.Tensor 'Identity:0' shape=() dtype=string>,
 'UNIQUE_CARRIER': <tf.Tensor 'Identity_4:0' shape=() dtype=string>}

In [13]:
ORDERED_COLS=['ORIGIN', 'FL_YEAR', 'FL_MONTH', 'FL_DOW', 'UNIQUE_CARRIER', 'DEST', 'CRS_ARR_TIME', 'DEP_DELAY', 'ARR_DELAY']

In [25]:
def extract_to_tfrecord(in_test_mode=True):
    import tempfile
    import datetime
    import apache_beam as beam
    import tensorflow_transform as tft
    import tensorflow_transform.beam.impl as tft_beam

    BASE_DIR='gs://going-tfx/tutorials/tft/'
    PROJECT='going-tfx'
    in_test_mode = True
    job_name = 'tft-tutorial' + '-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')    

    options = {
        'staging_location': os.path.join(BASE_DIR, 'staging'),
        'temp_location': os.path.join(BASE_DIR, 'tmp'),
        'job_name': job_name,
        'project': PROJECT,
        'max_num_workers': 24,
        'teardown_policy': 'TEARDOWN_ALWAYS',
        'no_save_main_session': True,
        'requirements_file': 'requirements.txt'
    }

    opts = beam.pipeline.PipelineOptions(flags=[], **options)
    if in_test_mode:
        RUNNER = 'DirectRunner'
    else:
        RUNNER = 'DataflowRunner'    
    
    filebase_tfr="gs://going-tfx/tutorials/tft/tfr/q0001"
    filebase_csv="gs://going-tfx/tutorials/tft/csv/q0001"

    with beam.Pipeline(RUNNER, options=opts) as p:
        with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
            raw_data = p | "ReadFromBQ"  >> beam.io.Read(beam.io.BigQuerySource(query=query_0001, use_standard_sql=True)) 

            tds, tfn = (
                (raw_data, raw_data_metadata)    
                | 'Transform' >> tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))

            td, tmd = tds

            encoder = tft.coders.ExampleProtoCoder(tmd.schema)

            _ = (td
                 | 'EncodeTFRecord' >> beam.Map(encoder.encode)
                 | 'WriteTFRecord' >> beam.io.WriteToTFRecord(filebase_tfr))
        
    with beam.Pipeline(RUNNER, options=opts) as p:
        with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
            csv_encode = tft.coders.CsvCoder(ORDERED_COLS, raw_data_metadata.schema).encode    
            tft_data = (p 
                        | "ReadFromTFRecord" >> beam.io.ReadFromTFRecord(coder=encoder, file_pattern=filebase_tfr+"*")
                        | beam.Map(csv_encode)
                        | beam.io.WriteToText(file_path_prefix=filebase_csv))

                        

In [26]:
tf.logging.set_verbosity(tf.logging.WARN)
extract_to_tfrecord(in_test_mode=False)



In [29]:
!gsutil cp gs://going-tfx/tutorials/tft/csv/* /tmp/tfr.csv

Copying gs://going-tfx/tutorials/tft/csv/q0001-00000-of-00001...
/ [1 files][  2.0 KiB/  2.0 KiB]                                                
Operation completed over 1 objects/2.0 KiB.                                      


In [30]:
!cat /tmp/tfr.csv

ATL,2009,6,7,AA,MIA,1610,0.0,0.181159421802
ATL,2006,6,1,DL,ABQ,2307,20.0,0.123188406229
ATL,2016,6,1,DL,MSN,1134,-3.0,0.0326086953282
ATL,2008,6,2,DL,LAX,1240,7.0,0.137681156397
ATL,2010,6,2,DL,ONT,2040,52.0,0.242753624916
ATL,2010,6,3,DL,PHL,1206,-1.0,0.105072468519
ATL,2016,6,3,DL,DAL,1744,267.0,1.0
ATL,2017,6,3,DL,ORD,1231,-5.0,0.00724637694657
ATL,2014,6,3,DL,LAX,2219,10.0,0.108695656061
ATL,2007,6,3,DL,SFO,1153,8.0,0.0688405781984
ATL,2016,6,4,DL,ROA,2317,-2.0,0.0326086953282
ATL,2017,6,4,DL,DEN,1059,-5.0,0.0289855077863
ATL,2016,6,5,DL,TPA,1100,-2.0,0.0760869607329
ATL,2008,6,5,DL,DEN,1206,13.0,0.202898561954
ATL,2011,6,6,DL,PWM,2311,-3.0,0.0978260859847
ATL,2014,6,6,DL,JAN,1110,-4.0,0.0724637657404
ATL,2014,6,6,DL,PIT,2337,30.0,0.155797109008
ATL,2013,6,6,DL,BWI,1819,-6.0,0.0108695654199
ATL,2015,6,7,DL,CLE,1541,26.0,0.0942028984427
ATL,2008,6,7,DL,ABQ,1255,3.0,0.137681156397
ATL,2009,6,1,EV,MLU,1213,-4.0,0.0615942031145
ATL,2011,6,1,EV,AEX,2121,43.0,0.278985500336
ATL,2009,6,2