Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

row_bcast operand from src/kernels/reduction_functions.h leads to an error on gfx10 #1528

Closed
kvirikroma opened this issue Apr 26, 2022 · 8 comments

Comments

@kvirikroma
Copy link

kvirikroma commented Apr 26, 2022

Note: I'm new to ROCm and this error may be my own fault, but I still would be thankful for any kind of help.
I've compiled MIOpen from source and tried to use it with TensorFlow. It works fine with a simple dense network, but I've tried it on a convolutional one and got an error on a training stage:

MIOpen(HIP): Info [FindConvFwdAlgorithm] FW Chosen Algorithm: ConvBinWinogradRxSf2x3g1 , 0, 11.0525
MIOpen(HIP): Info [get_device_name] Raw device name: gfx1031
MIOpen(HIP): Info [SetStream] stream: 0x5561e7c1c8c0, device_id: 0
MIOpen(HIP): Info [ConvolutionForward] algo = 3, workspace = 0
MIOpen(HIP): Info [get_device_name] Raw device name: gfx1031
MIOpen(HIP): Info [SetStream] stream: 0x5561e7c1c8c0, device_id: 0
MIOpen(HIP): Info [PrintVersionImpl] COMgr v.2.4.0, USE_HIP_PCH: 1
<inline asm>:14:20: error: not a valid operand.
v_add_f32 v4 v4 v4 row_bcast:15 row_mask:0xa
                   ^
<inline asm>:15:20: error: not a valid operand.
v_add_f32 v3 v3 v3 row_bcast:15 row_mask:0xa
                   ^
<inline asm>:17:20: error: not a valid operand.
v_add_f32 v4 v4 v4 row_bcast:31 row_mask:0xc
                   ^
<inline asm>:18:20: error: not a valid operand.
v_add_f32 v3 v3 v3 row_bcast:31 row_mask:0xc
                   ^
MIOpen(HIP): Error [Do] 'amd_comgr_do_action(kind, handle, in.GetHandle(), out.GetHandle())' AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE: ERROR (1)
MIOpen(HIP): Error [BuildOcl] comgr status = ERROR (1)
MIOpen(HIP): Warning [BuildOcl] error: cannot compile inline asm
MIOpen Error: /home/roman/ROCm/MIOpen/src/hipoc/hipoc_program.cpp:300: Code object build failed. Source: MIOpenBatchNormFwdTrainSpatial.cl
2022-04-26 20:42:17.080292: E tensorflow/stream_executor/rocm/rocm_dnn.cc:3781] failed to enqueue forward batch normalization on stream: miopenStatusUnknownError

I've built MIOpen with following parameters:

AMDGPU_TARGETS=gfx1031
BUILD_WITH_TENSILE_HOST=false
CMAKE_BUILD_TYPE=Release

Environment:
GPU: Radeon RX 6700 XT (Navi 22), gfx1031 architecture
OS: Arch Linux
ROCm version: 5.1.1
MIOpen version: 2.16.0
TensorFlow version: tf-nightly-rocm 2.10.0 (custom build for gfx1031)

Yes, I know that gfx1031 is not officially supported, but AFAICT, the error happens during compilation of GPU assembly (not in runtime) so I don't think it's related somehow.
Also I've tried to change those instructions to v_add_f32_dpp, but it did not fix the problem (they are on 95-99 lines in src/kernels/reduction_functions.h).

@kvirikroma kvirikroma changed the title row_bcast operand from src/kernels/reduction_functions.h leads to error on gfx10 row_bcast operand from src/kernels/reduction_functions.h leads to an error on gfx10 Apr 26, 2022
@xuhuisheng
Copy link
Contributor

xuhuisheng commented Apr 27, 2022

This may be the reason, ROCm only treat gfx1030 and gfx1011
https://github.com/ROCmSoftwarePlatform/rocSPARSE/blob/develop/library/src/include/common.h#L218

You can reference my patch to try it on gfx1031
https://github.com/xuhuisheng/rocm-build/blob/master/patch/25.rocsparse-gfx10-1.patch

@kvirikroma
Copy link
Author

@xuhuisheng thank you for answer
I've already done that (using your rocm-build repo, which is really helpful btw) but, unfortunately, it did not help. I don't think it's related to rocSPARSE, because that code (which causes the error) is in the MIOpen

@atamazov
Copy link
Contributor

@kvirikroma This is issue with inline assembly code used in BN OpenCL kernels. Note that gfx1031 is not officially supported. We can update MIOpen in order to resolve the issue. In order to help us to prioritize this work, please indicate importance of this issue for you.

/cc @junliume

@kvirikroma
Copy link
Author

kvirikroma commented Apr 27, 2022

@atamazov it's not urgent for me because I do it all just for fun/learning/experience and will not lose any money because of the issue (and even if I did, it would be my fault tbh, considering gfx1031 is not supported). But it would be great to get it working soon)

@atamazov
Copy link
Contributor

@kvirikroma Okay. Luckily, the fix seems easy to implement. There is PR #1531, please review it first, then give it a try and let us know if it works for you.

@kvirikroma
Copy link
Author

The issue was solved by @atamazov in https://github.com/atamazov/MIOpen/tree/generalize-gfx10 so I'm closing it

@atamazov
Copy link
Contributor

@kvirikroma Technically, the issue is not closed until the PR is merged in. I recommend re-opening.

@kvirikroma kvirikroma reopened this Apr 30, 2022
@junliume
Copy link
Collaborator

junliume commented May 5, 2022

@kvirikroma Technically, the issue is not closed until the PR is merged in. I recommend re-opening.

The PR is merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants