# TFX -  Create BigQuery Stored Procedures

This multi-part tutorial shows how to use Matrix Factorization algorithm in BigQuery ML to generate embeddings for items based on their cooccurrence statistics. The generated item embeddings can be then used to find similar items.

The is notebook covers creating and running a TFX pipeline that performs the following steps:
1. Compute PMI using the [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
2. Train BigQuery Matrix Factorization Model using the [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
3. Extract the Embeddings from the Model to a Table using the [Custom Python function](https://www.tensorflow.org/tfx/guide/custom_function_component) component.
4. Export the embeddings as TFRecords using the [BigQueryExampleGen](https://www.tensorflow.org/tfx/api_docs/python/tfx/extensions/google_cloud_big_query/example_gen/component/BigQueryExampleGen) component.
5. Create an embedding lookup SavedModel using the [Trainer](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Trainer) component.
6. Push the embedding lookp model to a model registry directory using the [Pusher](https://www.tensorflow.org/tfx/guide/pusher) component.
7. Build the ScaNN index using the [Trainer](https://www.tensorflow.org/tfx/api_docs/python/tfx/components/Trainer) component.
8. Push the ScaNN index to a model registry directory using [Container-based](https://www.tensorflow.org/tfx/guide/container_component) component.


After running the pipeline steps, we check the metadata stored in the local MLMD.

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

### Import libraries

In [None]:
import os
import numpy as np
import tfx
import tensorflow as tf
import logging

logging.getLogger().setLevel(logging.INFO)

print("Tensorflow Version:", tf.__version__)
print("TFX Version:", tfx.__version__)

### Configure GCP environment settings

In [None]:
PROJECT_ID = 'ksalama-cloudml'
BUCKET = 'ksalama-cloudml'
REGION = 'us-central1'
BQ_DATASET_NAME = 'recommendations'
WORKSPACE = f'gs://{BUCKET}/tfx_artifact_store/tfx_bqml_scann_interactive'
LOCAL_MLMD_SQLLITE = 'mlmd/mlmd.sqllite'
PIPELINE_NAME = 'bqml-scann'
EMBEDDING_LOOKUP_MODEL_NAME = 'embeddings_lookup'
SCANN_INDEX_MODEL_NAME = 'embeddings_scann'
MODEL_REGISTRY_DIR = os.path.join(WORKSPACE, 'model_registry')

!gcloud config set project $PROJECT_ID

### Authenticate your GCP account
This is required if you run the notebook in Colab

In [None]:
try:
  from google.colab import auth
  auth.authenticate_user()
  print("Colab user is authenticated.")
except: pass

## Create Interactive Context

In [None]:
CLEAN_WORKSPACE = True
if CLEAN_WORKSPACE:
  if tf.io.gfile.exists(WORKSPACE):
    print("Removing previous artifacts...")
    tf.io.gfile.rmtree(WORKSPACE)
  if tf.io.gfile.exists('mlmd'):
    print("Removing local mlmd SQLite...")
    tf.io.gfile.rmtree('mlmd')

if not tf.io.gfile.exists('mlmd'):
  tf.io.gfile.mkdir('mlmd')

In [None]:
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = 'mlmd/mlmd.sqllite'
connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE
mlmd_store = mlmd.metadata_store.MetadataStore(connection_config)

context = InteractiveContext(
  pipeline_name=PIPELINE_NAME,
  pipeline_root=WORKSPACE,
  metadata_connection_config=connection_config
)

## Executing the Pipeline Steps
The pipeline BigQuery steps components are implemented in [tfx_pipeline/bq_components.py](tfx_pipeline/bq_components.py) module.

In [None]:
from tfx_pipeline import bq_components

### 1. Compute PMI step

In [None]:
pmi_computer = bq_components.compute_pmi(
  project_id=PROJECT_ID,
  dataset=BQ_DATASET_NAME,
  min_item_frequency=15,
  max_group_size=100
)

In [None]:
context.run(pmi_computer)

In [None]:
pmi_computer.outputs.item_cooc.get()[0].get_string_custom_property('bq_result_table')

### 2. Train the BigQuery matrix factorization model step

In [None]:
bqml_trainer = bq_components.train_item_matching_model(
  project_id=PROJECT_ID,
  dataset=BQ_DATASET_NAME,
  item_cooc=pmi_computer.outputs.item_cooc,
  dimensions=50,
)

In [None]:
context.run(bqml_trainer)

In [None]:
bqml_trainer.outputs.model.get()[0].get_string_custom_property('bq_model_name')

### 3. Extract trained embeddings step

In [None]:
embeddings_extractor = bq_components.extract_embeddings(
  project_id=PROJECT_ID,
  dataset=BQ_DATASET_NAME,
  model=bqml_trainer.outputs.model
)

In [None]:
context.run(embeddings_extractor)

In [None]:
embeddings_extractor.outputs.item_embeddings.get()[0].get_string_custom_property('bq_result_table')

### 4. Export embeddings as TFRecords step

In [None]:
from tfx.proto import example_gen_pb2
from tfx.extensions.google_cloud_big_query.example_gen.component import BigQueryExampleGen

fetch_embeddings_query = f'''
  SELECT item_Id, embedding
  FROM {BQ_DATASET_NAME}.item_embeddings
  LIMIT 1000
'''

output_config = example_gen_pb2.Output(
  split_config=example_gen_pb2.SplitConfig(splits=[
    example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=1)])
)

embeddings_exporter = BigQueryExampleGen(
  query=fetch_embeddings_query,
  output_config=output_config,
  instance_name='BQExportEmbeddings'
)

In [None]:
beam_pipeline_args = [
  '--runner=DirectRunner',
  f'--project={PROJECT_ID}',
  f'--temp_location=gs://{BUCKET}/bqml_scann/beam/temp',
]

context.run(embeddings_exporter, beam_pipeline_args=beam_pipeline_args)

In [None]:
data_uri = embeddings_exporter.outputs.examples.get()[0].uri + "/train/*"
tfrecord_filenames = tf.data.Dataset.list_files(data_uri)

# Create a `TFRecordDataset` to read these files
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

counter = 0
for _ in dataset: counter +=1
print(f'Number of records: {counter}')
print('')

# Display some records
for tfrecord in dataset.shuffle(100).take(1):
  serialized_example = tfrecord.numpy()
  example = tf.train.Example.FromString(serialized_example)
  item_Id = example.features.feature['item_Id'].bytes_list.value[0].decode()
  embedding = np.array(example.features.feature['embedding'].float_list.value)
  print(f'item: {item_Id}')
  print(f'embedding vector: {embedding}')

### 5. Create an embedding lookup SavedModel step

In [None]:
schema_importer = tfx.components.ImporterNode(
  instance_name='RawSchemaImporter',
  source_uri='tfx_pipeline/schema',
  artifact_type=tfx.types.standard_artifacts.Schema
)

context.run(schema_importer)

In [None]:
context.show(schema_importer.outputs.result)

In [None]:
from tfx.components.base import executor_spec
from tfx.components.trainer import executor as trainer_executor

_module_file = 'tfx_pipeline/lookup_exporter.py'

lookup_savedmodel_exporter = tfx.components.Trainer(
  custom_executor_spec=executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
  module_file=_module_file,
  train_args={'num_steps': 0},
  eval_args={'num_steps': 0},
  schema=schema_importer.outputs.result,
  examples=embeddings_exporter.outputs.examples,
  instance_name='ExportEmbeddingLookupSavedModel'
)

In [None]:
context.run(lookup_savedmodel_exporter)

### 6. Push the embedding lookup model to the model registry step

In [None]:
embedding_lookup_pusher = tfx.components.Pusher(
  model=lookup_savedmodel_exporter.outputs.model,
  push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
      base_directory=os.path.join(MODEL_REGISTRY_DIR, EMBEDDING_LOOKUP_MODEL_NAME))
  )
)

In [None]:
context.run(embedding_lookup_pusher)

In [None]:
lookup_savedmodel_dir = embedding_lookup_pusher.outputs.pushed_model.get()[0].get_string_custom_property('pushed_destination')
!saved_model_cli show --dir {lookup_savedmodel_dir} --tag_set serve --signature_def serving_default

In [None]:
loaded_model = tf.saved_model.load(lookup_savedmodel_dir)
vocab = [token.strip() for token in tf.io.gfile.GFile(
 loaded_model.vocabulary_file.asset_path.numpy().decode(), 'r').readlines()]

In [None]:
input_items = [vocab[0], ' '.join([vocab[1], vocab[2]]), 'abc123']
print(input_items)
output = loaded_model(input_items)
print(f'Embeddings retrieved: {len(output)}')
for idx, embedding in enumerate(output):
  print(f'{input_items[idx]}: {embedding[:5]}')

### 7. Build the ScaNN index step

In [None]:
from tfx.components.base import executor_spec
from tfx.components.trainer import executor as trainer_executor

_module_file = 'tfx_pipeline/scann_indexer.py'

scann_indexer = tfx.components.Trainer(
  custom_executor_spec=executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
  module_file=_module_file,
  train_args={'num_steps': 0},
  eval_args={'num_steps': 0},
  schema=schema_importer.outputs.result,
  examples=embeddings_exporter.outputs.examples,
  instance_name='BuildScaNNIndex'
)

In [None]:
context.run(scann_indexer)

### 8. Push the ScaNN index to the model registry step

In [None]:
embedding_scann_pusher = tfx.components.Pusher(
  model=scann_indexer.outputs.model,
  push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
      base_directory=os.path.join(MODEL_REGISTRY_DIR, SCANN_INDEX_MODEL_NAME))
  )
)

In [None]:
context.run(embedding_scann_pusher)

In [None]:
from index_server.matching import ScaNNMatcher
scann_index_dir = embedding_scann_pusher.outputs.pushed_model.get()[0].get_string_custom_property('pushed_destination')
scann_matcher = ScaNNMatcher(scann_index_dir)

In [None]:
vector = np.random.rand(50)
scann_matcher.match(vector, 5)

## Check Local MLMD Store 

In [None]:
mlmd_store.get_artifacts()

## View the Model Registry Directory

In [None]:
!gsutil ls {MODEL_REGISTRY_DIR}

## License

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. You may obtain a copy of the License at: http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

See the License for the specific language governing permissions and limitations under the License.

**This is not an official Google product but sample code provided for an educational purpose**