forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation of gpu_embedding.
- Loading branch information
1 parent
5e47f2b
commit 9ae9192
Showing
4 changed files
with
169 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright 2022 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Mid level API for TPU Embeddings without Embedding Accelerator.""" | ||
|
||
from typing import Any, Dict, Iterable, Optional, Text, Union | ||
|
||
from tensorflow.python.distribute import distribute_lib | ||
from tensorflow.python.distribute import tpu_strategy | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.framework import sparse_tensor | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops import embedding_ops | ||
from tensorflow.python.ops import math_ops | ||
from tensorflow.python.ops import sparse_ops | ||
from tensorflow.python.ops import variables as tf_variables | ||
from tensorflow.python.ops.ragged import ragged_tensor | ||
from tensorflow.python.tpu import tpu_embedding_base | ||
from tensorflow.python.tpu import tpu_embedding_v1 | ||
from tensorflow.python.tpu import tpu_embedding_v2_utils | ||
from tensorflow.python.tpu import tpu_replication | ||
from tensorflow.python.util import nest | ||
from tensorflow.python.util.tf_export import tf_export | ||
|
||
|
||
@tf_export("tpu.experimental.embedding.GPUEmbedding") | ||
class GPUEmbedding(tpu_embedding_v1.TPUEmbeddingV0): | ||
"""The GPUEmbedding mid level API running on TPU without Embedding accelerator. | ||
NOTE: This mid level API is not intended for large embedding table lookup. | ||
Embedding tables will be replicated across devices rather than sharding | ||
across them. To do large embedding table lookup, please use the | ||
`tpu.experimental.embedding.GPUEmbedding` class. This class is an alternative | ||
way to do embedding lookups when the TPU doesn't support any version of | ||
embedding feature. See | ||
`tpu.experimental.tpu_hardware_feature.embedding_feature` for a detailed | ||
explanation. | ||
This class has to be created under the `TPUStrategy`, Otherwise a RuntimeError | ||
will be raised. | ||
```python | ||
strategy = tf.distribute.TPUStrategy(...) | ||
with strategy.scope(): | ||
embedding = tf.tpu.experimental.embedding.GPUEmbedding( | ||
feature_config=feature_config, | ||
optimizer=tf.tpu.experimental.embedding.SGD(0.1)) | ||
``` | ||
When creating a distributed dataset that is to be passed to the lookup | ||
operation a special input option must be specified: | ||
```python | ||
distributed_dataset = ( | ||
strategy.distribute_datasets_from_function( | ||
dataset_fn=..., | ||
options=tf.distribute.InputOptions( | ||
experimental_fetch_to_device=False)) | ||
dataset_iterator = iter(distributed_dataset) | ||
``` | ||
Below is an example of a training and evaluation step: | ||
```python | ||
optimizer = tf.keras.optimizers.SGD(0.1) | ||
@tf.function | ||
def training_step(dataset_iterator, num_steps): | ||
def tpu_step(embedding_features): | ||
with tf.GradientTape() as tape: | ||
tape.watch(embedding.embedding_table.values()) | ||
activation = embedding(embedding_features) | ||
model_output = model(activations) | ||
loss = ... # some function of labels and model_output | ||
embedding_gradients = tape.gradient(loss, | ||
embedding.embedding_table.values()) | ||
optimizer.apply_gradients(list(zip(gradients, | ||
mid_level_api.embedding_tables.values()))) | ||
# Insert your model gradient and optimizer application here | ||
for _ in tf.range(num_steps): | ||
strategy.run(tpu_step, args=(next(dataset_iterator), )) | ||
@tf.function | ||
def evalution_step(dataset_iterator, num_steps): | ||
def tpu_step(embedding_features): | ||
activations = embedding(embedding_features) | ||
model_output = model(activations) | ||
# Insert your evaluation code here. | ||
for _ in tf.range(num_steps): | ||
strategy.run(tpu_step, args=(next(dataset_iterator), )) | ||
``` | ||
NOTE: The optimizer used here is a Keras optimizer. In order to make the slot | ||
variable creation stay consistent between Keras optimizers and | ||
embedding optimizers, the `slot_variable_creation_fn` argument of the | ||
embedding optimizers has to be passed with the Keras `add_slot` function. Also | ||
note that the slot names might be slightly different between them. | ||
```python | ||
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.1) | ||
def slot_variable_creation_fn(table, slot_names, slot_initializers): | ||
slots = {} | ||
for slot, initializer in zip(slot_names, slot_initializers): | ||
slots[slot] = optimizer.add_slot(table, slot, initializer) | ||
return slots | ||
embedding_optimizer = tf.experimental.embedding.Adagrad( | ||
learning_rate=0.1, | ||
slot_variable_creation_fn=slot_variable_creation_fn) | ||
# Use the embedding optimizer to create mid level api and keras optimizer to | ||
# apply gradients. | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic | ||
optimizer: Optional[tpu_embedding_v2_utils._Optimizer]): # pylint:disable=protected-access | ||
super(tpu_embedding_v1.TPUEmbeddingV0, self).__init__(feature_config, optimizer) | ||
self._strategy = distribute_lib.get_strategy() | ||
self._built = False | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters