Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Upstream Backend] [PyTorch UT] error: failed to legalize operation 'tt.mulhiui' that was explicitly marked illegal #548

Closed
jataylo opened this issue Mar 28, 2024 · 7 comments
Assignees

Comments

@jataylo
Copy link

jataylo commented Mar 28, 2024

Problem Description

Environment:
Docker image: rocm/pytorch-private:rocm_inductor_triton_upstream_migration_v1
Triton branch: https://github.com/jataylo/triton/tree/jack-triton-inductor-migration
Pytorch branch: https://github.com/pytorch/pytorch/tree/rocm-inductor-hip-device

PyTorch UT: inductor/test_torchinductor.py::test_dropout2_cuda

Reproducer:

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.triton_helpers import libdevice, math as tl_math

@triton.jit
def randint64(seed, offset, low, high):
    r0, r1, r2, r3 = tl.randint4x(seed, offset)
    r0 = r0.to(tl.uint64)
    r1 = r1.to(tl.uint64)
    result = r0 | (r1 << 32)
    size = high - low
    result = result % size.to(tl.uint64)
    result = result.to(tl.int64) + low
    return result

@triton.jit
def triton_fn(in_ptr0, in_ptr1, out_ptr0, ks0, load_seed_offset, load_seed_offset1):
    XBLOCK : tl.constexpr = 16
    RBLOCK : tl.constexpr = 8192
    xnumel : tl.constexpr = 13
    rnumel : tl.constexpr = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp27 = tl.full([XBLOCK, RBLOCK], 0, tl.int32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (x0*((12 + ks0) // 13))
        tmp1 = ks0
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + load_seed_offset)
        tmp4 = randint64(tmp3, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp5 = tmp4.to(tl.int32)
        tmp6 = tl_math.abs(tmp5)
        tmp7 = tl.full([1, 1], 2147483647, tl.int32)
        tmp8 = tmp6 % tmp7
        tmp9 = tmp8 + tmp7
        tmp10 = tl.where(((tmp8 != 0) & ((tmp8 < 0) != (tmp7 < 0))), tmp9, tmp8)
        tmp11 = tl.full([1, 1], 1, tl.int32)
        tmp12 = tmp10 + tmp11
        tmp13 = tmp12.to(tl.int64)
        tmp14 = tl.load(in_ptr1 + (r1 + (x0*((12 + ks0) // 13))), rmask & tmp2 & xmask, eviction_policy='evict_first', other=0.0)
        tmp15 = tmp14.to(tl.int64)
        tmp16 = tmp13 * tmp15
        tmp17 = tl.load(in_ptr0 + load_seed_offset1)
        tmp18 = randint64(tmp17, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp19 = tmp18.to(tl.int32)
        tmp20 = tl_math.abs(tmp19)
        tmp21 = tmp20.to(tl.int64)
        tmp22 = tmp16 + tmp21
        tmp23 = tmp22.to(tl.int32)
        tmp24 = tl.full(tmp23.shape, 0, tmp23.dtype)
        tmp25 = tl.where(tmp2, tmp23, tmp24)
        tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
        tmp28 = _tmp27 ^ tmp26
        _tmp27 = tl.where(rmask & xmask, tmp28, _tmp27)
    tmp27 = tl.xor_sum(_tmp27, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp27, xmask)

src = triton.compiler.ASTSource(fn=triton_fn, signature="*i64, *i32, *i32, i32, i32, i32")
test = triton.compile(src)

Stacktrace:

loc(callsite(callsite(callsite(callsite("/tmp/triton/python/triton/language/random.py":35:28 at "/tmp/triton/python/triton/language/random.py":61:57) at "/tmp/triton/python/triton/language/random.py":94:44) at "repro.py":10:40) at "repro.py":39:66)): error: failed to legalize operation 'tt.mulhiui' that was explicitly marked illegal
Traceback (most recent call last):
  File "repro.py", line 68, in <module>
    test = triton.compile(src)
  File "/tmp/triton/python/triton/compiler/compiler.py", line 268, in compile
    next_module = compile_ir(module, metadata)
  File "/tmp/triton/python/triton/backends/amd/compiler.py", line 223, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, 90)
  File "/tmp/triton/python/triton/backends/amd/compiler.py", line 163, in make_llir
    pm.run(mod)
RuntimeError: PassManager::run failed

Operating System

CPU

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@jataylo
Copy link
Author

jataylo commented Apr 2, 2024

Stripped out inductor entirely:

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

@triton.jit
def randint64(seed, offset, low, high):
    r0, r1, r2, r3 = tl.randint4x(seed, offset)
    r0 = r0.to(tl.uint64)
    r1 = r1.to(tl.uint64)
    result = r0 | (r1 << 32)
    size = high - low
    result = result % size.to(tl.uint64)
    result = result.to(tl.int64) + low
    return result

@triton.jit
def triton_fn(in_ptr0, in_ptr1, out_ptr0, ks0, load_seed_offset, load_seed_offset1):
    XBLOCK : tl.constexpr = 16
    RBLOCK : tl.constexpr = 8192
    xnumel : tl.constexpr = 13
    rnumel : tl.constexpr = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp27 = tl.full([XBLOCK, RBLOCK], 0, tl.int32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (x0*((12 + ks0) // 13))
        tmp1 = ks0
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + load_seed_offset)
        tmp4 = randint64(tmp3, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp5 = tmp4.to(tl.int32)
        tmp6 = tl.math.abs(tmp5)
        tmp7 = tl.full([1, 1], 2147483647, tl.int32)
        tmp8 = tmp6 % tmp7
        tmp9 = tmp8 + tmp7
        tmp10 = tl.where(((tmp8 != 0) & ((tmp8 < 0) != (tmp7 < 0))), tmp9, tmp8)
        tmp11 = tl.full([1, 1], 1, tl.int32)
        tmp12 = tmp10 + tmp11
        tmp13 = tmp12.to(tl.int64)
        tmp14 = tl.load(in_ptr1 + (r1 + (x0*((12 + ks0) // 13))), rmask & tmp2 & xmask, eviction_policy='evict_first', other=0.0)
        tmp15 = tmp14.to(tl.int64)
        tmp16 = tmp13 * tmp15
        tmp17 = tl.load(in_ptr0 + load_seed_offset1)
        tmp18 = randint64(tmp17, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp19 = tmp18.to(tl.int32)
        tmp20 = tl.math.abs(tmp19)
        tmp21 = tmp20.to(tl.int64)
        tmp22 = tmp16 + tmp21
        tmp23 = tmp22.to(tl.int32)
        tmp24 = tl.full(tmp23.shape, 0, tmp23.dtype)
        tmp25 = tl.where(tmp2, tmp23, tmp24)
        tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
        tmp28 = _tmp27 ^ tmp26
        _tmp27 = tl.where(rmask & xmask, tmp28, _tmp27)
    tmp27 = tl.xor_sum(_tmp27, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp27, xmask)

src = triton.compiler.ASTSource(fn=triton_fn, signature="*i64, *i32, *i32, i32, i32, i32")
test = triton.compile(src)

@jataylo
Copy link
Author

jataylo commented Apr 2, 2024

Note this is blocking us creating reproducers for two other failing categories:

  • RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
  • error: failed to legalize operation 'triton_gpu.local_load' that was explicitly marked illegal

Once unblocked I will create reproducers for the above

@zhanglx13
Copy link

triton_gpu.local_load

This was recently added in upstream and not pulled in our fork. @jataylo Are you using upstream or the fork?

@jataylo
Copy link
Author

jataylo commented Apr 2, 2024

triton_gpu.local_load

This was recently added in upstream and not pulled in our fork. @jataylo Are you using upstream or the fork?

This is using the upstream backend, I can raise these issues at openai/triton if we think this is more appropriate

@micmelesse
Copy link
Collaborator

This branch fixes the issue, https://github.com/micmelesse/triton/tree/micmelesse/pytorch_2. We will work to upstream this.

@jataylo
Copy link
Author

jataylo commented Apr 3, 2024

@micmelesse
Thank you! I can confirm the UT is passing with this change. I'll keep the issue open until upstream PR is scoped.

@zhanglx13
Copy link

This upstream PR should fix the issue: triton-lang#3563

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants