<a href="https://colab.research.google.com/github/Saber-Hosseinzade/TensorFlowExtended_TFX/blob/main/TFX_Tuner_and_Trainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Install packages


In [None]:
# Restart the RunTime after completion of this section
!pip install tfx==1.2

In [None]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds

import os
import pprint

from tfx.components import ImportExampleGen
from tfx.components import ExampleValidator
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Transform
from tfx.components import Tuner
from tfx.components import Trainer

from tfx.proto import example_gen_pb2
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

#### Fashion MNIST dataset

In [None]:
pipe_dir = './pipeline/'
dataset_dir = './dataset'
temp_dir = './temp'

!rm -r {pipe_dir}         # remove directory if exists from previous run
!rm -r {dataset_dir}      # remove directory if exists from previous run
!rm -r {temp_dir}         # remove directory if exists from previous run

!mkdir {dataset_dir}      # create directory
!mkdir {pipe_dir}         # create directory

data = tfds.load('fashion_mnist', data_dir=temp_dir)

tfds_data_path = './temp/fashion_mnist/3.0.1/fashion_mnist-train.tfrecord-00000-of-00001'
!cp {tfds_data_path} {dataset_dir}

#### Create Interactive Context

In [None]:
# Initialize the InteractiveContext
context = InteractiveContext(pipeline_root=pipe_dir)

#### Create ExampleGen


In [None]:
output = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=8),
        example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),
    ]))

# Ingest the data through ExampleGen
example_gen = ImportExampleGen(input_base=dataset_dir, output_config=output)

# Run the component
context.run(example_gen)

In [None]:
# Print split names and URI
artifact = example_gen.outputs['examples'].get()[0]
print(artifact.split_names, artifact.uri)

#### Create StatisticsGen

In [None]:
statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples'])

context.run(statistics_gen)

#### Create SchemaGen

In [None]:
# Run SchemaGen
schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen)

In [None]:
# Visualize the results
context.show(schema_gen.outputs['schema'])

#### Create ExampleValidator to detect anomalies

In [None]:
# Run ExampleValidator
example_validator = ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator)

In [None]:
# Visualize the results. There should be no anomalies.
context.show(example_validator.outputs['anomalies'])

### Transform


In [None]:
transform_path = 'transform.py'

In [None]:
%%writefile {transform_path}

import tensorflow as tf
import tensorflow_transform as tft

def image_fn(input):

    output = tf.image.decode_image(input, channels=1)
    output = tf.reshape(output, (28, 28, 1))
    output = tf.cast(output, tf.float32)
    return output

def preprocessing_fn(inputs):

    outputs = {
        'image_xf':
            tf.map_fn(
                fn = image_fn,
                elems = tf.squeeze(inputs['image'], axis=1),
                fn_output_signature = tf.float32),
        'label_xf':
            tf.map_fn(
                fn = lambda x : tf.cast(x, tf.float32),
                elems = inputs['label'],
                fn_output_signature = tf.float32)
    }
    
    # scale the pixels from 0 to 1
    outputs['image_xf'] = tft.scale_to_0_1(outputs['image_xf'])
    
    return outputs

In [None]:

transform = Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(transform_path))

context.run(transform)

#### Tuner



In [None]:
# Declare name of module file
tuner_path = 'tuner.py'

In [None]:
%%writefile {tuner_path}

# Define imports
from kerastuner.engine import base_tuner
import kerastuner as kt
from tensorflow import keras
from typing import NamedTuple, Dict, Text, Any, List
from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor
import tensorflow as tf
import tensorflow_transform as tft

# Declare namedtuple field names
TunerFnResult = NamedTuple('TunerFnResult', [('tuner', base_tuner.BaseTuner),
                                             ('fit_kwargs', Dict[Text, Any])])

# Callback for the search strategy
stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)


def reader_fn(input):

  return tf.data.TFRecordDataset(input, compression_type='GZIP')
  

def _input_fn(file_pattern, tf_transform_output, num_epochs=None, batch_size=32):
 
  # Get feature specification based on transform output
  transformed_feature_spec = tf_transform_output.transformed_feature_spec().copy()
  
  # Create batches of features and labels
  dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=transformed_feature_spec,
      reader=reader_fn,
      num_epochs=num_epochs,
      label_key='label_xf')
  
  return dataset


def model_builder(hp):


  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28, 1)))

  hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))

  model.add(keras.layers.Dropout(0.2))
  model.add(keras.layers.Dense(10, activation='softmax'))

  hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(),
                metrics=['accuracy'])

  return model

def tuner_fn(fn_args):

  # Define tuner search strategy
  tuner = kt.Hyperband(model_builder,
                     objective='val_accuracy',
                     max_epochs=10,
                     factor=3,
                     directory=fn_args.working_dir,
                     project_name='kt_hyperband')

  # Load transform output
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)

  # Use _input_fn() to extract input features and labels from the train and val set
  train_set = _input_fn(fn_args.train_files[0], tf_transform_output)
  val_set = _input_fn(fn_args.eval_files[0], tf_transform_output)


  return TunerFnResult(
      tuner=tuner,
      fit_kwargs={ 
          "callbacks":[stop_early],
          'x': train_set,
          'validation_data': val_set,
          'steps_per_epoch': fn_args.train_steps,
          'validation_steps': fn_args.eval_steps
      }
  )

In [None]:
from tfx.proto import trainer_pb2

# Setup the Tuner component
tuner = Tuner(
    module_file=tuner_path,
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(splits=['train'], num_steps=500),
    eval_args=trainer_pb2.EvalArgs(splits=['eval'], num_steps=100)
    )

In [None]:
# Run the component. This will take around 10 minutes to run.
context.run(tuner, enable_cache=False)

#### Trainer


In [None]:
trainer_path = 'trainer.py'

In [None]:
%%writefile {trainer_path}

from tensorflow import keras
from typing import NamedTuple, Dict, Text, Any, List
from tfx.components.trainer.fn_args_utils import FnArgs, DataAccessor
import tensorflow as tf
import tensorflow_transform as tft

def reader_fn(filenames):

  return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
  

def input_fn(file_pattern, tf_transform_output, num_epochs=None, batch_size=32):

  transformed_feature_spec = (
      tf_transform_output.transformed_feature_spec().copy())
  
  dataset = tf.data.experimental.make_batched_features_dataset(
      file_pattern=file_pattern,
      batch_size=batch_size,
      features=transformed_feature_spec,
      reader= reader_fn,
      num_epochs=num_epochs,
      label_key='label_xf')
  
  return dataset


def model_builder(hp):

  model = keras.Sequential()
  model.add(keras.layers.Flatten(input_shape=(28, 28, 1)))

  # Get the number of units from the Tuner results
  hp_units = hp.get('units')
  model.add(keras.layers.Dense(units=hp_units, activation='relu'))

  model.add(keras.layers.Dropout(0.2))
  model.add(keras.layers.Dense(10, activation='softmax'))

  # Get the learning rate from the Tuner results
  hp_learning_rate = hp.get('learning_rate')

  model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
                loss=keras.losses.SparseCategoricalCrossentropy(),
                metrics=['accuracy'])

  # Print the model summary
  model.summary()
  
  return model


def run_fn(fn_args):

  # Callback for TensorBoard
  tensorboard_callback = tf.keras.callbacks.TensorBoard(
      log_dir=fn_args.model_run_dir, update_freq='batch')
  
  # Load transform output
  tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)
  
  # Create batches of data good for 10 epochs
  train_set = input_fn(fn_args.train_files[0], tf_transform_output, 10)
  val_set = input_fn(fn_args.eval_files[0], tf_transform_output, 10)

  # Load best hyperparameters
  hp = fn_args.hyperparameters.get('values')

  # Build the model
  model = model_builder(hp)

  # Train the model
  model.fit(
      x=train_set,
      validation_data=val_set,
      callbacks=[tensorboard_callback]
      )
  
  # Save the model
  model.save(fn_args.serving_model_dir, save_format='tf')

In [None]:
# Setup the Trainer component
trainer = Trainer(
    module_file=trainer_path,
    examples=transform.outputs['transformed_examples'],
    hyperparameters=tuner.outputs['best_hyperparameters'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=trainer_pb2.TrainArgs(splits=['train']),
    eval_args=trainer_pb2.EvalArgs(splits=['eval']))

In [None]:
# Run the component
context.run(trainer, enable_cache=False)

#### Tensorboard

In [None]:
model_run_artifact_dir = trainer.outputs['model_run'].get()[0].uri

%load_ext tensorboard
%tensorboard --logdir {model_run_artifact_dir}