Skip to content

Commit

Permalink
Initial implementation of gpu_embedding.
Browse files Browse the repository at this point in the history
  • Loading branch information
zoranjovanovic-ns committed May 30, 2023
1 parent c2bd866 commit 7f0c638
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 1 deletion.
27 changes: 27 additions & 0 deletions tensorflow/python/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ py_strict_library(
":tpu_embedding",
":tpu_embedding_for_serving",
":tpu_embedding_v1",
":gpu_embedding",
":tpu_embedding_v2",
":tpu_embedding_v2_utils",
":tpu_hardware_feature",
Expand Down Expand Up @@ -842,6 +843,7 @@ tf_py_strict_test(
srcs_version = "PY3",
deps = [
":tpu_embedding_for_serving",
":gpu_embedding",
":tpu_embedding_v2_utils",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python/compat:v2_compat",
Expand Down Expand Up @@ -879,6 +881,31 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "gpu_embedding",
srcs = ["gpu_embedding.py"],
srcs_version = "PY3",
deps = [
":tpu_embedding_base",
":tpu_embedding_v2_utils",
":tpu_py",
":tpu_replication",
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/util:nest",
"//tensorflow/python/util:tf_export",
],
)

tf_py_strict_test(
name = "tpu_embedding_v2_utils_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/tpu/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu import tpu_embedding_for_serving
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import gpu_embedding
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_hardware_feature
Expand Down
114 changes: 114 additions & 0 deletions tensorflow/python/tpu/gpu_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mid level API for TPU Embeddings without Embedding Accelerator."""

from typing import Any, Dict, Iterable, Optional, Text, Union

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import tpu_embedding_base
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


@tf_export("tpu.experimental.embedding.GPUEmbedding")
class GPUEmbeddingV0(tpu_embedding_v1.TPUEmbeddingV0):
"""The GPUEmbedding mid level API running on TPU without Embedding accelerator.
NOTE: This mid level API is not intended for large embedding table lookup.
Embedding tables will be replicated across devices rather than sharding
across them. To do large embedding table lookup, please use the
`tpu.experimental.embedding.GPUEmbedding` class. This class is an alternative
way to do embedding lookups when the TPU doesn't support any version of
embedding feature. See
`tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
explanation.
This class has to be created under the `MirroredStrategy`, Otherwise a RuntimeError
will be raised.
```python
strategy = tf.distribute.MirroredStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.GPUEmbeddingV0(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the lookup
operation a special input option must be specified:
```python
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```
Below is an example of a training and evaluation step:
```python
optimizer = tf.keras.optimizers.SGD(0.1)
@tf.function
def training_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
with tf.GradientTape() as tape:
tape.watch(embedding.embedding_table.values())
activation = embedding(embedding_features)
model_output = model(activations)
loss = ... # some function of labels and model_output
embedding_gradients = tape.gradient(loss,
embedding.embedding_table.values())
optimizer.apply_gradients(list(zip(gradients,
mid_level_api.embedding_tables.values())))
# Insert your model gradient and optimizer application here
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
@tf.function
def evalution_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
activations = embedding(embedding_features)
model_output = model(activations)
# Insert your evaluation code here.
for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
```
"""

def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
super(tpu_embedding_v1.TPUEmbeddingV0, self).__init__(feature_config, optimizer)
self._strategy = distribute_lib.get_strategy()
self._built = False

5 changes: 4 additions & 1 deletion tensorflow/python/tpu/tpu_hardware_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class EmbeddingFeature(enum.Enum):
UNSUPPORTED = "UNSUPPORTED"
V1 = "V1"
V2 = "V2"
GPU = "GPU"

@classmethod
def _embedding_feature_proto_to_string(cls, embedding_feature_proto):
Expand All @@ -56,7 +57,9 @@ def _embedding_feature_proto_to_string(cls, embedding_feature_proto):
topology_pb2.TPUHardwareFeature.EmbeddingFeature.V1:
HardwareFeature.EmbeddingFeature.V1,
topology_pb2.TPUHardwareFeature.EmbeddingFeature.V2:
HardwareFeature.EmbeddingFeature.V2
HardwareFeature.EmbeddingFeature.V2,
topology_pb2.TPUHardwareFeature.EmbeddingFeature.GPU:
HardwareFeature.EmbeddingFeature.GPU
}
return embedding_feature_proto_to_string_map.get(
embedding_feature_proto, HardwareFeature.EmbeddingFeature.UNSUPPORTED)
Expand Down

0 comments on commit 7f0c638

Please sign in to comment.