# TFX -  Interactive Pipeline

This multi-part tutorial shows how to use Matrix Factorization algorithm in BigQuery ML to generate embeddings for items based on their cooccurrence statistics. The generated item embeddings can be then used to find similar items.

The is notebook covers creating and running a TFX pipeline that performs the following steps:
1. Compute PMI using a [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
2. Train BigQuery ML matrix factorization model using a [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
3. Extract the Embeddings from the BigQuery ML model to a BigQuery table using a [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
4. Export the embeddings as TFRecords using the standard [BigQueryExampleGen](https://www.tensorflow.org/tfx/api_docs/python/tfx/extensions/google_cloud_big_query/example_gen/component/BigQueryExampleGen) component.
5. Import the schema for the embeddings using the standard [ImporterNode](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/ImporterNode) component.
6. Validate the embeddings against the imported schema using the standard [StatisticsGen](https://www.tensorflow.org/tfx/guide/statsgen) and [ExampleValidator](https://www.tensorflow.org/tfx/guide/exampleval) components. 
7. Create an embedding lookup SavedModel using the standard [Trainer](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Trainer) component.
8. Push the embedding lookup model to a model registry directory using the standard [Pusher](https://www.tensorflow.org/tfx/guide/pusher) component.
9. Build the ScaNN index using the standard [Trainer](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Trainer) component.
10. Evaluate and validate the ScaNN index latency and recall by implementing a [TFX Custom Component](https://www.tensorflow.org/tfx/guide/custom_component).
11. Push the ScaNN index to a model registry directory using standard [Pusher](https://www.tensorflow.org/tfx/guide/pusher) component.


After running the pipeline steps, we check the metadata stored in the local MLMD.

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install -U -q tfx

### Import libraries

In [None]:
import os
import numpy as np
import tfx
import tensorflow as tf
import tensorflow_data_validation as tfdv
from tensorflow_transform.tf_metadata import schema_utils
import logging

logging.getLogger().setLevel(logging.INFO)

print("Tensorflow Version:", tf.__version__)
print("TFX Version:", tfx.__version__)

### Configure GCP environment settings

In [None]:
PROJECT_ID = 'ksalama-cloudml' # Change to your project.
BUCKET = 'ksalama-cloudml' # Change to your bucket.
BQ_DATASET_NAME = 'recommendations'
ARTIFACT_STORE = f'gs://{BUCKET}/tfx_artifact_store'
LOCAL_MLMD_SQLLITE = 'mlmd/mlmd.sqllite'
PIPELINE_NAME = 'tfx_bqml_scann'
EMBEDDING_LOOKUP_MODEL_NAME = 'embeddings_lookup'
SCANN_INDEX_MODEL_NAME = 'embeddings_scann'

PIPELINE_ROOT = os.path.join(ARTIFACT_STORE, f'{PIPELINE_NAME}_interactive')
MODEL_REGISTRY_DIR = os.path.join(ARTIFACT_STORE, 'model_registry_interactive')

!gcloud config set project $PROJECT_ID

### Authenticate your GCP account
This is required if you run the notebook in Colab

In [None]:
try:
  from google.colab import auth
  auth.authenticate_user()
  print("Colab user is authenticated.")
except: pass

## Create Interactive Context

In [None]:
CLEAN_ARTIFACTS = True
if CLEAN_ARTIFACTS:
  if tf.io.gfile.exists(PIPELINE_ROOT):
    print("Removing previous artifacts...")
    tf.io.gfile.rmtree(PIPELINE_ROOT)
  if tf.io.gfile.exists('mlmd'):
    print("Removing local mlmd SQLite...")
    tf.io.gfile.rmtree('mlmd')

if not tf.io.gfile.exists('mlmd'):
  print("Creating mlmd directory...")
  tf.io.gfile.mkdir('mlmd')
    
print(f'Pipeline artifacts directory: {PIPELINE_ROOT}')
print(f'Model registry directory: {MODEL_REGISTRY_DIR}')
print(f'Local metadata SQLlit path: {LOCAL_MLMD_SQLLITE}')

In [None]:
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = LOCAL_MLMD_SQLLITE
connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE
mlmd_store = mlmd.metadata_store.MetadataStore(connection_config)

context = InteractiveContext(
  pipeline_name=PIPELINE_NAME,
  pipeline_root=PIPELINE_ROOT,
  metadata_connection_config=connection_config
)

## Executing the Pipeline Steps
The pipeline BigQuery steps components are implemented in [tfx_pipeline/bq_components.py](tfx_pipeline/bq_components.py) module.

In [None]:
from tfx_pipeline import bq_components

### 1. Compute PMI step

In [None]:
pmi_computer = bq_components.compute_pmi(
  project_id=PROJECT_ID,
  bq_dataset=BQ_DATASET_NAME,
  min_item_frequency=15,
  max_group_size=100,
)

In [None]:
context.run(pmi_computer)

In [None]:
pmi_computer.outputs.item_cooc.get()[0].get_string_custom_property('bq_result_table')

### 2. Train the BigQuery matrix factorization model step

In [None]:
bqml_trainer = bq_components.train_item_matching_model(
  project_id=PROJECT_ID,
  bq_dataset=BQ_DATASET_NAME,
  item_cooc=pmi_computer.outputs.item_cooc,
  dimensions=50,
)

In [None]:
context.run(bqml_trainer)

In [None]:
bqml_trainer.outputs.bq_model.get()[0].get_string_custom_property('bq_model_name')

### 3. Extract trained embeddings step

In [None]:
embeddings_extractor = bq_components.extract_embeddings(
  project_id=PROJECT_ID,
  bq_dataset=BQ_DATASET_NAME,
  bq_model=bqml_trainer.outputs.bq_model,
)

In [None]:
context.run(embeddings_extractor)

In [None]:
embeddings_extractor.outputs.item_embeddings.get()[0].get_string_custom_property('bq_result_table')

### 4. Export embeddings as TFRecords step

In [None]:
from tfx.proto import example_gen_pb2
from tfx.extensions.google_cloud_big_query.example_gen.component import BigQueryExampleGen

query = f'''
  SELECT item_Id, embedding, bias,
  FROM {BQ_DATASET_NAME}.item_embeddings
  LIMIT 1000
'''

output_config = example_gen_pb2.Output(
  split_config=example_gen_pb2.SplitConfig(splits=[
    example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=1)])
)

embeddings_exporter = BigQueryExampleGen(
  query=query,
  output_config=output_config
)

In [None]:
beam_pipeline_args = [
  '--runner=DirectRunner',
  f'--project={PROJECT_ID}',
  f'--temp_location=gs://{BUCKET}/bqml_scann/beam/temp',
]

context.run(embeddings_exporter, beam_pipeline_args=beam_pipeline_args)

### 5. Import the Schema for the embeddings step.

In [None]:
schema_importer = tfx.components.ImporterNode(
  source_uri='tfx_pipeline/schema',
  artifact_type=tfx.types.standard_artifacts.Schema,
  instance_name='SchemaImporter'
)

In [None]:
context.run(schema_importer)

In [None]:
context.show(schema_importer.outputs.result)

#### Read a sample embedding from the exported TFRecord files using the Schema:

In [None]:
schema_file = schema_importer.outputs.result.get()[0].uri + "/schema.pbtxt"
schema = tfdv.load_schema_text(schema_file)
feature_sepc = schema_utils.schema_as_feature_spec(schema).feature_spec

In [None]:
data_uri = embeddings_exporter.outputs.examples.get()[0].uri + "/train/*"

def _gzip_reader_fn(filenames):
  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')

dataset = tf.data.experimental.make_batched_features_dataset(
  data_uri, 
  batch_size=1, 
  num_epochs=1,
  features=feature_sepc,
  reader=_gzip_reader_fn,
  shuffle=True
)

counter = 0
for _ in dataset: counter +=1
print(f'Number of records: {counter}')
print('')

for batch in dataset.take(1):
  print(f'item: {batch["item_Id"].numpy()[0][0].decode()}')
  print(f'embedding vector: {batch["embedding"].numpy()[0]}')

### 6. Validate the embeddings against the imported Schema step

In [None]:
stats_generator = tfx.components.StatisticsGen(
  examples=embeddings_exporter.outputs.examples,
)

context.run(stats_generator)

In [None]:
stats_validator = tfx.components.ExampleValidator(
  statistics=stats_generator.outputs.statistics,
  schema=schema_importer.outputs.result,
)

context.run(stats_validator)

In [None]:
context.show(stats_validator.outputs.anomalies)

### 7. Create an embedding lookup SavedModel step

In [None]:
from tfx.components.base import executor_spec
from tfx.components.trainer import executor as trainer_executor

_module_file = 'tfx_pipeline/lookup_creator.py'

embedding_lookup_creator = tfx.components.Trainer(
  custom_executor_spec=executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
  module_file=_module_file,
  train_args={'splits': ['train'], 'num_steps': 0},
  eval_args={'splits': ['train'], 'num_steps': 0},
  schema=schema_importer.outputs.result,
  examples=embeddings_exporter.outputs.examples,
)

In [None]:
context.run(embedding_lookup_creator)

### 8. Infra-validate the lookup model step

In [None]:
from tfx.proto import infra_validator_pb2

serving_config = infra_validator_pb2.ServingSpec(
  tensorflow_serving=infra_validator_pb2.TensorFlowServing(
      tags=['latest']),
  local_docker=infra_validator_pb2.LocalDockerConfig(),
)
  
validation_config = infra_validator_pb2.ValidationSpec(
  max_loading_time_seconds=60,
  num_tries=3,
)

infra_validator = tfx.components.InfraValidator(
  model=embedding_lookup_creator.outputs.model,
  serving_spec=serving_config,
  validation_spec=validation_config,
)

In [None]:
context.run(infra_validator)

In [None]:
tf.io.gfile.listdir(infra_validator.outputs.blessing.get()[0].uri)

### 8. Push the embedding lookup model to the model registry step

In [None]:
embedding_lookup_pusher = tfx.components.Pusher(
  model=embedding_lookup_creator.outputs.model,
  infra_blessing=infra_validator.outputs.blessing,
  push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
      base_directory=os.path.join(MODEL_REGISTRY_DIR, EMBEDDING_LOOKUP_MODEL_NAME))
  )
)

In [None]:
context.run(embedding_lookup_pusher)

In [None]:
lookup_savedmodel_dir = embedding_lookup_pusher.outputs.pushed_model.get()[0].get_string_custom_property('pushed_destination')
!saved_model_cli show --dir {lookup_savedmodel_dir} --tag_set serve --signature_def serving_default

In [None]:
loaded_model = tf.saved_model.load(lookup_savedmodel_dir)
vocab = [token.strip() for token in tf.io.gfile.GFile(
  loaded_model.vocabulary_file.asset_path.numpy().decode(), 'r').readlines()]

In [None]:
input_items = [vocab[0], ' '.join([vocab[1], vocab[2]]), 'abc123']
print(input_items)
output = loaded_model(input_items)
print(f'Embeddings retrieved: {len(output)}')
for idx, embedding in enumerate(output):
  print(f'{input_items[idx]}: {embedding[:5]}')

### 9. Build the ScaNN index step

In [None]:
from tfx.components.base import executor_spec
from tfx.components.trainer import executor as trainer_executor

_module_file = 'tfx_pipeline/scann_indexer.py'

scann_indexer = tfx.components.Trainer(
  custom_executor_spec=executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
  module_file=_module_file,
  train_args={'splits': ['train'], 'num_steps': 0},
  eval_args={'splits': ['train'], 'num_steps': 0},
  schema=schema_importer.outputs.result,
  examples=embeddings_exporter.outputs.examples
)

In [None]:
context.run(scann_indexer)

### 10. Evaluate and validate the ScaNN Index step

The IndexEvaluator custom component is implemented in the [tfx_pipeline/scann_evaluator.py](tfx_pipeline/scann_evaluator.py) module.

In [None]:
from tfx_pipeline import scann_evaluator

index_evaluator = scann_evaluator.IndexEvaluator(
  examples=embeddings_exporter.outputs.examples,
  model=scann_indexer.outputs.model,
  schema=schema_importer.outputs.result,
  min_recall=0.8,
  max_latency=0.01,
)

In [None]:
context.run(index_evaluator)

### 11. Push the ScaNN index to the model registry step

In [None]:
embedding_scann_pusher = tfx.components.Pusher(
  model=scann_indexer.outputs.model,
  model_blessing=index_evaluator.outputs.blessing,
  push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
      base_directory=os.path.join(MODEL_REGISTRY_DIR, SCANN_INDEX_MODEL_NAME))
  )
)

In [None]:
context.run(embedding_scann_pusher)

In [None]:
from index_server.matching import ScaNNMatcher
scann_index_dir = embedding_scann_pusher.outputs.pushed_model.get()[0].get_string_custom_property('pushed_destination')
scann_matcher = ScaNNMatcher(scann_index_dir)

In [None]:
vector = np.random.rand(50)
scann_matcher.match(vector, 5)

## Check Local MLMD Store 

In [None]:
mlmd_store.get_artifacts()

## View the Model Registry Directory

In [None]:
!gsutil ls {MODEL_REGISTRY_DIR}

## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**