In [None]:
# !pip install -U tfx
# !pip install apache-beam==2.39.0

In [None]:
from absl import logging
from google.protobuf.json_format import MessageToDict
from io import StringIO
from ml_metadata.proto import metadata_store_pb2
import numpy as np
import os
import pandas as pd
import pprint
import shutil
import tempfile
import tensorflow as tf
import tensorflow_data_validation as tfdv
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
from tfx import v1 as tfx
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.orchestration.experimental.interactive import visualizations, standard_visualizations
from tfx.orchestration.metadata import Metadata
from  tfx.proto import example_gen_pb2
from tfx.types import standard_artifacts, standard_component_specs

# To ignore warnings from TF
tf.get_logger().setLevel('ERROR')
# Set default logging level
logging.set_verbosity(logging.INFO) 

# For formatting print statements
pp = pprint.PrettyPrinter()

standard_visualizations.register_standard_visualizations()

# Display versions of TF and TFX related packages
print('TensorFlow version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
print('TensorFlow Data Validation version: {}'.format(tfdv.__version__))

TensorFlow version: 2.8.2
TFX version: 1.8.0
TensorFlow Data Validation version: 1.8.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Location of the data and model directory
DATA_DIR = '/content/drive/My Drive/Stroke Prediction ML System'
DATA_ROOT = f'{DATA_DIR}/data'

# Set the paths to the reduced dataset
DATA_DIR_SELECT = f'{DATA_DIR}/data/select'
TRAINING_ROOT = f'{DATA_DIR}/training'
TESTING_ROOT = f'{DATA_DIR}/testing'
SERVING_ROOT = f'{DATA_DIR}/serving'

TRAINING_DATA = f'{DATA_DIR}/training/stoke_prediction_training_dataset.csv'
TESTING_DATA = f'{DATA_DIR}/training/stoke_prediction_testing_dataset.csv'
SERVING_DATA = f'{DATA_DIR}/serving/stoke_prediction_serving_dataset.csv'

# We will create a pipeline for schema generation
SCHEMA_PIPELINE_NAME = 'stroke-mlops-schema'

# Output directory to store artifacts generated from the pipeline
SCHEMA_PIPELINE_ROOT = os.path.join(DATA_DIR, 'pipeline', SCHEMA_PIPELINE_NAME)

# Path to curated schema file
SCHEMA_FOLDER = os.path.join(DATA_DIR, 'schema/schema_output')

# Path to a SQLite DB file to use as an MLMD storage
SCHEMA_METADATA_PATH = os.path.join(DATA_DIR, 'metadata', SCHEMA_PIPELINE_NAME, 'metadata.db')

# Set random seed
RANDOM_SEED = 0

In [None]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     schema_path: str, metadata_path: str) -> tfx.dsl.Pipeline:                     
  """
  Creates a pipeline for data validation

  Args:
    pipeline_name: name of schema pipeline
    pipeline_root: path of schema pipeline directory
    data_root: path of data file directory
    schema_path: path of curated schema
    metadata_path: path of metadata

  Returns:
    Pipeline with example_gen, statistics_gen, schema_importer, and example_validator
  """

  # Brings data into the pipeline
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Computes statistics over data for visualization and example validation
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # Import the schema
  schema_importer = tfx.dsl.Importer(
      source_uri=schema_path,
      artifact_type=tfx.types.standard_artifacts.Schema).with_id(
          'schema_importer')

  # Performs anomaly detection based on statistics and data schema
  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_importer.outputs['result'])

  components = [example_gen, statistics_gen, schema_importer, example_validator]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

In [None]:
tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=SCHEMA_PIPELINE_NAME,
      pipeline_root=SCHEMA_PIPELINE_ROOT,
      data_root=DATA_DIR_SELECT,
      schema_path=SCHEMA_FOLDER,
      metadata_path=SCHEMA_METADATA_PATH))

INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.csv_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "ExampleValidator"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_validator.executor.Executor"
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "CsvExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
      filenam

In [None]:
def get_latest_artifacts(metadata, pipeline_name, component_id):
  """
  Output artifacts of the latest run of the component
  """
  
  # Get the latest executions of nodes within our pipeline and the designated component
  context = metadata.store.get_context_by_type_and_name('node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions, key=lambda e: e.last_update_time_since_epoch)
  
  return execution_lib.get_artifacts_dict(metadata, latest_execution.id, [metadata_store_pb2.Event.OUTPUT])

def visualize_artifacts(artifacts):
  """
  Visualizes artifacts using standard visualization modules
  """
  
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(artifact.type_name)
    
    if visualization:
      visualization.display(artifact)

In [None]:
metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(SCHEMA_METADATA_PATH)

with Metadata(metadata_connection_config) as metadata_handler:
  
  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'StatisticsGen')
  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]

  schema_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'SchemaGen')
  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]

  ev_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME, 'ExampleValidator')
  anomalies_artifacts = ev_output[standard_component_specs.ANOMALIES_KEY]

INFO:absl:MetadataStore with DB connection initialized


In [None]:
visualize_artifacts(stats_artifacts)

In [None]:
visualize_artifacts(schema_artifacts)

In [None]:
visualize_artifacts(anomalies_artifacts)