In [1]:
import ml_metadata as mlmd
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
import pipeline_settings as settings
import pandas as pd
import tensorflow as tf

print('TF version: {}'.format(tf.__version__))
print('TFX version: {}'.format(tfx.__version__))
print('MLMD version: {}'.format(mlmd.__version__))
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))


TF version: 2.11.0
TFX version: 1.12.0
MLMD version: 1.12.0
Num GPUs Available:  1


In [None]:
interactive_context = InteractiveContext(
    pipeline_name=settings.PIPELINE_NAME,
    pipeline_root=settings.PIPELINE_ROOT,
    metadata_connection_config=tfx.orchestration.metadata.sqlite_metadata_connection_config(
        settings.METADATA_PATH)
)

connection_config = interactive_context.metadata_connection_config
store = mlmd.MetadataStore(connection_config)

# All TFX artifacts are stored in the base directory
base_dir = connection_config.sqlite.filename_uri.split('metadata.sqlite')[0]

In [None]:
def display_types(types):
  # Helper function to render dataframes for the artifact and execution types
  table = {'id': [], 'name': []}
  for a_type in types:
    table['id'].append(a_type.id)
    table['name'].append(a_type.name)
  return pd.DataFrame(data=table)


def display_artifacts(store, artifacts):
  # Helper function to render dataframes for the input artifacts
  table = {'artifact id': [], 'type': [], 'uri': []}
  for a in artifacts:
    table['artifact id'].append(a.id)
    artifact_type = store.get_artifact_types_by_id([a.type_id])[0]
    table['type'].append(artifact_type.name)
    table['uri'].append(a.uri.replace(base_dir, './'))
  return pd.DataFrame(data=table)


def display_properties(store, node):
  # Helper function to render dataframes for artifact and execution properties
  table = {'property': [], 'value': []}
  for k, v in node.properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  for k, v in node.custom_properties.items():
    table['property'].append(k)
    table['value'].append(
        v.string_value if v.HasField('string_value') else v.int_value)
  return pd.DataFrame(data=table)


def get_one_hop_parent_artifacts(store, artifacts):
  # Get a list of artifacts within a 1-hop of the artifacts of interest
  artifact_ids = [artifact.id for artifact in artifacts]
  executions_ids = set(
      event.execution_id
      for event in store.get_events_by_artifact_ids(artifact_ids)
      if event.type == mlmd.proto.Event.OUTPUT)
  artifacts_ids = set(
      event.artifact_id
      for event in store.get_events_by_execution_ids(executions_ids)
      if event.type == mlmd.proto.Event.INPUT)
  return [artifact for artifact in store.get_artifacts_by_id(artifacts_ids)]


def find_producer_execution(store, artifact):
  executions_ids = set(
      event.execution_id for event in store.get_events_by_artifact_ids([artifact.id]) if event.type == mlmd.proto.Event.OUTPUT
    )
  return store.get_executions_by_id(executions_ids)[0]


In [None]:
display_types(store.get_artifact_types())

In [None]:
example_statistics_set = store.get_artifacts_by_type("ExampleStatistics")
display_artifacts(store, example_statistics_set)


In [None]:
example_statistics = example_statistics_set[-1]
display_properties(store, example_statistics)


In [None]:
parent_artifacts = get_one_hop_parent_artifacts(store, [example_statistics])
display_artifacts(store, parent_artifacts)


In [None]:
exported_dataset = parent_artifacts[0]
display_properties(store, exported_dataset)


In [None]:
display_types(store.get_execution_types())


In [None]:
trainer = find_producer_execution(store, exported_dataset)
display_properties(store, trainer)


In [None]:
from tfx.orchestration.experimental.interactive import standard_visualizations
from tfx.orchestration.experimental.interactive import visualizations
from ml_metadata.proto import metadata_store_pb2
# Non-public APIs, just for showcase.
from tfx.orchestration.portable.mlmd import execution_lib

# TODO(b/171447278): Move these functions into the TFX library.


def get_latest_artifacts(metadata, pipeline_name, component_id):
  """Output artifacts of the latest run of the component."""
  context = metadata.store.get_context_by_type_and_name(
      'node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions,
                         key=lambda e: e.last_update_time_since_epoch)
  return execution_lib.get_output_artifacts(metadata, latest_execution.id)


# Non-public APIs, just for showcase.


def visualize_artifacts(artifacts):
  """Visualizes artifacts using standard visualization modules."""
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(
        artifact.type_name)
    if visualization:
      visualization.display(artifact)


standard_visualizations.register_standard_visualizations()


In [None]:
import tensorflow_data_validation as tfdv

latest_statistics = store.get_artifacts_by_type("ExampleStatistics")[-1]

stats_uri = latest_statistics.uri + '/FeatureStats.pb'
stats = tfdv.load_stats_binary(stats_uri)

lateast_schema = store.get_artifacts_by_type("Schema")[-1]
schema_uri = lateast_schema.uri + '/schema.pbtxt'
schema = tfdv.load_schema_text(schema_uri)

tfdv.validate_statistics(stats, schema)




In [None]:
tfdv.visualize_statistics(stats)


In [None]:
stats_eval_uri = latest_statistics.uri + '/FeatureStats.pb'
stats_eval = tfdv.load_stats_binary(stats_eval_uri)
tfdv.visualize_statistics(stats, stats_eval)


In [None]:
examples = store.get_artifacts_by_type("Examples")
display_artifacts(store, examples)


In [None]:
dataset_uri = examples[-1].uri + \
    '/Split-train/transformed_examples-00000-of-00001.gz'
tf_dataset = tf.data.TFRecordDataset(dataset_uri, compression_type='GZIP')
lengt_dataset = tf_dataset.reduce(0, lambda x,_: x+1).numpy()

print(f'Number of train examples: {lengt_dataset}')
dataset_eval_uri = examples[-1].uri + \
    '/Split-eval/transformed_examples-00000-of-00001.gz'
tf_dataset_eval = tf.data.TFRecordDataset(
    dataset_eval_uri, compression_type='GZIP')
lengt_dataset_eval = tf_dataset_eval.reduce(0, lambda x, _: x+1).numpy()
print(f'Number of validation examples: {lengt_dataset_eval}')


In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from tensorflow_transform.tf_metadata import schema_utils

feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
print(feature_spec)


# Create a batch from the dataset
for records in tf_dataset.batch(1).take(1):

  # Parse the batch to get a dictionary of raw features
  parsed_examples = tf.io.parse_example(records, feature_spec)
  
  record = records[0]
  image = parsed_examples["image"][0]
  plt.imshow(image, interpolation='nearest', cmap='gray')
  plt.show()

  # Print the results
  print("\nRAW FEATURES:")
  for key, value in parsed_examples.items():
    print(f'{key}: {value.numpy()}')
        


In [None]:
# # Define a function to parse the `tf.train.Example` protocol buffer
# def parse_fn(example):
#     features = {
#         'image': tf.io.FixedLenFeature([], tf.string),
#         'label': tf.io.FixedLenFeature([], tf.int64)
#     }
#     parsed_example = tf.io.parse_single_example(example, features)
#     print(example)
#     print(parsed_example)
#     return parsed_example


# # Apply the parse function to the dataset
# dataset_inspect = tf_dataset.map(parse_fn)

# # one_example = next(iter(tf_dataset.take(1)))
# # parsed_example = parse_fn(one_example)
# # numpy_image = tf.io.parse_tensor(parsed_example["image"], out_type=tf.uint8)

# # Iterate over the dataset and print the features of each example
# for example in dataset_inspect.take(1):
#     numpy_image = tf.io.parse_tensor(example["image"], out_type=tf.uint8)
#     print(f'Image shape: {numpy_image.shape}')
#     plt.imshow(numpy_image, interpolation='nearest', cmap='gray')
#     plt.show()


In [None]:
examples_debug = store.get_artifacts_by_type("Examples")
dataset_debug_uri = examples[-1].uri
import_examples = tfx.components.ImportExampleGen(dataset_debug_uri)
#run the component


In [None]:
examples_debug = store.get_artifacts_by_type("Examples")
dataset_debug_uri = examples[-1].uri + \
    '/Split-train/transformed_examples-00000-of-00001.gz'

tf_dataset = tf.data.TFRecordDataset(dataset_uri, compression_type='GZIP')
lengt_dataset = tf_dataset.reduce(0, lambda x, _: x+1).numpy()

print(f'Number of train examples: {lengt_dataset}')
dataset_eval_uri = examples[-1].uri + \
    '/Split-eval/transformed_examples-00000-of-00001.gz'
tf_dataset_eval = tf.data.TFRecordDataset(
    dataset_eval_uri, compression_type='GZIP')
lengt_dataset_eval = tf_dataset_eval.reduce(0, lambda x, _: x+1).numpy()
print(f'Number of validation examples: {lengt_dataset_eval}')

dataset_importer = tfx.dsl.Importer(
    source_uri=dataset_uri,
    artifact_type=tfx.types.standard_artifacts.Examples).with_id(
        'dataset_importer')

lateast_schema = store.get_artifacts_by_type("Schema")[-1]
schema__debug_uri = lateast_schema.uri + '/schema.pbtxt'

schema_importer = tfx.dsl.Importer(
    source_uri=schema_uri,
    artifact_type=tfx.types.standard_artifacts.Schema).with_id(
    'schema_importer')


In [None]:
model_evaluations = store.get_artifacts_by_type("ModelEvaluation")
display_artifacts(store, model_evaluations)

In [None]:
import tensorflow_model_analysis as tfma
model_evaluation = model_evaluations[-1]
evaluation_uri = model_evaluation.uri
print(evaluation_uri)
eval_result = tfma.load_eval_result(evaluation_uri)
# tfma.view.render_slicing_metrics(eval_result, slicing_column='label')
eval_result.slicing_metrics
