In [1]:
pip install -q -U tensorflow_transform

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m447.8/447.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.7/14.7 MB[0m [31m60.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.5/22.5 MB[0m [31m41.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.7/89.7 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.7/138.7 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m152.0/152.0 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m51.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━

In [2]:
import pathlib
import pprint
import tempfile

import tensorflow as tf
import tensorflow_transform as tft

import tensorflow_transform.beam as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import schema_utils

In [3]:
raw_data = [
    {'x':1, 'y':1, 's':'hello'},
    {'x':2, 'y':2, 's':'world'},
    {'x':3, 'y':3, 's':'hello'}
]

raw_data_metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec({
        'y': tf.io.FixedLenFeature([],tf.float32),
        'x': tf.io.FixedLenFeature([], tf.float32),
        's': tf.io.FixedLenFeature([], tf.string)
    })
)

In [4]:
def preprocessing_fn(inputs):
  x=inputs['x']
  y=inputs['y']
  s=inputs['s']
  x_centered=x-tft.mean(x)
  y_normalized=tft.scale_to_0_1(y)
  s_integerized=tft.compute_and_apply_vocabulary(s)
  x_centered_times_y_normalized = (x_centered*y_normalized)

  return {
      'x_centered':x_centered,
      'y_normalized':y_normalized,
      's_integerized':s_integerized,
      'x_centered_times_y_normalized':x_centered_times_y_normalized
  }

In [5]:
def main(output_dir):
  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (
        (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
            preprocessing_fn))

  transformed_data, transformed_metadata = transformed_dataset

  _ = (
      transform_fn
      | 'WriteTransformFn' >> tft_beam.WriteTransformFn(output_dir))

  return transformed_data, transformed_metadata

In [6]:
output_dir = pathlib.Path(tempfile.mkdtemp())

transformed_data, transformed_metadata = main(str(output_dir))

print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))






Raw data:
[{'s': 'hello', 'x': 1, 'y': 1},
 {'s': 'world', 'x': 2, 'y': 2},
 {'s': 'hello', 'x': 3, 'y': 3}]

Transformed data:
[{'s_integerized': 0,
  'x_centered': -1.0,
  'x_centered_times_y_normalized': -0.0,
  'y_normalized': 0.0},
 {'s_integerized': 1,
  'x_centered': 0.0,
  'x_centered_times_y_normalized': 0.0,
  'y_normalized': 0.5},
 {'s_integerized': 0,
  'x_centered': 1.0,
  'x_centered_times_y_normalized': 1.0,
  'y_normalized': 1.0}]


In [8]:
ls -l {output_dir}

total 8
drwxr-xr-x 2 root root 4096 Dec 19 14:37 [0m[01;34mtransformed_metadata[0m/
drwxr-xr-x 4 root root 4096 Dec 19 14:37 [01;34mtransform_fn[0m/


In [12]:
tf_transform_output = tft.TFTransformOutput(output_dir)

tft_layer = tf_transform_output.transform_features_layer()
tft_layer

<tensorflow_transform.output_wrapper.TransformFeaturesLayer at 0x7fb7221a9ea0>

In [14]:
raw_data_batch = {
    's': tf.constant([ex['s'] for ex in raw_data]),
    'x': tf.constant([ex['x'] for ex in raw_data], dtype=tf.float32),
    'y': tf.constant([ex['y'] for ex in raw_data], dtype=tf.float32),
}

In [15]:
transformed_batch = tft_layer(raw_data_batch)

{key: value.numpy() for key, value in transformed_batch.items()}

{'x_centered': array([-1.,  0.,  1.], dtype=float32),
 'y_normalized': array([0. , 0.5, 1. ], dtype=float32),
 's_integerized': array([0, 1, 0]),
 'x_centered_times_y_normalized': array([-0.,  0.,  1.], dtype=float32)}

In [17]:
class StackDict(tf.keras.layers.Layer):
  def call(self, inputs):
    values = [
        tf.cast(v, tf.float32)
        for k, v in sorted(inputs.items(), key=lambda kv: kv[0])]
    return tf.stack(values, axis=1)

In [20]:
class TrainedModel(tf.keras.Model):
  def __init__(self):
    super().__init__(self)
    self.concat = StackDict()
    self.body = tf.keras.Sequential([
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])

  def call(self, inputs, training=None):
    x = self.concat(inputs)
    return self.body(x, training)

In [21]:
trained_model = TrainedModel()

In [None]:
trained_model.compile(loss=..., optimizer='adam')
trained_model.fit(...)

In [25]:
trained_model_output=trained_model(transformed_batch)
trained_model_output.shape

TensorShape([3, 10])

In [26]:
class ExportModel(tf.Module):
  def __init__(self,trained_model,input_transform):
    self.trained_model=trained_model
    self.input_transform=input_transform

  @tf.function
  def __call__(self,inputs,training=None):
    x=self.input_transform(inputs)
    return self.trained_model(x)

In [27]:
export_model = ExportModel(trained_model=trained_model,
                           input_transform=tft_layer)

In [28]:
export_model_output=export_model(raw_data_batch)
export_model_output.shape

TensorShape([3, 10])

In [29]:
tf.reduce_max(abs(export_model_output - trained_model_output)).numpy()

0.0

In [30]:
model_dir=tempfile.mkdtemp(suffix='tft')
tf.saved_model.save(export_model,model_dir)

In [32]:
reloaded = tf.saved_model.load(model_dir)

reloaded_model_output = reloaded(raw_data_batch)
reloaded_model_output.shape

TensorShape([3, 10])

In [33]:
tf.reduce_max(abs(export_model_output - reloaded_model_output)).numpy()

0.0