Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bitcast errors for some tf.reduce_sum operations when using XLA #911

Closed
roblem opened this issue Mar 26, 2020 · 4 comments
Closed

bitcast errors for some tf.reduce_sum operations when using XLA #911

roblem opened this issue Mar 26, 2020 · 4 comments
Assignees

Comments

@roblem
Copy link

roblem commented Mar 26, 2020

Same system setup and evironment as reported in #908. In summary, very basic tensorflow ops like tf.reduce_sum seem to fail in some instances with "Invalid bitcast" errors when using XLA.

In #908, I reported an issue with bincount (also one concerning bitcast errors) that prevented XLA compilation of an expensive to calculate function. To workaround this issue, I have been playing with problems having identical numbers of rows in each bin so that the tf.math.bincount (and tf.math.segment_sum which hasn't been implemented yet AFAIK) can be avoided. Some code generating toy data:

# for each row in the float xf, idx indicates which group
# (for summing purposes). Note 5 groups with 3 rows each
idx = tf.constant([0,0,0,1,1,1,2,2,2,3,3,3,4,4,4], dtype=tf.int32)
xf = tf.constant(np.random.randn(idx.shape[0]), dtype=tf.float32)
# here are parameters for reshapes below
n_rows = tf.constant(tf.reduce_max(idx)+1, dtype=tf.int32)
n_cols = tf.constant(3, dtype=tf.int32)
# for a sanity check, we want to replicate this result:
print("tf.math.bincount calculation:\n ", tf.math.bincount(idx, weights=xf))

Gives this result

tf.math.bincount calculation:
  tf.Tensor([ 1.775579    1.5032012   0.04776037 -2.9308324  -1.5677404 ], shape=(5,), dtype=float32)

For the purposes of exploring XLA, create two tf.functions one implementing XLA and one not:

@tf.function(experimental_compile=True)
def addvec_XLA(x):
    return tf.reduce_sum(tf.reshape(x,(n_rows, n_cols)), axis=-1)

@tf.function
def addvec_noXLA(x):
    return tf.reduce_sum(tf.reshape(x,(n_rows, n_cols)), axis=-1)

Running the non-XLA one gives

print("Reshaped vector reduce_sum operation (no XLA): \n", addvec_noXLA(xf))
Reshaped vector reduce_sum operation (no XLA): 
 tf.Tensor([ 1.775579    1.5032012   0.04776037 -2.9308324  -1.5677404 ], shape=(5,), dtype=float32)

And running the xla one:

print("Reshaped vector reduce_sum operation (with XLA): \n", addvec_XLA(xf))

fails with this error

InternalError: RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:394) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:
Call parameter type does not match function signature!
  %parameter_buffer = alloca float, addrspace(5)
 float*  call void @Sum_reduction_14(float addrspace(5)* %parameter_buffer, float addrspace(5)* %parameter_buffer1, float addrspace(5)* %return_buffer, i8* null)
Invalid bitcast
  %add.17.typed = bitcast float addrspace(5)* %add.17.raw to float*

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR. 
	This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_addvec_XLA_591]
@roblem
Copy link
Author

roblem commented Mar 26, 2020

I think it may affect any tf.reduce_sum op:

# but can you reduce sum if no reshape is involved?
xf_r = tf.constant(np.random.randn(5,3), dtype=tf.float32)

@tf.function(experimental_compile=True)
def reducesum(x):
    return tf.reduce_sum(x, axis=-1)

reducesum(xf_r)

gives the same error:

InternalError: RET_CHECK failure (tensorflow/compiler/xla/service/gpu/gpu_compiler.cc:394) !llvm::verifyModule(llvm_module, &err_stream) Invalid LLVM IR before optimizations:
Call parameter type does not match function signature!
  %4 = alloca float, addrspace(5)
 float*  call void @Sum_reduction_6(float addrspace(5)* %4, float* %8, float addrspace(5)* %4, i8* null)
Invalid bitcast
  %add.9.typed = bitcast float addrspace(5)* %add.9.raw to float*

This probably indicates a bug in the HLO -> LLVM IR lowering. Rerun with --xla_dump_to to get the IR. 
        This error might be occurring with the use of xla.compile. If it is not necessary that every Op be compiled with XLA, an alternative is to use auto_jit with OptimizerOptions.global_jit_level = ON_2 or the environment variable TF_XLA_FLAGS="tf_xla_auto_jit=2" which will attempt to use xla to compile as much of the graph as the compiler is able to. [Op:__inference_reducesum_18]

@roblem
Copy link
Author

roblem commented Mar 27, 2020

I have done some additional digging for a range of ops that I normally use and conducted some simple tests. See the summary table below. A Pass value of 'Yes' means the XLA function ran with the correct result (it doesn't test speed or efficiency in any way). If it didn't run in XLA, a partial error is reported.

Command Pass Error
tf.transpose Yes N/A
tf.math.bincount No Invalid bitcast %add.typed = bitcast float addrspace(5)* %add.raw to float*
tf.reduce_sum No Invalid bitcast %add.9.typed = bitcast float addrspace(5)* %add.9.raw to float*
tf.reduce_max Yes N/A
tf.reduce_mean Yes N/A
tf.reduce_prod Yes N/A
tf.reduce_min (and max) Yes N/A
tf.reduce_sumlogexp Yes N/A
tf.linalg.matvec No Invalid bitcast %add.typed = bitcast float addrspace(5)* %add.raw to float*
tf.linalg.matmul Yes N/A
tf.not_equal Yes N/A
tf.where No unsupported op: [Unsupported in all of tf]
tf.scatter_nd No Invalid bitcast %add.10.typed = bitcast float addrspace(5)* %add.10.raw to float*
tf.fill Yes N/A
tf.multiply Yes N/A
tf.add Yes N/A
tf.gather_nd Yes N/A
tf.linalg.norm No Invalid bitcast %add.10.typed = bitcast float addrspace(5)* %add.10.raw to float*
tf.squeeze Yes N/A
tf.math.segment_sum No unsupported op:
tf.greater_equal Yes N/A
tf.cond Yes N/A

@jerryyin
Copy link
Member

jerryyin commented Apr 3, 2020

@roblem Same with #908. This is likely to be the same root cause but different symptom and has already been fixed. Could you try with rocm/tensorflow-autobuilds:rocm3.0-fd0317e?

@roblem
Copy link
Author

roblem commented Apr 5, 2020

I can confirm that all tests listed above now pass (except unsupported ops- not surprising) with rocm/tensorflow-autobuilds:rocm3.0-fd0317e.

@roblem roblem closed this as completed Apr 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants