In [2]:
import tensorflow_model_analysis as tfma
import ml_metadata as mlmd
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
import pipeline_settings as settings
import pandas as pd

print('TFX version: {}'.format(tfx.__version__))
print('MLMD version: {}'.format(mlmd.__version__))


TFX version: 1.12.0
MLMD version: 1.12.0


In [3]:
interactive_context = InteractiveContext(
    pipeline_name=settings.PIPELINE_NAME,
    pipeline_root=settings.PIPELINE_ROOT,
    metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
        settings.METADATA_PATH)
)

connection_config = interactive_context.metadata_connection_config
store = mlmd.MetadataStore(connection_config)

# All TFX artifacts are stored in the base directory
base_dir = connection_config.sqlite.filename_uri.split('metadata.sqlite')[0]

In [4]:
def display_types(types):
  # Helper function to render dataframes for the artifact and execution types
  table = {'id': [], 'name': []}
  for a_type in types:
    table['id'].append(a_type.id)
    table['name'].append(a_type.name)
  return pd.DataFrame(data=table)


def display_artifacts(store, artifacts):
  # Helper function to render dataframes for the input artifacts
  table = {'artifact id': [], 'type': [], 'uri': []}
  for a in artifacts:
    table['artifact id'].append(a.id)
    artifact_type = store.get_artifact_types_by_id([a.type_id])[0]
    table['type'].append(artifact_type.name)
    table['uri'].append(a.uri.replace(base_dir, './'))
  return pd.DataFrame(data=table)


def display_properties(store, node):
  # Helper function to render dataframes for artifact and execution properties
  table = {'property': [], 'value': []}
  for k, v in node.properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  for k, v in node.custom_properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  return pd.DataFrame(data=table)


def get_one_hop_parent_artifacts(store, artifacts):
  # Get a list of artifacts within a 1-hop of the artifacts of interest
  artifact_ids = [artifact.id for artifact in artifacts]
  executions_ids = set(
      event.execution_id
      for event in store.get_events_by_artifact_ids(artifact_ids)
      if event.type == mlmd.proto.Event.OUTPUT)
  artifacts_ids = set(
      event.artifact_id
      for event in store.get_events_by_execution_ids(executions_ids)
      if event.type == mlmd.proto.Event.INPUT)
  return [artifact for artifact in store.get_artifacts_by_id(artifacts_ids)]


def find_producer_execution(store, artifact):
  executions_ids = set(
      event.execution_id for event in store.get_events_by_artifact_ids([artifact.id]) if event.type == mlmd.proto.Event.OUTPUT
    )
  return store.get_executions_by_id(executions_ids)[0]


In [5]:
display_types(store.get_artifact_types())

Unnamed: 0,id,name
0,15,Examples
1,17,ExampleStatistics
2,19,Schema


In [6]:
example_statistics_set = store.get_artifacts_by_type("ExampleStatistics")
display_artifacts(store, example_statistics_set)


Unnamed: 0,artifact id,type,uri
0,3,ExampleStatistics,/root/tfx_data/minst_pipeline_root/StatisticsG...
1,6,ExampleStatistics,/root/tfx_data/minst_pipeline_root/StatisticsG...
2,9,ExampleStatistics,/root/tfx_data/minst_pipeline_root/StatisticsG...
3,12,ExampleStatistics,/root/tfx_data/minst_pipeline_root/StatisticsG...


In [7]:
example_statistics = example_statistics_set[-1]
display_properties(store, example_statistics)


Unnamed: 0,property,value
0,split_names,"[""train"", ""eval""]"
1,tfx_version,1.12.0
2,state,published
3,is_external,0


In [8]:
parent_artifacts = get_one_hop_parent_artifacts(store, [example_statistics])
display_artifacts(store, parent_artifacts)


Unnamed: 0,artifact id,type,uri
0,11,Examples,/root/tfx_data/minst_pipeline_root/ImportExamp...


In [9]:
exported_dataset = parent_artifacts[0]
display_properties(store, exported_dataset)


Unnamed: 0,property,value
0,split_names,"[""train"", ""eval""]"
1,state,published
2,payload_format,FORMAT_TF_EXAMPLE
3,input_fingerprint,"split:single_split,num_files:1,total_bytes:597..."
4,is_external,0
5,tfx_version,1.12.0
6,span,0
7,file_format,tfrecords_gzip


In [10]:
display_types(store.get_execution_types())


Unnamed: 0,id,name
0,13,tfx.components.example_gen.import_example_gen....
1,16,tfx.components.statistics_gen.component.Statis...
2,18,tfx.components.schema_gen.component.SchemaGen


In [11]:
trainer = find_producer_execution(store, exported_dataset)
display_properties(store, trainer)


Unnamed: 0,property,value
0,span,0
1,output_config,"{\n ""split_config"": {\n ""splits"": [\n ..."
2,output_file_format,5
3,output_data_format,6
4,input_fingerprint,"split:single_split,num_files:1,total_bytes:597..."
5,input_config,"{\n ""splits"": [\n {\n ""name"": ""single..."
6,input_base,/tmp/tfx-datazpm6kxqa/v1.0


In [12]:
from tfx.orchestration.experimental.interactive import standard_visualizations
from tfx.orchestration.experimental.interactive import visualizations
from ml_metadata.proto import metadata_store_pb2
# Non-public APIs, just for showcase.
from tfx.orchestration.portable.mlmd import execution_lib

# TODO(b/171447278): Move these functions into the TFX library.


def get_latest_artifacts(metadata, pipeline_name, component_id):
  """Output artifacts of the latest run of the component."""
  context = metadata.store.get_context_by_type_and_name(
      'node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions,
                         key=lambda e: e.last_update_time_since_epoch)
  return execution_lib.get_output_artifacts(metadata, latest_execution.id)


# Non-public APIs, just for showcase.


def visualize_artifacts(artifacts):
  """Visualizes artifacts using standard visualization modules."""
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(
        artifact.type_name)
    if visualization:
      visualization.display(artifact)


standard_visualizations.register_standard_visualizations()


In [23]:
# Non-public APIs, just for showcase.
from tfx.orchestration.metadata import Metadata
from tfx.types import standard_component_specs
import tensorflow_data_validation as tfdv

latest_statistics = store.get_artifacts_by_type("ExampleStatistics")[-1]

stats_uri = latest_statistics.uri + '/Split-train/FeatureStats.pb'
stats = tfdv.load_stats_binary(stats_uri)

lateast_schema = store.get_artifacts_by_type("Schema")[-1]
schema_uri = lateast_schema.uri + '/schema.pbtxt'
schema = tfdv.load_schema_text(schema_uri)

tfdv.validate_statistics(stats, schema)




baseline {
  feature {
    name: "image"
    type: BYTES
    presence {
      min_fraction: 1.0
      min_count: 1
    }
    shape {
      dim {
        size: 1
      }
    }
  }
  feature {
    name: "label"
    type: INT
    presence {
      min_fraction: 1.0
      min_count: 1
    }
    shape {
      dim {
        size: 1
      }
    }
  }
}
anomaly_name_format: SERIALIZED_PATH

In [24]:
tfdv.visualize_statistics(stats)
