In [None]:
import os
import logging
from enum import Enum
from typing import List, Optional, Text, Union, Dict, Iterable,Mapping

import apache_beam as beam

import mlflow

from apache_beam.options.pipeline_options import PipelineOptions
import datetime
from datetime import timedelta
from google.cloud import bigquery
from jinja2 import Template
import json
import numpy as np

from google.protobuf import text_format
from tensorflow_metadata.proto.v0 import statistics_pb2
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import anomalies_pb2
import tensorflow_data_validation as tfdv
from tensorflow_data_validation.utils import io_util
from tensorflow_data_validation import constants
from tensorflow_data_validation import types

In [None]:
experiment_name = "chicago-taxi"
mlflow.set_experiment(experiment_name)

mlflow_tracking_uri = mlflow.get_tracking_uri()
MLFLOW_TRACKING_EXTERNAL_URI = os.environ["MLFLOW_TRACKING_EXTERNAL_URI"]

REGION=os.environ["MLOPS_REGION"]
ML_IMAGE_URI = os.environ["ML_IMAGE_URI"]
COMPOSER_NAME = os.environ["MLOPS_COMPOSER_NAME"]
MLFLOW_GCS_ROOT_URI = os.environ["MLFLOW_GCS_ROOT_URI"]
DATASET_GCS_FOLDER = MLFLOW_GCS_ROOT_URI+"/data"
PROJECT_ID = os.getenv("GCP_PROJECT", "edgeml-demo")

In [None]:
#Baseline statistics
#TODO: Create BQ -> CSV sample file

statopt=tfdv.StatsOptions(num_histogram_buckets=5)
baseline_stats = tfdv.generate_statistics_from_csv(
    data_location=DATASET_GCS_FOLDER+'/stats_source/ds_bq_data_statistics_*.csv',
    stats_options = statopt,
    delimiter='|'
)
reference_schema = tfdv.infer_schema(baseline_stats)
schema_text = text_format.MessageToString(reference_schema)
io_util.write_string_to_file(DATASET_GCS_FOLDER+'/taxi_schema.pbtxt', schema_text)

#tfdv.display_schema(schema=reference_schema)

In [None]:
import json
import apache_beam as beam
import numpy as np

from typing import List, Optional, Text, Union, Dict, Iterable, Mapping
from tensorflow_data_validation import types
from tensorflow_data_validation import constants
from tensorflow_metadata.proto.v0 import schema_pb2

_RAW_DATA_COLUMN = 'raw_data'
_INSTANCES_KEY = 'instances'
_TIMESTAMP_KEY = 'trip_start_timestamp'

_SCHEMA_TO_NUMPY = {
    schema_pb2.FeatureType.BYTES:  np.str,
    schema_pb2.FeatureType.INT:    np.int64,
    schema_pb2.FeatureType.FLOAT:  np.float
}


@beam.typehints.with_input_types(Dict)
@beam.typehints.with_output_types(types.BeamExample)
class InstanceCoder(beam.DoFn):
    """A DoFn which converts an taxi row to types.BeamExample elements."""

    def __init__(self, 
        schema: schema_pb2, 
        end_time: datetime=None, 
        time_window: datetime=None,
        slicing_column: str=None):

        self._example_size = beam.metrics.Metrics.counter(
            constants.METRICS_NAMESPACE, "example_size")

        self._features = {}
        for feature in schema.feature:
            if not feature.type in _SCHEMA_TO_NUMPY.keys():
                raise ValueError(
                    "Unsupported feature type: {}".format(feature.type))
            if feature.HasField('presence') and feature.presence.min_fraction == 1.0:
                self._features[feature.name] = _SCHEMA_TO_NUMPY[feature.type]
            else:
                # What is the NumPy {optional int64} data type?
                self._features[feature.name] = np.object

        if end_time and time_window and slicing_column:
            self._end_time = end_time
            self._time_window = time_window
            self._slicing_column = slicing_column 
        else:
            self._slicing_column = None

    def _get_time_slice(self, time_stamp: str) -> str:
        """
        Assigns a time stamp to a time slice.

        Args:
            time_stamp: A date_time string in the ISO YYYY-MM-DDTHH:MM:SS format
        Returns:
            A time slice as a string in the following format:
            YYYY-MM-DDTHH:MM_YYYY-MM-DDTHH:MM
        """

        time_stamp = datetime.strptime(time_stamp, '%Y-%m-%dT%H:%M:%S')

        q = (self._end_time - time_stamp) // self._time_window
        slice_end = self._end_time - q * self._time_window
        slice_begining = self._end_time - (q + 1) * self._time_window

        return (slice_begining.strftime('%Y-%m-%dT%H:%M') + '_' +
                slice_end.strftime('%Y-%m-%dT%H:%M'))

    def _parse_raw_instance(self, raw_instance: Union[list, dict]) -> dict:
        if type(raw_instance) is dict:
            instance = {name: np.array(value if type(value) == list else [value], dtype=self._features[name])
                        for name, value in raw_instance.items()}
        elif type(raw_instance) is list:
            instance = {name: np.array([value], dtype=self._features[name])
                        for name, value in zip(list(self._features.keys()), raw_instance)}
        else:
            raise TypeError(
                "Unsupported input instance format. Only JSON list or JSON object instances are supported")

        return instance

    def process(self, raw_instance: Dict) -> Iterable:
        instance = self._parse_raw_instance(raw_instance)
        if self._slicing_column:
            instance[self._slicing_column] = np.array(
                [self._get_time_slice(raw_instance[_TIMESTAMP_KEY])])
        yield instance


In [None]:
_STATS_FILENAME = 'stats.pb'
_ANOMALIES_FILENAME = 'anomalies.pbtxt'
_SLICING_COLUMN_NAME = 'time_slice'
_SLICING_COLUMN_TYPE = schema_pb2.FeatureType.BYTES

def _alert_if_anomalies(anomalies: anomalies_pb2.Anomalies, output_path: str):
    """
    Analyzes an anomaly protobuf and reports the status.
    Currently, the function just writes to a default Python logger.
    A more comprehensive alerting strategy will be considered in the future.
    """

    if list(anomalies.anomaly_info):
        logging.warn("Anomalies detected. The anomaly report uploaded to: {}".format(output_path))
    else:
        logging.info("No anomalies detected.")
    
    return anomalies

def _generate_query(sampling_query_template: str, start_time: str, end_time: str) -> str:
    """
    Generates a query that extracts a time series of records between start and end time.
    """    
    query = Template(sampling_query_template).render( 
        start_time=start_time, 
        end_time=end_time)

    return query


In [None]:
def generate_statistics(
    query: str,
    output_path: str,
    start_time: datetime,
    end_time: datetime,
    schema: schema_pb2.Schema,
    baseline_stats: Optional[statistics_pb2.DatasetFeatureStatisticsList]=None,
    time_window: Optional[timedelta]=None,
    pipeline_options: Optional[PipelineOptions] = None,
): 
    
    end_time = end_time.replace(second=0, microsecond=0)
    start_time = start_time.replace(second=0, microsecond=0)
    query = _generate_query(
        sampling_query_template=query, 
        start_time=start_time.strftime('%Y-%m-%dT%H:%M:%S'), 
        end_time=end_time.strftime('%Y-%m-%dT%H:%M:%S'))
    
    # Configure slicing for statistics calculations
    stats_options = tfdv.StatsOptions(schema=schema)
    slicing_column = None
    if time_window:
        time_window = timedelta(
            days=time_window.days,
            seconds=(time_window.seconds // 60) * 60)

        if end_time - start_time > time_window:
            slice_fn = tfdv.get_feature_value_slicer(features={_SLICING_COLUMN_NAME: None})
            stats_options.slice_functions=[slice_fn]
            slicing_column = _SLICING_COLUMN_NAME 
            slicing_feature = schema.feature.add()
            slicing_feature.name = _SLICING_COLUMN_NAME
            slicing_feature.type = _SLICING_COLUMN_TYPE

    # Configure output paths 
    stats_output_path = os.path.join(output_path, _STATS_FILENAME)
    anomalies_output_path = os.path.join(output_path, _ANOMALIES_FILENAME)
    
    # Define an start the pipeline
    with beam.Pipeline(options=pipeline_options) as p:
        raw_examples = (p
           | 'GetData' >> beam.io.Read(beam.io.ReadFromBigQuery(query=query, gcs_location=output_path,project=PROJECT_ID, use_standard_sql=True)))

        examples = (raw_examples
           | 'InstancesToBeamExamples' >> beam.ParDo(InstanceCoder(schema, end_time, time_window, slicing_column)))

        stats = (examples
           | 'BeamExamplesToArrow' >> tfdv.utils.batch_util.BatchExamplesToArrowRecordBatches()
           | 'GenerateStatistics' >> tfdv.GenerateStatistics(options=stats_options))
        
        _ = (stats
            | 'WriteStatsOutput' >> beam.io.WriteToTFRecord(
                file_path_prefix=stats_output_path,
                shard_name_template='',
                coder=beam.coders.ProtoCoder(
                    statistics_pb2.DatasetFeatureStatisticsList)))

        anomalies = (stats
            | 'ValidateStatistics' >> beam.Map(tfdv.validate_statistics, schema=schema, previous_statistics=baseline_stats))

        _ = (anomalies
            | 'AlertIfAnomalies' >> beam.Map(_alert_if_anomalies, anomalies_output_path)
            | 'WriteAnomaliesOutput' >> beam.io.textio.WriteToText(
                file_path_prefix=anomalies_output_path,
                shard_name_template='',
                append_trailing_newlines=False))


In [None]:
from tensorflow_data_validation import StatsOptions
from tensorflow_data_validation import load_statistics
from tensorflow_data_validation import load_schema_text

_SETUP_FILE = './setup.py'

sampling_query_template = """
    SELECT trip_start_timestamp, 
        unique_key, taxi_id, trip_end_timestamp, trip_seconds, trip_miles, pickup_census_tract, 
        dropoff_census_tract, pickup_community_area, dropoff_community_area, fare, tips, tolls, extras, trip_total, 
        payment_type, company, pickup_latitude, pickup_longitude, pickup_location, dropoff_latitude, dropoff_longitude, dropoff_location
    FROM `bigquery-public-data.chicago_taxi_trips.taxi_trips`
        WHERE trip_start_timestamp BETWEEN '{{ start_time }}' AND '{{ end_time }}'
    LIMIT 1000
"""

baseline_stats = None
#if known_args.baseline_stats_file:
#    baseline_stats = load_statistics(known_args.baseline_stats_file)

#schema = tfdv.infer_schema(statistics=train_stats)

# Load back the saved reference schema
loaded_schema = load_schema_text(f'{DATASET_GCS_FOLDER}/taxi_schema.pbtxt')
tfdv.display_schema(schema=loaded_schema)

generate_statistics(
    query= sampling_query_template,
    output_path= f'{DATASET_GCS_FOLDER}/stats',
    start_time= datetime.datetime(1990,1,1),
    end_time= datetime.datetime.today(),
    schema= loaded_schema,
    baseline_stats= baseline_stats,
    time_window =None)