Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward pass of broadcasting on GPU is non-deterministic #2652

Closed
ghost opened this issue Jun 4, 2016 · 10 comments
Closed

Backward pass of broadcasting on GPU is non-deterministic #2652

ghost opened this issue Jun 4, 2016 · 10 comments
Labels
stat:awaiting response Status - Awaiting response from author

Comments

@ghost
Copy link

ghost commented Jun 4, 2016

import tensorflow as tf

def run(on_gpu):
    tf.reset_default_graph()
    tf.set_random_seed(42)
    with tf.device('/gpu:0' if on_gpu else '/cpu:0'):
        a = tf.random_normal([16, 16])
        b = tf.get_variable('b', shape = [], initializer = tf.constant_initializer(value = 0.0))
        c = a*b
        grad = tf.gradients(c, [b], gate_gradients=tf.train.Optimizer.GATE_GRAPH)[0]

    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    grad_val = sess.run(grad)
    return grad_val

for i in xrange(20):
    print repr(run(on_gpu=True)),
print ''
for i in xrange(20):
    print repr(run(on_gpu=False)),

Result:

23.066511 23.066511 23.066513 23.066513 23.066511 23.066513 23.066509 23.066513 23.066513 23.066511 23.066513 23.066511 23.066513 23.066513 23.066509 23.066511 23.066513 23.066513 23.066511 23.066511 
23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509 23.066509

As you can see, consistent result across CPU runs but inconsistent result across GPU runs.

No doubt a CUDA reduction order issue, but it'd be really nice if we can have deterministic reduction. I am using tf 0.8.0 (self-compiled against CuDNN v5). CuDNN version is 5005 (not rc)

@girving
Copy link
Contributor

girving commented Jun 7, 2016

Unfortunately, the reduction ops on GPU use asynchronous atomic adds, and are therefore fundamentally nondeterministic for floating point. Making them deterministic would require either tree-structured reductions or integer math, both significantly slower.

I can leave this open with contributions welcome if you'd like (with an adjusted title), but it'll be a lot of work if someone tries to take it on, and it's unclear how best to make it happen automatically. Even if one added deterministic reductions as an option (either as a separate op or as an attr on the existing ops), we'd need an unpleasant global flag to turn this on when building the backward pass.

@girving girving added the stat:awaiting response Status - Awaiting response from author label Jun 7, 2016
@ghost
Copy link
Author

ghost commented Jun 8, 2016

I can understand if that's the case. Thanks for the response.

@ghost
Copy link
Author

ghost commented Jun 12, 2016

By the way, pure warp-shuffle (shfl_down, or shfl_xor for keep_dim) based block reduction doesn't seem to be that much slower than warp-shuffle+atomic

@girving
Copy link
Contributor

girving commented Jun 12, 2016

@MetaP Do you have a link for that? I don't quite follow, especially the bit about keep_dim since that doesn't change the computation structure.

@ghost
Copy link
Author

ghost commented Jun 13, 2016

Here's the link: https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
Using shfl_xor just keeps the same copy of data with in a warp (not mentioned in the article) and should be better in keep_dim situations. One example is in a fast CUDA word2vec implementation that utilizes the warp shuffle to reach 3 million tokens/s on a Titan X with the default (as in the original implementation given by Mikolov) CBOW negative sampling setting

@girving
Copy link
Contributor

girving commented Jun 13, 2016

Cc @zheng-xq @benoitsteiner in case more GPU knowledgeable folk want to take a look. Determinism would certainly be nice to have if we can get it.

@zheng-xq
Copy link
Contributor

The shfl_down results are only useful within a single wrap. That technique itself would take a second pass to accumulate the results for each block.

In general, there is no guarantee of determinism on GPU. Therefore, we are not sure how much effort we want to spend on it. Even if we can fix this particular kernel, we have other Cudnn kernels that do have non-determinism.

@jkschin
Copy link

jkschin commented May 12, 2017

@zheng-xq could you give some examples of other CUDNN kernels that have non-determinism? I'd like to explore this a little. Just for education purposes, because like what you mentioned, it's probably not worth the effort unless some major thing happens down the road.

@albertz
Copy link
Contributor

albertz commented Nov 29, 2018

Which op exactly is non-deterministic here? These are the ops in the graph:

[<tf.Operation 'random_normal/shape' type=Const>,
 <tf.Operation 'random_normal/mean' type=Const>,
 <tf.Operation 'random_normal/stddev' type=Const>,
 <tf.Operation 'random_normal/RandomStandardNormal' type=RandomStandardNormal>,
 <tf.Operation 'random_normal/mul' type=Mul>,
 <tf.Operation 'random_normal' type=Add>,
 <tf.Operation 'b/Initializer/Const' type=Const>,
 <tf.Operation 'b' type=VariableV2>,
 <tf.Operation 'b/Assign' type=Assign>,
 <tf.Operation 'b/read' type=Identity>,
 <tf.Operation 'mul' type=Mul>,
 <tf.Operation 'gradients/Shape' type=Const>,
 <tf.Operation 'gradients/Const' type=Const>,
 <tf.Operation 'gradients/Fill' type=Fill>,
 <tf.Operation 'gradients/mul_grad/Shape' type=Const>,
 <tf.Operation 'gradients/mul_grad/Shape_1' type=Const>,
 <tf.Operation 'gradients/mul_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>,
 <tf.Operation 'gradients/mul_grad/mul' type=Mul>,
 <tf.Operation 'gradients/mul_grad/Sum' type=Sum>,
 <tf.Operation 'gradients/mul_grad/Reshape' type=Reshape>,
 <tf.Operation 'gradients/mul_grad/mul_1' type=Mul>,
 <tf.Operation 'gradients/mul_grad/Sum_1' type=Sum>,
 <tf.Operation 'gradients/mul_grad/Reshape_1' type=Reshape>,
 <tf.Operation 'gradients/mul_grad/tuple/group_deps' type=NoOp>,
 <tf.Operation 'gradients/mul_grad/tuple/control_dependency' type=Identity>,
 <tf.Operation 'gradients/mul_grad/tuple/control_dependency_1' type=Identity>,
 <tf.Operation 'init' type=NoOp>]

Do you expect that BroadcastGradientArgs is non-deterministic?

For reference, I tried to run this (with both TF 1.4.1, and also TF 1.12.0), and it seems deterministic to me (980 GTX, CUDA 9.1).

@duncanriach
Copy link
Contributor

The current high-level status is that there are now solutions for TensorFlow determinism when running on GPUs related to cuDNN (convolutions and max-pooling) and bias_add. Please see the following repo for up-to-date status: https://github.com/NVIDIA/tensorflow-determinism

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

No branches or pull requests

5 participants