In [1]:
# !pip install -U tfx
# !pip install apache-beam==2.39.0

## Imports and Paths

We want to start with initializing the paths for our data (training and serving) and for the outputs of our pipeline and schema.

In [2]:
import tensorflow as tf
from tfx import v1 as tfx

# TFX libraries
import tensorflow_data_validation as tfdv

# For performing feature selection and preprocessing and modeling
from sklearn.model_selection import train_test_split

# For feature visualization
import matplotlib.pyplot as plt 
import seaborn as sns

# Utilities
from tensorflow_metadata.proto.v0 import schema_pb2
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.orchestration.experimental.interactive import visualizations, standard_visualizations
from google.protobuf.json_format import MessageToDict
from  tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts, standard_component_specs
from tfx.orchestration.metadata import Metadata
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
import os
import pprint
import tempfile
import pandas as pd
import numpy as np
from io import StringIO
from absl import logging
import shutil

# To ignore warnings from TF
tf.get_logger().setLevel('ERROR')
# Set default logging level
logging.set_verbosity(logging.INFO) 

# For formatting print statements
pp = pprint.PrettyPrinter()

standard_visualizations.register_standard_visualizations()

# Display versions of TF and TFX related packages
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
print('TensorFlow Data Validation version: {}'.format(tfdv.__version__))

TensorFlow version: 2.9.1
TFX version: 1.9.0
TensorFlow Data Validation version: 1.9.0


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Location of the data and model directory
DATA_DIR = '/content/drive/My Drive/Stroke Prediction ML System'
DATA_ROOT = f'{DATA_DIR}/data'

# Set the paths to the reduced dataset
DATA_DIR_SELECT = f'{DATA_ROOT}/select'
TRAINING_ROOT = f'{DATA_DIR}/training'
TESTING_ROOT = f'{DATA_DIR}/testing'
SERVING_ROOT = f'{DATA_DIR}/serving'

TRAINING_DATA = f'{TRAINING_ROOT}/stoke_prediction_training_dataset.csv'
TESTING_DATA = f'{TESTING_ROOT}/stoke_prediction_testing_dataset.csv'
SERVING_DATA = f'{SERVING_ROOT}/stoke_prediction_serving_dataset.csv'

# We will create two pipelines. One for schema generation and one for training.
SCHEMA_PIPELINE_NAME = 'stroke-mlops-schema'
PIPELINE_NAME = 'stroke-mlops'

# Output directory to store artifacts generated from the pipeline.
SCHEMA_PIPELINE_ROOT = os.path.join(DATA_DIR, 'pipeline', SCHEMA_PIPELINE_NAME)
PIPELINE_ROOT = os.path.join(DATA_DIR, 'pipeline', PIPELINE_NAME)

# Path to a SQLite DB file to use as an MLMD storage.
SCHEMA_METADATA_PATH = os.path.join(DATA_DIR, 'metadata', SCHEMA_PIPELINE_NAME, 'metadata.db')
METADATA_PATH = os.path.join(DATA_DIR, 'metadata', PIPELINE_NAME, 'metadata.db')

# Set random seed
RANDOM_SEED = 0

# Data Cleaning

This section is meant to be the initial, messy data cleaning that helps to get to the first schema curated. It can be handling nulls or some other relatively basic task. Nulls may be able to be handled in schema curation, but haven't found a good example of that yet.

Then we save the datasets out to their respective folders, drop unneeded columns, and send to create the first schema.

In [5]:
stroke_dataset = pd.read_csv(f'{DATA_ROOT}/healthcare-dataset-stroke-data.csv')

stroke_dataset.head()

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [6]:
stroke_dataset.isna().sum()

id                     0
gender                 0
age                    0
hypertension           0
heart_disease          0
ever_married           0
work_type              0
Residence_type         0
avg_glucose_level      0
bmi                  201
smoking_status         0
stroke                 0
dtype: int64

In [7]:
stroke_dataset = stroke_dataset.dropna()

stroke_dataset.shape

(4909, 12)

In [8]:
train_test, holdout = train_test_split(stroke_dataset, test_size=.1, random_state=RANDOM_SEED)
train, test = train_test_split(train_test, test_size=.2, random_state=RANDOM_SEED)

stroke_dataset = stroke_dataset.drop(columns=['id'])
train = train.drop(columns=['id'])
test = test.drop(columns=['id'])
serving = holdout.drop(columns=['stroke', 'id'])

stroke_dataset.to_csv(f'{DATA_ROOT}/select/healthcare-dataset-stroke-data-cleaned.csv', index=False)
train.to_csv(f'{TRAINING_DATA}', index=False)
test.to_csv(f'{TESTING_DATA}', index=False)
serving.to_csv(f'{SERVING_DATA}', index=False)

In [9]:
train_test.head()

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
3688,38575,Male,58.0,1,0,Yes,Self-employed,Rural,209.15,52.9,formerly smoked,0
1605,39286,Female,35.0,0,0,Yes,Self-employed,Rural,151.25,28.4,Unknown,0
4514,46670,Female,75.0,1,0,Yes,Self-employed,Rural,197.06,26.1,never smoked,0
2857,4309,Female,23.0,0,0,Yes,Private,Rural,102.88,38.9,Unknown,0
3725,24202,Male,63.0,0,0,Yes,Private,Rural,78.23,34.8,never smoked,0


In [10]:
train.shape, test.shape, serving.shape

((3534, 11), (884, 11), (491, 10))

In [11]:
train_test.isna().sum()

id                   0
gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
stroke               0
dtype: int64

In [12]:
serving.isna().sum()

gender               0
age                  0
hypertension         0
heart_disease        0
ever_married         0
work_type            0
Residence_type       0
avg_glucose_level    0
bmi                  0
smoking_status       0
dtype: int64

# Curate Schema

The pipeline uses CsvExampleGen to ingest data into TFX pipelines. It takes in data from external data sources such as CSV, TFRecord, Avro, Parquet and BigQuery.

StatisticsGen generates features statistics over both training and serving data, which can be used by other pipeline components. It takes in the examples read in from the example gen component.

SchemaGen will automatically generate a schema by inferring types, categories, and ranges from the training data. It can specify data types for feature values, whether a feature has to be present in all examples, allowed value ranges, and other properties. It takes in the statistics generated on the examples ingested.

In [13]:
def create_schema_pipeline(pipeline_name: str, pipeline_root: str, data_root: str, metadata_path: str) -> tfx.dsl.Pipeline:
  """
  Creates a pipeline for the initial data validation
  Args:
    pipeline_name: name of schema pipeline
    pipeline_root: path of schema pipeline directory
    data_root: path of data file directory
    metadata_path: path of metadata

  Returns:
    Pipeline with example_gen, statistics_gen, and schema_gen
  """

  # Brings data into the pipeline
  # Input: data file to read into the component
  # Output: ExampleGen component that can be used for statistics and schema generation
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and schema generation
  # Input: examples from example gen component
  # Output: StatisticsGen component that can be used to visualize statistics on the dataset
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # Generates schema based on the generated statistics
  # Input: statistics from statistics gen component
  # Output: SchemaGen component that can be used to visualize schema on the dataset
  schema_gen = tfx.components.SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
  
  components = [example_gen, statistics_gen, schema_gen]

  return tfx.dsl.Pipeline(pipeline_name=pipeline_name,
                          pipeline_root=pipeline_root,
                          metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(metadata_path),
                          components=components)

We define a curate_schema function that reads in the initial schema artifact and sets appropriate data types, ranges, environment configurations, and more. Schema curation can also include drift and skew thresholds for certain features which can be added here or individually after.

In [14]:
def curate_schema(schema_artifacts):
  """
  Curate latest schema to the proper specifications.
  """
  
  # Get the schema uri
  schema_uri = schema_artifacts[0].uri
  # Get the schema pbtxt file from the SchemaGen output
  schema = tfdv.load_schema_text(os.path.join(schema_uri, 'schema.pbtxt'))

  # Set the `bmi` feature to an appropriate type
  tfdv.set_domain(schema, 'bmi', schema_pb2.FloatDomain(name='bmi', min=0.0, max=200.0))

  # Set the `age` feature to an appropriate range
  tfdv.set_domain(schema, 'age', schema_pb2.FloatDomain(name='age', min=0, max=100))

  # Set the `avg_glucose_level` feature to an appropriate range
  tfdv.set_domain(schema, 'avg_glucose_level', schema_pb2.FloatDomain(name='avg_glucose_level', min=25.0, max=300.0))

  # Set the `heart_disease`, `stroke`, and `hypertension` features to 0/1 and make it categorical
  tfdv.set_domain(schema, 'heart_disease', schema_pb2.IntDomain(name='heart_disease', min=0, max=1, is_categorical=True))
  tfdv.set_domain(schema, 'hypertension', schema_pb2.IntDomain(name='hypertension', min=0, max=1, is_categorical=True))
  tfdv.set_domain(schema, 'stroke', schema_pb2.IntDomain(name='stroke', min=0, max=1, is_categorical=True))

  # Create training and servin environments to recognize differences
  schema.default_environment.append('TRAINING')
  schema.default_environment.append('SERVING')

  # Removing stroke feature from SERVING using TFDV
  tfdv.get_feature(schema, 'stroke').not_in_environment.append('SERVING')

  return schema

def save_schema(schema, folder, subfolder):
  """
  Save schema to file.
  """

  SCHEMA_DIR = f'{DATA_DIR}/{folder}/{subfolder}'
  updated_schema = os.path.join(SCHEMA_DIR, 'schema.pbtxt')

  tfdv.write_schema_text(schema, updated_schema)

Running the pipeline just involves including in the needed paths; we can add the schema curation in the pipeline itself or we can add that in after to ensure separation and make it easier to debug.

In [15]:
tfx.orchestration.LocalDagRunner().run(create_schema_pipeline(pipeline_name=SCHEMA_PIPELINE_NAME,
                                                              pipeline_root=SCHEMA_PIPELINE_ROOT,
                                                              data_root=DATA_DIR_SELECT,
                                                              metadata_path=SCHEMA_METADATA_PATH))

INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "SchemaGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.schema_gen.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "CsvExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
      filename_uri: "/conte

INFO:absl:Processing input csv data /content/drive/My Drive/Stroke Prediction ML System/data/select/* to TFExample.
INFO:absl:Examples generated.
INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it
INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 40 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "/content/drive/My Drive/Stroke Prediction ML System/pipeline/stroke-mlops-schema/CsvExampleGen/examples/40"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:286686,xor_checksum:1659020926,sum_checksum:1659020926"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "stroke-mlops-schema:2022-07-28T15:08:53.53

In [16]:
def get_latest_artifacts(metadata, pipeline_name, component_id):
  """
  Output artifacts of the latest run of the component.
  """
  
  # Get the latest executions of nodes within our pipeline and the designated component
  context = metadata.store.get_context_by_type_and_name('node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions, key=lambda e: e.last_update_time_since_epoch)
  
  return execution_lib.get_artifacts_dict(metadata, latest_execution.id, [metadata_store_pb2.Event.OUTPUT])

def visualize_artifacts(artifacts):
  """
  Visualizes artifacts using standard visualization modules.
  """
  
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(artifact.type_name)
    
    if visualization:
      visualization.display(artifact)

In [17]:
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(SCHEMA_METADATA_PATH)

with Metadata(metadata_connection_config) as metadata_handler:
  # Find output artifacts from MLMD
  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'StatisticsGen')
  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]

  schema_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'SchemaGen')
  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]

INFO:absl:MetadataStore with DB connection initialized


In [18]:
visualize_artifacts(stats_artifacts)

In [19]:
visualize_artifacts(schema_artifacts)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'Residence_type',STRING,required,,'Residence_type'
'age',FLOAT,required,,-
'avg_glucose_level',FLOAT,required,,-
'bmi',FLOAT,required,,-
'ever_married',STRING,required,,'ever_married'
'gender',STRING,required,,'gender'
'heart_disease',INT,required,,-
'hypertension',INT,required,,-
'smoking_status',STRING,required,,'smoking_status'
'stroke',INT,required,,-


Unnamed: 0_level_0,Values
Domain,Unnamed: 1_level_1
'Residence_type',"'Rural', 'Urban'"
'ever_married',"'No', 'Yes'"
'gender',"'Female', 'Male', 'Other'"
'smoking_status',"'Unknown', 'formerly smoked', 'never smoked', 'smokes'"
'work_type',"'Govt_job', 'Never_worked', 'Private', 'Self-employed', 'children'"


In [20]:
# Curate schema using our function and save it out to a new folder
curated_schema = curate_schema(schema_artifacts)
save_schema(curated_schema, 'pipeline','updated_schema')



In [21]:
UPDATED_SCHEMA_FILE = f'{DATA_DIR}/pipeline/updated_schema/schema.pbtxt'
new_schema = tfdv.load_schema_text(UPDATED_SCHEMA_FILE)

# Display the schema. Check that the Domain column still contains the ranges
tfdv.display_schema(schema=new_schema)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'Residence_type',STRING,required,,'Residence_type'
'age',FLOAT,required,,min: 0.000000; max: 100.000000
'avg_glucose_level',FLOAT,required,,min: 25.000000; max: 300.000000
'bmi',FLOAT,required,,min: 0.000000; max: 200.000000
'ever_married',STRING,required,,'ever_married'
'gender',STRING,required,,'gender'
'heart_disease',INT,required,,min: 0; max: 1
'hypertension',INT,required,,min: 0; max: 1
'smoking_status',STRING,required,,'smoking_status'
'stroke',INT,required,,min: 0; max: 1


Unnamed: 0_level_0,Values
Domain,Unnamed: 1_level_1
'Residence_type',"'Rural', 'Urban'"
'ever_married',"'No', 'Yes'"
'gender',"'Female', 'Male', 'Other'"
'smoking_status',"'Unknown', 'formerly smoked', 'never smoked', 'smokes'"
'work_type',"'Govt_job', 'Never_worked', 'Private', 'Self-employed', 'children'"


After we curated the schema to have the proper configurations, we can evaluate the training and serving environments to make sure no anomlaies are detected before saving out the final version to file.

In [22]:
# Get stats for serving data and validate by checking for anomalies in the serving environment
stats_options = tfdv.StatsOptions(schema=new_schema, infer_type_from_schema=True)
serving_stats = tfdv.generate_statistics_from_csv(SERVING_DATA, stats_options=stats_options)
serving_anomalies = tfdv.validate_statistics(serving_stats, new_schema, environment='SERVING')

tfdv.display_anomalies(serving_anomalies)



In [23]:
# Get stats for training data and validate by checking for anomalies in the training environment
training_stats = tfdv.generate_statistics_from_csv(TRAINING_DATA, stats_options=stats_options)
training_anomalies = tfdv.validate_statistics(training_stats, new_schema, environment='TRAINING')

tfdv.display_anomalies(training_anomalies)



In [24]:
# Add skew comparator for 'bmi' feature
tfdv.get_feature(new_schema, 'bmi').skew_comparator.infinity_norm.threshold = 0.01

# Add drift comparator for 'heart_disease' feature
tfdv.get_feature(new_schema, 'heart_disease').drift_comparator.infinity_norm.threshold = 0.001

# Check for anomalies after adding skew and drift thresholds
skew_anomalies = tfdv.validate_statistics(training_stats, new_schema, serving_statistics=serving_stats)

tfdv.display_anomalies(skew_anomalies)

In [25]:
# Freeze final schema
save_schema(new_schema, 'schema', 'schema_output')

This file is meant to be paired with the modeling POC work; think of this as the data POC work. We want to get a sense of what kind of data is going to be coming in when we deploy our best model to production so we can save out an "ideal" schema that represents our data model(s).

Then as we are operating our ML pipeline in production, we can load in our curated schema to evaluate to new data and assess if we need to retrain the model based on changes to the data or unforeseen issues causing new data prep methods.