Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feature/patch/softmax cross entropy with logits #21

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
10 changes: 5 additions & 5 deletions tfdeterminism/enable_determinism.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import tensorflow as tf

from .patch import _patch_bias_add
from .patch import _patch_bias_add, _patch_fused_softmax_cross_entropy
from .utils import _Version as Version

def _enable_determinism(seed=None):
Expand All @@ -31,7 +31,7 @@ def _enable_determinism(seed=None):
Call this method either before or after explicitly importing TensorFlow,
but always before constructing any graphs.

This function cannot address all possible sources of non-determinism. Please
This function cannot address all possible sources of non-determinism. Please
see further instructions at https://github.com/NVIDIA/tensorflow-determinism
to understand how to use it in a larger deterministic context.

Expand All @@ -52,7 +52,7 @@ def _enable_determinism(seed=None):
_patch_bias_add()
if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('2.1'):
os.environ['TF_DETERMINISTIC_OPS'] = '1'
# TODO: Add patch crossentropy here as well? Issue seems to still be present on tf 2.1, 2.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is the condition for setting TF_DETERMINISTIC_OPS=1: NGC containers with version >= 19.06 or stock TensorFlow with version >= 2.1.

The condition below will ensure that the fused softmax/cross-entropy patch is applied to NGC containers with version >= 19.06 or stock TensorFlow with version >= 1.14 (which includes versions 2.1 and 2.2).

if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('1.14'):
# Apply the fused softmax/cross-entropy patch here
pass
# TODO: Add other recipe items
_patch_fused_softmax_cross_entropy()
# TODO: Add other recipe items
114 changes: 113 additions & 1 deletion tfdeterminism/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,20 @@ def _patch():
if re.match("(1\.(14|15)|2\.0)", tf_version):
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
_patch_bias_add()
# Apply the fused softmax/cross-entropy patch here
_patch_fused_softmax_cross_entropy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll push changes that rough-out enable_determinism and then you can call this from the appropriate part of that.

print("TensorFlow version %s has been patched "
"using tfdeterminism version %s" %
(tf_version, __version__), file=sys.stderr)
elif re.match("2\.1|2\.2", tf_version):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to deprecate patch, not make it work on newer versions of TensorFlow, and no longer advertise it in the documentation (will be updated before release). If folks want determinism in TensorFlow version 2.1 or 2.2 then they should be using enable_determinism. Not adding this does not break anything because there was no patch available for TensorFlow versions 2.1 and 2.2 before.

_patch_fused_softmax_cross_entropy()
print("TensorFlow version %s has been patched "
"using tfdeterminism version %s" %
(tf_version, __version__), file=sys.stderr)
else:
raise TypeError("tfdeterminism: No patch available "
"for version %s of TensorFlow" % tf_version)


def _patch_bias_add():
tf.nn.bias_add = _new_bias_add_1_14 # access via public API
nn.bias_add = _new_bias_add_1_14 # called from tf.keras.layers.convolutional.Conv
Expand Down Expand Up @@ -136,3 +142,109 @@ def _new_bias_add_1_14(value, bias, data_format=None, name=None):
value, array_ops.reshape(bias, broadcast_shape), name=name)
else: # data_format == 'NHWC' or data_format == None
return math_ops.add(value, bias, name=name)


def _patch_fused_softmax_cross_entropy():
# Non-sparse
tf.nn.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # access via public API
nn.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # called from tf.keras.layers.convolutional.Conv
nn_ops.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # called from tests

# Sparse
tf.nn.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # access via public API
nn.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # called from tf.keras.layers.convolutional.Conv
nn_ops.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # called from tests

# The original, pre-patched method can be viewed at
# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L3182
def _new_softmax_cross_entropy_with_logits(labels, logits, axis=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
**NOTE:** While the classes are mutually exclusive, their probabilities
need not be. All that is required is that each row of `labels` is
a valid probability distribution. If they are not, the computation of the
gradient will be incorrect.
If using exclusive `labels` (wherein one and only
one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
A common use case is to have logits and labels of shape
`[batch_size, num_classes]`, but higher dimensions are supported, with
the `dim` argument specifying the class dimension.
Backpropagation will happen only into `logits`. To calculate a cross entropy
loss that allows backpropagation into both `logits` and `labels`, see
`tf.nn.softmax_cross_entropy_with_logits_v2`.
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: Each vector along the class dimension should hold a valid
probability distribution e.g. for the case in which labels are of shape
`[batch_size, num_classes]`, each row of `labels[i]` must be a valid
probability distribution.
logits: Per-label activations, typically a linear output. These activation
energies are interpreted as unnormalized log probabilities.
dim: The class dimension. Defaulted to -1 which is the last dimension.
name: A name for the operation (optional).
axis: Alias for dim.
Returns:
A `Tensor` that contains the softmax cross entropy loss. Its type is the
same as `logits` and its shape is the same as `labels` except that it does
not have the last dimension of `labels`.
"""
raise NotImplementedError()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the baby steps. Good job!



# The original, pre-patched method can be viewed at
# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L2628
def _new_sparse_softmax_cross_entropy_with_logits(
_sentinel=None, # pylint: disable=invalid-name
labels=None,
logits=None,
name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
classes are mutually exclusive (each entry is in exactly one class). For
example, each CIFAR-10 image is labeled with one and only one label: an image
can be a dog or a truck, but not both.
**NOTE:** For this operation, the probability of a given label is considered
exclusive. That is, soft classes are not allowed, and the `labels` vector
must provide a single specific index for the true class for each row of
`logits` (each minibatch entry). For soft softmax classification with
a probability distribution for each entry, see
`softmax_cross_entropy_with_logits_v2`.
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
output of `softmax`, as it will produce incorrect results.
A common use case is to have logits of shape
`[batch_size, num_classes]` and have labels of shape
`[batch_size]`, but higher dimensions are supported, in which
case the `dim`-th dimension is assumed to be of size `num_classes`.
`logits` must have the dtype of `float16`, `float32`, or `float64`, and
`labels` must have the dtype of `int32` or `int64`.
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
Args:
_sentinel: Used to prevent positional parameters. Internal, do not use.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of
`labels` and result) and dtype `int32` or `int64`. Each entry in `labels`
must be an index in `[0, num_classes)`. Other values will raise an
exception when this op is run on CPU, and return `NaN` for corresponding
loss and gradient rows on GPU.
logits: Per-label activations (typically a linear output) of shape
`[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or
`float64`. These activation energies are interpreted as unnormalized log
probabilities.
name: A name for the operation (optional).
Returns:
A `Tensor` of the same shape as `labels` and of the same type as `logits`
with the softmax cross entropy loss.
Raises:
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
of the labels is not equal to the rank of the logits minus one.
"""
raise NotImplementedError()