Skip to content

Commit

Permalink
Initial implementation of gpu_embedding.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed May 25, 2023
1 parent 5e47f2b commit 9ae9192
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 1 deletion.
27 changes: 27 additions & 0 deletions tensorflow/python/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ py_strict_library(
":tpu_embedding",
":tpu_embedding_for_serving",
":tpu_embedding_v1",
":gpu_embedding",
":tpu_embedding_v2",
":tpu_embedding_v2_utils",
":tpu_hardware_feature",
Expand Down Expand Up @@ -842,6 +843,7 @@ tf_py_strict_test(
srcs_version = "PY3",
deps = [
":tpu_embedding_for_serving",
":gpu_embedding",
":tpu_embedding_v2_utils",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python/compat:v2_compat",
Expand Down Expand Up @@ -879,6 +881,31 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "gpu_embedding",
srcs = ["gpu_embedding.py"],
srcs_version = "PY3",
deps = [
":tpu_embedding_base",
":tpu_embedding_v2_utils",
":tpu_py",
":tpu_replication",
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/util:nest",
"//tensorflow/python/util:tf_export",
],
)

tf_py_strict_test(
name = "tpu_embedding_v2_utils_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/tpu/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu import tpu_embedding_for_serving
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import gpu_embedding
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_hardware_feature
Expand Down
137 changes: 137 additions & 0 deletions tensorflow/python/tpu/gpu_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mid level API for TPU Embeddings without Embedding Accelerator."""

from typing import Any, Dict, Iterable, Optional, Text, Union

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import tpu_embedding_base
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


@tf_export("tpu.experimental.embedding.GPUEmbedding")
class GPUEmbedding(tpu_embedding_v1.TPUEmbeddingV0):
"""The GPUEmbedding mid level API running on TPU without Embedding accelerator.
NOTE: This mid level API is not intended for large embedding table lookup.
Embedding tables will be replicated across devices rather than sharding
across them. To do large embedding table lookup, please use the
`tpu.experimental.embedding.GPUEmbedding` class. This class is an alternative
way to do embedding lookups when the TPU doesn't support any version of
embedding feature. See
`tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
explanation.
This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError
will be raised.
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.GPUEmbedding(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the lookup
operation a special input option must be specified:
```python
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```
Below is an example of a training and evaluation step:
```python
optimizer = tf.keras.optimizers.SGD(0.1)
@tf.function
def training_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
with tf.GradientTape() as tape:
tape.watch(embedding.embedding_table.values())
activation = embedding(embedding_features)
model_output = model(activations)
loss = ... # some function of labels and model_output
embedding_gradients = tape.gradient(loss,
embedding.embedding_table.values())
optimizer.apply_gradients(list(zip(gradients,
mid_level_api.embedding_tables.values())))
# Insert your model gradient and optimizer application here
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
@tf.function
def evalution_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
activations = embedding(embedding_features)
model_output = model(activations)
# Insert your evaluation code here.
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
```
NOTE: The optimizer used here is a Keras optimizer. In order to make the slot
variable creation stay consistent between Keras optimizers and
embedding optimizers, the `slot_variable_creation_fn` argument of the
embedding optimizers has to be passed with the Keras `add_slot` function. Also
note that the slot names might be slightly different between them.
```python
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1)
def slot_variable_creation_fn(table, slot_names, slot_initializers):
slots = {}
for slot, initializer in zip(slot_names, slot_initializers):
slots[slot] = optimizer.add_slot(table, slot, initializer)
return slots
embedding_optimizer = tf.experimental.embedding.Adagrad(
learning_rate=0.1,
slot_variable_creation_fn=slot_variable_creation_fn)
# Use the embedding optimizer to create mid level api and keras optimizer to
# apply gradients.
```
"""

def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
super(tpu_embedding_v1.TPUEmbeddingV0, self).__init__(feature_config, optimizer)
self._strategy = distribute_lib.get_strategy()
self._built = False

5 changes: 4 additions & 1 deletion tensorflow/python/tpu/tpu_hardware_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class EmbeddingFeature(enum.Enum):
UNSUPPORTED = "UNSUPPORTED"
V1 = "V1"
V2 = "V2"
GPU = "GPU"

@classmethod
def _embedding_feature_proto_to_string(cls, embedding_feature_proto):
Expand All @@ -56,7 +57,9 @@ def _embedding_feature_proto_to_string(cls, embedding_feature_proto):
topology_pb2.TPUHardwareFeature.EmbeddingFeature.V1:
HardwareFeature.EmbeddingFeature.V1,
topology_pb2.TPUHardwareFeature.EmbeddingFeature.V2:
HardwareFeature.EmbeddingFeature.V2
HardwareFeature.EmbeddingFeature.V2,
topology_pb2.TPUHardwareFeature.EmbeddingFeature.GPU:
HardwareFeature.EmbeddingFeature.GPU
}
return embedding_feature_proto_to_string_map.get(
embedding_feature_proto, HardwareFeature.EmbeddingFeature.UNSUPPORTED)
Expand Down

0 comments on commit 9ae9192

Please sign in to comment.