Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of gpu_embedding. #2114

Open
wants to merge 2 commits into
base: develop-upstream
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tensorflow/python/distribute/mirrored_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,20 +496,32 @@ def _get_variable_creator_initial_value(self,
primary_var,
**kwargs):
"""Return the initial value for variables on a replica."""
print('MS _get_variable_creator_initial_value')
if replica_id == 0:
print('replica_id == 0')
return kwargs["initial_value"]
else:
print('replica_id != 0')
assert primary_var is not None
assert device is not None
assert kwargs is not None

print('----------- 1')
print(primary_var)
print('----------- 2')
print(primary_var.value())
print('----------- 3')
def initial_value_fn():
if context.executing_eagerly() or ops.inside_function():
init_value = primary_var.value()
print('MS initial_value_fn 1')
print(init_value)
return array_ops.identity(init_value)
else:
with ops.device(device):
init_value = primary_var.initial_value
print('MS initial_value_fn 2')
print(init_value)
return array_ops.identity(init_value)

return initial_value_fn
Expand All @@ -526,6 +538,7 @@ def _create_variable(self, next_creator, **kwargs):
devices = colocate_with._devices # pylint: disable=protected-access

def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
print('_real_mirrored_creator')
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
Expand All @@ -547,6 +560,8 @@ def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
with record.stop_recording():
v = next_creator(**kwargs)
assert not isinstance(v, values.DistributedVariable)
print('_real_mirrored_creator 7')
print(v)
value_list.append(v)
return value_list

Expand Down
7 changes: 5 additions & 2 deletions tensorflow/python/ops/variable_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def __init__(
shape and `validate_shape` is `True`.
RuntimeError: If eager execution is enabled.
"""

print("Variable VariableV1")
SaveSliceInfo = variables.Variable.SaveSliceInfo

def initialized_value(self):
Expand Down Expand Up @@ -294,12 +294,15 @@ def _variable_call(
**kwargs,
):
"""VariableV1 class getter. Useful to force the signature."""
print("Var init v1")
if cls is not VariableV1:
return None
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for _, getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = variables._make_getter(getter, previous_getter) # pylint: disable=protected-access


print("Var init v1 2")
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
if aggregation is None:
aggregation = variables.VariableAggregation.NONE
Expand All @@ -321,5 +324,5 @@ def _variable_call(
collections=collections,
use_resource=use_resource,
)

print("Var init v1 3")
variable_scope.set_variable_v1(VariableV1)
1 change: 1 addition & 0 deletions tensorflow/python/ops/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ def __init__(self, name, shape, dtype, variable_list, partitions):
ValueError: If `variable_list` is empty, or the `Variable` shape
information does not match `shape`, or `partitions` has invalid values.
"""
print('PartitionedVariable')
if not isinstance(variable_list, (list, tuple)):
raise TypeError("variable_list is not a list or tuple: %s" %
variable_list)
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/python/tpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ py_strict_library(
":tpu_embedding",
":tpu_embedding_for_serving",
":tpu_embedding_v1",
":gpu_embedding",
":tpu_embedding_v2",
":tpu_embedding_v2_utils",
":tpu_hardware_feature",
Expand Down Expand Up @@ -842,6 +843,7 @@ tf_py_strict_test(
srcs_version = "PY3",
deps = [
":tpu_embedding_for_serving",
":gpu_embedding",
":tpu_embedding_v2_utils",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python/compat:v2_compat",
Expand Down Expand Up @@ -879,6 +881,31 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "gpu_embedding",
srcs = ["gpu_embedding.py"],
srcs_version = "PY3",
deps = [
":tpu_embedding_base",
":tpu_embedding_v2_utils",
":tpu_py",
":tpu_replication",
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:tpu_strategy",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:ops",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/util:nest",
"//tensorflow/python/util:tf_export",
],
)

tf_py_strict_test(
name = "tpu_embedding_v2_utils_test",
srcs = [
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/tpu/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.tpu import tpu_embedding
from tensorflow.python.tpu import tpu_embedding_for_serving
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import gpu_embedding
from tensorflow.python.tpu import tpu_embedding_v2
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_hardware_feature
Expand Down
19 changes: 0 additions & 19 deletions tensorflow/python/tpu/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental TPU library."""

# pylint: disable=unused-import
from tensorflow.python.tpu import tpu_strategy_util
# pylint: enable=unused-import
120 changes: 120 additions & 0 deletions tensorflow/python/tpu/gpu_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""GPU embeddings API."""

from typing import Any, Dict, Iterable, Optional, Text, Union

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import tpu_embedding_base
from tensorflow.python.tpu import tpu_embedding_v1
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export


@tf_export("tpu.experimental.embedding.GPUEmbeddingV0")
class GPUEmbeddingV0(tpu_embedding_v1.TPUEmbeddingV0):
"""GPU embeddings API.

NOTE: This API is not intended for large embedding table lookup.
Embedding tables will be replicated across devices rather than sharding
across them. To do large embedding table lookup, please use the
`tpu.experimental.embedding.GPUEmbedding` class. This class is an alternative
way to do embedding lookups on GPUs. See
`tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed
explanation.

This class has to be created under the `MirroredStrategy`, Otherwise a RuntimeError
will be raised.
```python
strategy = tf.distribute.MirroredStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.GPUEmbeddingV0(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the lookup
operation a special input option must be specified:

```python
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```

Below is an example of a training and evaluation step:

```python
optimizer = tf.keras.optimizers.SGD(0.1)

@tf.function
def training_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
with tf.GradientTape() as tape:
tape.watch(embedding.embedding_table.values())
activation = embedding(embedding_features)
model_output = model(activations)
loss = ... # some function of labels and model_output

embedding_gradients = tape.gradient(loss,
embedding.embedding_table.values())
optimizer.apply_gradients(list(zip(gradients,
mid_level_api.embedding_tables.values())))
# Insert your model gradient and optimizer application here

for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))

@tf.function
def evalution_step(dataset_iterator, num_steps):
def tpu_step(embedding_features):
activations = embedding(embedding_features)
model_output = model(activations)
# Insert your evaluation code here.

for _ in tf.range(num_steps):
strategy.run(tpu_step, args=(next(dataset_iterator), ))
```
"""

def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access
print ('__init__ GPUEmbedding')
super(tpu_embedding_v1.TPUEmbeddingV0, self).__init__(feature_config, optimizer)
self._strategy = distribute_lib.get_strategy()
self._built = False
if not isinstance(self._strategy,
(mirrored_strategy.MirroredStrategy)):
raise RuntimeError(
"GPUEmbeddingV0 should be created under MirroredStrategy but found {}."
.format(self._strategy))
self._built = False
60 changes: 60 additions & 0 deletions tensorflow/python/tpu/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,11 @@ tpu_py_strict_test(
"//tensorflow/python/eager:def_function",
"//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/tpu:tpu_embedding_for_serving",
"//tensorflow/python/tpu:tpu_embedding_v1",
"//tensorflow/python/tpu:tpu_embedding_v2_utils",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
Expand All @@ -530,3 +533,60 @@ tpu_py_strict_test(
"@absl_py//absl/testing:parameterized",
],
)


pytype_strict_library(
name = "gpu_embedding_base_test",
srcs = ["gpu_embedding_base_test.py"],
srcs_version = "PY3",
deps = [
":tpu_embedding_base_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/eager:remote",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:sparse_tensor",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/tpu:gpu_embedding",
"//tensorflow/python/tpu:tpu_embedding_v2_utils",
"//tensorflow/python/tpu:tpu_strategy_util",
"//tensorflow/python/util:nest",
"//third_party/py/numpy",
"@absl_py//absl/flags",
"@absl_py//absl/testing:parameterized",
],
)


tpu_py_strict_test(
name = "gpu_embedding_v1_correctness_test",
srcs = [
"gpu_embedding_v1_correctness_test.py",
],
disable_mlir_bridge = False,
python_version = "PY3",
srcs_version = "PY3",
tags = ["no_oss"],
deps = [
":gpu_embedding_base_test",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python/platform:client_testlib",
"//tensorflow/python/tpu:gpu_embedding",
"//tensorflow/python/tpu:tpu_embedding_v1",
"//tensorflow/python/tpu:tpu_embedding_v2_utils",
"//tensorflow/python/saved_model:load",
"//tensorflow/python/saved_model:save",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
Loading