Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.math.bincount can only sum integer weights when using XLA #908

Closed
roblem opened this issue Mar 24, 2020 · 9 comments
Closed

tf.math.bincount can only sum integer weights when using XLA #908

roblem opened this issue Mar 24, 2020 · 9 comments
Assignees
Labels
question Further information is requested triage

Comments

@roblem
Copy link

roblem commented Mar 24, 2020

System information

  • Have I written custom code (as opposed to using a stock
    example script provided in TensorFlow): Yes.

  • OS Platform and Distribution
    Linux Ubuntu 18.04 using upstream radeon kernel driver and launching ROCM scripts in the latest docker container using TF_ROCM_FUSION_ENABLE=1

  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if
    N/A

  • TensorFlow installed from (source or
    binary): installed from docker as rocm/tensorflow:latest

  • TensorFlow version (use command below): v2.1.0-15-g5466af3 2.1.0

  • Python version: 3.5.2

  • Bazel version (if compiling from source): Build label: 0.29.1

  • GCC/Compiler version (if compiling from source): gcc (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609

  • CUDA/cuDNN version:

  • GPU model and memory: Radeon VII gfx906

Describe the current behavior

From the documentation here tf.math.bincount can accept weights as an argument. These weights can be integer or floats. According to the documentation, if the optional dtype argument isn't given, then the output data is assigned a dtype equal to the dtype of weights.

This isn't currently happening with bincount under XLA compilation. Bincount will only successfully apply weights if they are integers. It will neither use the dtype of the weights variable nor will it assign dtype if it is supplied as an optional argument.

Standalone code to reproduce the issue
Here is some example code with output:

import tensorflow as tf
import numpy as np
# define index for summing over elements in weights (the sum groups on index value).
# This index must be of type int32 I believe.
idx = tf.constant([0,0,0,1,1,1,2,2,2,3,3,4,4,4,4], dtype=tf.int32)
# here are the optional weights we'll be using.  One a float and one and integer
x_float = tf.constant(np.random.randn(idx.shape[0]), dtype=tf.float32)  
x_int = tf.constant([i + 1 for i in range(idx.shape[0])], dtype = tf.int32)

# here is a very simple xla function that implements bincount
@tf.function(experimental_compile=True)
def bincount_XLA(idx_, x_):
     return tf.math.bincount(idx_, weights=x_)
# and another that doesn't use XLA
def bincount(idx_, x_):
     return tf.math.bincount(idx_, weights=x_)
# applying the non-xla function to weights that are integers:
bincount(idx, x_int)

yields

<tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 6, 15, 24, 21, 54], dtype=int32)>

Running with weights that are integers on XLA yields same results:

In [11]: bincount_XLA(idx, x_int)  
2020-03-24 10:24:24.767456: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5587bb802580 initialized for platform ROCM (this does not guarantee that XLA will be used). Devices:
2020-03-24 10:24:24.767486: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Vega 20, AMDGPU ISA version: gfx906
2020-03-24 10:24:24.986954: I tensorflow/compiler/jit/xla_compilation_cache.cc:242] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
Out[11]: <tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 6, 15, 24, 21, 54], dtype=int32)>

Now try a weights vector that is a float using non-xla:

In [12]: bincount(idx, x_float)                                                                                                                               
Out[12]: 
<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([-1.9718069,  2.658638 , -0.6710254,  1.0607541,  1.5008008],
      dtype=float32)>

And with XLA, the command fails with this at the bottom:

In [13]: bincount_XLA(idx, x_float)
Out[13]:
<snip>
InternalError: RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:394) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:
Invalid bitcast
  %19 = bitcast i32 addrspace(5)* %cas_new_output_address to float*
Call parameter type does not match function signature!
  %14 = alloca float, addrspace(5)
 float*  call void @scatter_combiner_21(float* %19, float addrspace(5)* %14, float* %19, i8* null)
Invalid bitcast
  %add.24.typed = bitcast float addrspace(5)* %add.24.raw to float*

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR. 
	This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_bincount_XLA_63]

It seems that the code generated by XLA is trying to force the results of bincount into an integer dtype. I should also mention that adding the optional dtype to the bincount functions results in the same errors (so is ignored during code generation).

@sunway513
Copy link

@jerryyin can you help review this issue for our current XLA implementation?

@jerryyin
Copy link
Member

jerryyin commented Apr 3, 2020

@roblem Thanks for reporting the problem. I digged the history a little bit and think this might be a fairly recent regression in IR-emitter changes. It has already been fixed in tensorflow#36187. Could you try with rocm/tensorflow-autobuilds:rocm3.0-fd0317e (The first image that pulled in the fix)? I have tested and can confirm no regression in that image.

@roblem
Copy link
Author

roblem commented Apr 5, 2020

@jerryyin thanks. I can confirm that all my tests pass (including those in #911) with rocm/tensorflow-autobuilds:rocm3.0-fd0317e. When might we expect this to make it to the "latest" docker image or in pypi?

@jerryyin
Copy link
Member

jerryyin commented Apr 6, 2020

@roblem This is a question for @sunway513. In order for him to determine if the latest docker is up-to-date, we will need your image hash (Did you pull it most recently?). The latest tag for rocm/tensorflow:latest is always moving and depends on when you pull it, it might end in different images.

@jerryyin jerryyin added question Further information is requested triage labels Apr 6, 2020
@sunway513
Copy link

Hi @roblem , the rocm/tensorflow:last tag would point to the last TF release build with the latest ROCm release packages. We won't be able to modify the last tag to the builds for develop-upstream branch.
For this issue, I don't see upstream TF has the PR tensorflow#36187 in the TF r2.2 release branches. Hence, the fix should be propagated to the next release after TF2.2.
In the meanwhile, you can continue use the docker tag referred by @jerryyin .

@sunway513
Copy link

Wait, upstream TF actually applied the patch internally, and the fix is actually part of the TF r2.2 release branch.
Can you try the following docker container for TF2.2 RC2?
rocm/tensorflow:rocm3.3-tf2.2-rc2-dev

@roblem
Copy link
Author

roblem commented Apr 12, 2020

Apologies for the delay. My workstation is running a memory constrained tensorflow job right now and I project it will take approximately 2 weeks to complete. Then I can try this.

@jerryyin
Copy link
Member

Closing the issue now as a viable solution has already been given. Please feel free to comment/re-open if the image turns out to be not working.

@roblem
Copy link
Author

roblem commented Apr 17, 2020

I can confirm that all my tests pass with rocm/tensorflow:rocm3.3-tf2.2-rc2-dev, so the fix is included in 2.2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested triage
Projects
None yet
Development

No branches or pull requests

3 participants