In [1]:
# !pip install -U tfx
# !pip install apache-beam==2.39.0

In [2]:
from absl import logging
from google.protobuf.json_format import MessageToDict
from io import StringIO
from ml_metadata.proto import metadata_store_pb2
import numpy as np
import os
import pandas as pd
import pprint
import shutil
import tempfile
import tensorflow as tf
import tensorflow_data_validation as tfdv
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
from tfx import v1 as tfx
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.orchestration.experimental.interactive import visualizations, standard_visualizations
from tfx.orchestration.metadata import Metadata
from  tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts, standard_component_specs

# To ignore warnings from TF
tf.get_logger().setLevel('ERROR')
# Set default logging level
logging.set_verbosity(logging.INFO) 

# For formatting print statements
pp = pprint.PrettyPrinter()

standard_visualizations.register_standard_visualizations()

# Display versions of TF and TFX related packages
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
print('TensorFlow Data Validation version: {}'.format(tfdv.__version__))

TensorFlow version: 2.9.1
TFX version: 1.9.0
TensorFlow Data Validation version: 1.9.0


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Location of the data and model directory
DATA_DIR = '/content/drive/My Drive/Stroke Prediction ML System'
DATA_ROOT = f'{DATA_DIR}/data'

# Set the paths to the reduced dataset
DATA_DIR_SELECT = f'{DATA_ROOT}/select'
TRAINING_ROOT = f'{DATA_DIR}/training'
TESTING_ROOT = f'{DATA_DIR}/testing'
SERVING_ROOT = f'{DATA_DIR}/serving'

TRAINING_DATA = f'{TRAINING_ROOT}/stoke_prediction_training_dataset.csv'
TESTING_DATA = f'{TESTING_ROOT}/stoke_prediction_testing_dataset.csv'
SERVING_DATA = f'{SERVING_ROOT}/stoke_prediction_serving_dataset.csv'

# We will create a pipeline for schema generation
SCHEMA_PIPELINE_NAME = 'stroke-mlops-schema'

# Output directory to store artifacts generated from the pipeline
SCHEMA_PIPELINE_ROOT = os.path.join(DATA_DIR, 'pipeline', SCHEMA_PIPELINE_NAME)

# Path to curated schema file
SCHEMA_FOLDER = os.path.join(DATA_DIR, 'schema/schema_output')

# Path to a SQLite DB file to use as an MLMD storage
SCHEMA_METADATA_PATH = os.path.join(DATA_DIR, 'metadata', SCHEMA_PIPELINE_NAME, 'metadata.db')

# Set random seed
RANDOM_SEED = 0

In [5]:
cleaned_dataset = pd.read_csv(f'{DATA_DIR_SELECT}/healthcare-dataset-stroke-data-cleaned.csv')
cleaned_dataset.head()

Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
2,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
3,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1
4,Male,81.0,0,0,Yes,Private,Urban,186.21,29.0,formerly smoked,1


In [6]:
serving_dataset = pd.read_csv(f'{SERVING_DATA}')
serving_dataset.head()

Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status
0,Male,49.0,0,0,Yes,Self-employed,Urban,70.73,27.3,formerly smoked
1,Male,40.0,0,0,Yes,Private,Rural,144.48,29.8,smokes
2,Female,74.0,0,0,Yes,Govt_job,Urban,251.99,25.5,never smoked
3,Male,52.0,1,0,Yes,Govt_job,Urban,235.06,39.9,formerly smoked
4,Female,51.0,1,0,Yes,Govt_job,Urban,69.94,33.3,smokes


In [7]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     schema_path: str, metadata_path: str) -> tfx.dsl.Pipeline:                     
  """
  Creates a pipeline for data validation

  Args:
    pipeline_name: name of schema pipeline
    pipeline_root: path of schema pipeline directory
    data_root: path of data file directory
    schema_path: path of curated schema
    metadata_path: path of metadata

  Returns:
    Pipeline with example_gen, statistics_gen, schema_importer, and example_validator
  """

  # Brings data into the pipeline
  # Input: data file to read into the component
  # Output: ExampleGen component that can be used for statistics and schema generation
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation
  # Input: examples from example gen component
  # Output: StatisticsGen component that can be used to visualize statistics on the dataset  
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # Import the schema that has already been saved
  # Input: path of saved schema
  # Output: ImportSchemaGen component that can be used to visualize schema on the dataset      
  schema_gen = tfx.components.ImportSchemaGen(
      schema_file=f'{schema_path}/schema.pbtxt'
  )

  # Performs anomaly detection based on statistics and data schema
  # Input: statistics from statistics gen component and schema from import schema gen component
  # Output: ExampleValidator component to detect anomalies
  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_gen.outputs['schema'])

  components = [example_gen, statistics_gen, schema_gen, example_validator]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

In [8]:
tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=SCHEMA_PIPELINE_NAME,
      pipeline_root=SCHEMA_PIPELINE_ROOT,
      data_root=DATA_DIR_SELECT,
      schema_path=SCHEMA_FOLDER,
      metadata_path=SCHEMA_METADATA_PATH))

INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "ExampleValidator"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_validator.executor.Executor"
    }
  }
}
executor_specs {
  key: "ImportSchemaGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.schema_gen.import_schema_gen.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "CsvExampleGen"
  value {
    python_class_execu

INFO:absl:Processing input csv data /content/drive/My Drive/Stroke Prediction ML System/data/select/* to TFExample.
INFO:absl:Examples generated.
INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it
INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 43 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "/content/drive/My Drive/Stroke Prediction ML System/pipeline/stroke-mlops-schema/CsvExampleGen/examples/43"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:286686,xor_checksum:1659020936,sum_checksum:1659020936"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "stroke-mlops-schema:2022-07-28T15:19:36.68

In [9]:
def get_latest_artifacts(metadata, pipeline_name, component_id):
  """
  Output artifacts of the latest run of the component
  """
  
  # Get the latest executions of nodes within our pipeline and the designated component
  context = metadata.store.get_context_by_type_and_name('node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions, key=lambda e: e.last_update_time_since_epoch)
  
  return execution_lib.get_artifacts_dict(metadata, latest_execution.id, [metadata_store_pb2.Event.OUTPUT])

def visualize_artifacts(artifacts):
  """
  Visualizes artifacts using standard visualization modules
  """
  
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(artifact.type_name)
    
    if visualization:
      visualization.display(artifact)

In [10]:
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(SCHEMA_METADATA_PATH)

with Metadata(metadata_connection_config) as metadata_handler:
  
  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'StatisticsGen')
  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]

  schema_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'ImportSchemaGen')
  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]

  ev_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'ExampleValidator')
  anomalies_artifacts = ev_output[standard_component_specs.ANOMALIES_KEY]

INFO:absl:MetadataStore with DB connection initialized


In [11]:
visualize_artifacts(stats_artifacts)

In [12]:
visualize_artifacts(schema_artifacts)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'Residence_type',STRING,required,,'Residence_type'
'age',FLOAT,required,,min: 0.000000; max: 100.000000
'avg_glucose_level',FLOAT,required,,min: 25.000000; max: 300.000000
'bmi',FLOAT,required,,min: 0.000000; max: 200.000000
'ever_married',STRING,required,,'ever_married'
'gender',STRING,required,,'gender'
'heart_disease',INT,required,,min: 0; max: 1
'hypertension',INT,required,,min: 0; max: 1
'smoking_status',STRING,required,,'smoking_status'
'stroke',INT,required,,min: 0; max: 1


Unnamed: 0_level_0,Values
Domain,Unnamed: 1_level_1
'Residence_type',"'Rural', 'Urban'"
'ever_married',"'No', 'Yes'"
'gender',"'Female', 'Male', 'Other'"
'smoking_status',"'Unknown', 'formerly smoked', 'never smoked', 'smokes'"
'work_type',"'Govt_job', 'Never_worked', 'Private', 'Self-employed', 'children'"


In [13]:
visualize_artifacts(anomalies_artifacts)