Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] SplitHostDevice generates free var when var only exists in T.thread_binding of device function #16237

Closed
jinhongyii opened this issue Dec 14, 2023 · 8 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@jinhongyii
Copy link
Contributor

jinhongyii commented Dec 14, 2023

Expected behavior

SplitHostDevice generates no free var.

Actual behavior

SplitHostDevice generates free var cse_var_3 using the script below


    @T.prim_func(private=True)
    def default_function_kernel_2(output_buf: T.handle("int32", "global"), cse_var_4: T.int64, i: T.int32, seq_len: T.int32):
        T.func_attr({"target": T.target({"keys": ["metal", "gpu"], "kind": "metal", "max_function_args": 31, "max_num_threads": 256, "max_shared_memory_per_block": 32768, "max_threads_per_block": 256, "tag": "", "thread_warp_size": 16}), "tir.is_global_func": T.bool(True), "tir.noalias": T.bool(True)})
        output_buf_1 = T.decl_buffer((T.Cast("int64", seq_len) * T.int64(8),), "int32", data=output_buf, align=8)
        end_1 = T.handle("int64", "local")
        end_1_1 = T.decl_buffer((1,), "int64", data=end_1, scope="local")
        middle_1 = T.handle("int64", "local")
        middle_1_1 = T.decl_buffer((1,), "int64", data=middle_1, scope="local")
        start_1 = T.handle("int64", "local")
        start_1_1 = T.decl_buffer((1,), "int64", data=start_1, scope="local")
        threadIdx_x = T.launch_thread("threadIdx.x", 256)
        start_1 = T.allocate([1], "int64", "local")
        middle_1 = T.allocate([1], "int64", "local")
        end_1 = T.allocate([1], "int64", "local")
        cse_var_3 = T.int32()
        blockIdx_x = T.launch_thread("blockIdx.x", (cse_var_3 - 1) // (T.shift_left(2, i) * 256) + 1)
        blockIdx_y = T.launch_thread("blockIdx.y", 1)
        start_1_1[0] = T.Cast("int64", T.shift_left(2, i)) * (T.Cast("int64", blockIdx_x) * T.int64(256) + T.Cast("int64", threadIdx_x))
        if start_1_1[0] < cse_var_4:
            middle_1_1[0] = T.Cast("int64", T.shift_left(2, i)) // T.int64(2) + start_1_1[0]
            end_1_1[0] = T.min(start_1_1[0] + T.Cast("int64", T.shift_left(2, i)), cse_var_4)
            if middle_1_1[0] < cse_var_4:
                output_buf_1[end_1_1[0] - T.int64(1)] = output_buf_1[end_1_1[0] - T.int64(1)] + output_buf_1[middle_1_1[0] - T.int64(1)]

Environment

TVM Unity branch

Steps to reproduce

from tvm.script import tir as T
import tvm

@T.prim_func(private=True)
def cumsum(var_A: T.handle, var_T_add: T.handle, seq_len: T.int32):
    T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
    A = T.match_buffer(var_A, (seq_len * 8,), "int32")
    T_add = T.match_buffer(var_T_add, (seq_len * 8,), "int32")
    # with T.block("root"):
    T_expand_dims = T.alloc_buffer((1, seq_len * 8), "int32")
    output_buf = T.alloc_buffer((1, seq_len * 8), "int32", align=8)
    T_squeeze = T.alloc_buffer((seq_len * 8,), "int32")
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_expand_dims"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(A[v0])
                T.writes(T_expand_dims[0, v0])
                T_expand_dims[0, v0] = A[v0]
    with T.block("exclusive_scan"):
        T.reads(T_expand_dims[0, 0:seq_len * 8])
        T.writes(output_buf[0, 0:seq_len * 8])
        if seq_len * 8 == 0:
            blockIdx_x = T.launch_thread("blockIdx.x", 1)
            if blockIdx_x < 1:
                T.evaluate(0)
        else:
            with T.launch_thread("threadIdx.x", 256) as threadIdx_x:
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, (seq_len * 8 + 255) // 256))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                if blockIdx_x * 256 + threadIdx_x < seq_len * 8:
                    output_buf[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)] = T_expand_dims[(blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) // (seq_len * 8), (blockIdx_y * (seq_len * 8) + (blockIdx_x * 256 + threadIdx_x)) % (seq_len * 8)]
            for i in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 256)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (seq_len * 8 + (256 * T.shift_left(2, i) - 1)) // (256 * T.shift_left(2, i)))))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                start = T.allocate([1], "int64", "local")
                middle = T.allocate([1], "int64", "local")
                end = T.allocate([1], "int64", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[0] = T.Cast("int64", T.shift_left(2, i)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
                if start_1[0] < T.Cast("int64", seq_len * 8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[0] = start_1[0] + T.Cast("int64", T.shift_left(2, i) // 2)
                    end_1 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_1[0] = T.min(start_1[0] + T.Cast("int64", T.shift_left(2, i)), T.Cast("int64", seq_len * 8))
                    if middle_1[0] < T.Cast("int64", seq_len * 8):
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
            with T.launch_thread("blockIdx.x", 1) as blockIdx_x:
                if blockIdx_x < 1:
                    output_buf[((blockIdx_x + 1) * (seq_len * 8) - 1) // (seq_len * 8), ((blockIdx_x + 1) * (seq_len * 8) - 1) % (seq_len * 8)] = 0
            for j in range(T.Cast("int32", T.ceil(T.log2(T.Cast("float32", seq_len * 8))))):
                threadIdx_x = T.launch_thread("threadIdx.x", 256)
                blockIdx_x = T.launch_thread("blockIdx.x", T.max(1, T.Cast("int32", (T.Cast("int64", seq_len * 8) + (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) - T.int64(1))) // (T.int64(256) * T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1))))))
                blockIdx_y = T.launch_thread("blockIdx.y", 1)
                start = T.allocate([1], "int64", "local")
                middle = T.allocate([1], "int64", "local")
                end = T.allocate([1], "int64", "local")
                end_1 = T.allocate([1], "int32", "local")
                start_1 = T.Buffer((1,), "int64", data=start, scope="local")
                start_1[0] = T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) * T.Cast("int64", blockIdx_x * 256 + threadIdx_x)
                if start_1[0] < T.Cast("int64", seq_len * 8):
                    middle_1 = T.Buffer((1,), "int64", data=middle, scope="local")
                    middle_1[0] = start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)) // T.int64(2)
                    end_2 = T.Buffer((1,), "int64", data=end, scope="local")
                    end_2[0] = T.min(start_1[0] + T.shift_left(T.int64(2), T.Cast("int64", T.ceil(T.log2(T.Cast("float32", seq_len * 8)))) - T.Cast("int64", j) - T.int64(1)), T.Cast("int64", seq_len * 8))
                    if middle_1[0] < T.Cast("int64", seq_len * 8):
                        end_3 = T.Buffer((1,), "int32", data=end_1, scope="local")
                        end_3[0] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + middle_1[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)]
                        output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] = output_buf[(T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) // T.Cast("int64", seq_len * 8), (T.Cast("int64", blockIdx_y * (seq_len * 8)) + end_2[0] - T.int64(1)) % T.Cast("int64", seq_len * 8)] + end_3[0]
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_squeeze"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(output_buf[0, v0])
                T.writes(T_squeeze[v0])
                T_squeeze[v0] = output_buf[0, v0]
    for ax0_fused_0 in T.thread_binding((seq_len * 8 + 1023) // 1024, thread="blockIdx.x"):
        for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"):
            with T.block("T_add"):
                v0 = T.axis.spatial(seq_len * 8, ax0_fused_0 * 1024 + ax0_fused_1)
                T.where(ax0_fused_0 * 1024 + ax0_fused_1 < seq_len * 8)
                T.reads(A[v0], T_squeeze[v0])
                T.writes(T_add[v0])
                T_add[v0] = A[v0] + T_squeeze[v0]
                
tvm.build(cumsum, target="metal")

cc: @Lunderberg

@jinhongyii jinhongyii added type: bug needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Dec 14, 2023
@jinhongyii
Copy link
Contributor Author

#16236 is a temporal fix that enforces free var in device function to have the same pointer as that in host function, but a formal fix requires the pass to eliminate the ill-formed output that contains free var

@tqchen
Copy link
Member

tqchen commented Dec 14, 2023

cc @Lunderberg . I think we need to revisit #14918

Specifically, we should follow a principle that any stage should contain functions that are self-contained. One approach could be SplitHostDevice generate the function call with those extra information (of launch) and you then eliminate them when detecting it is a device only function.

Right now the dependency on a function contains symbols that is in another function body can be quite frigile

@Lunderberg
Copy link
Contributor

Lunderberg commented Dec 14, 2023

SplitHostDevice generates free var cse_var_3 using the script below

@jinhongyii The test case does not reproduce the error describe on the apache:main branch. Either before or after #16236, it instead raises an error in tir.transform.LowerAutoCopy pass (memhammer_lower_auto_copy.cc). It does sporadically reproduce the error on apache:unity, but is inconsistent (60 failures / 100 iterations). This makes sense if the failure mode is order-dependent.

Also, the failure mode does not require target="metal" to reproduce. This is good for debugging, as Metal is OSX-specific.

Specifically, we should follow a principle that any stage should contain functions that are self-contained. One approach could be SplitHostDevice generate the function call with those extra information (of launch) and you then eliminate them when detecting it is a device only function.

This is both the intention and current design of SplitHostDevice. Any variables not defined in a function should be provided from the calling scope. The device function contains sufficient information to know how it should be launched, using the known thread extents. This is later used to determine the launch parameters when called from a host function.

Right now the dependency on a function contains symbols that is in another function body can be quite frigile

Agreed. Depending on symbols defined in a different function body sounds incredibly fragile, and is certainly not the intended behavior.

@Lunderberg
Copy link
Contributor

I've tracked down the root cause, it looks like the main cause is the use of visit_thread_extent_ = false in VarUseDefAnalysis, with a secondary cause in EliminateCommonSubexpr's choice of LetStmt location.

  1. Multiple expressions use the same seq_len * 8. This is hoisted by EliminateCommonSubexpr and placed outside the "target" attribute.
  2. SplitHostDevice collects the variables that are used by the body of a function used within a "target" attribute, but defined outside. The cse_var_3 only occurs within a thread extent, so it is not included.
  3. Each variable collected in the previous step is used as a parameter to the device-side kernel. This causes a duplicate definition of these variables, between the LetStmt in the host-side function and the parameter in the device-side function.
  4. To restore SSA, the module is post-processed with ConvertSSA.
  5. ConvertSSA sees that the cse_var_3 is defined in the host-side function, but used in the device-side function. The usage in the device-side function has no definition.
  6. When the thread extents are collected in LowerDeviceKernelLaunch, the undefined variable cse_var_3 is found in the AttrStmt of the device-side function. All device-side function parameters are substituted to determine the thread extent in terms of the host-side variables, but there is no definition of cse_var_3 on the device side, so it remains undefined.

The solution is to remove the /* visit_thread_extent = */ false argument here. This produces the correct parameters in the device-side function, which are then replaced correctly with ConvertSSA.

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 15, 2023
The buf reported in apache#16237 can be resolved by tracking variable usage
in a thread extent.
@Lunderberg
Copy link
Contributor

Lunderberg commented Dec 15, 2023

And confirmed, when we visit the thread_extent, I can make it through LLVM codegen without issue.

In the process of investigating, I also ran into a few additional limitations:

  1. The IRConvertSSA checks to see whether the AttrStmtNode::node is a tir::Var, but doesn't check for a tir::IterVar. For a thread_extent attribute, this should be treated as the point of definition for the variable, to be de-duplicated.

  2. The IRConvertSSA doesn't update the min/extent of the tir::IterVar. If the min/extent includes a variable being replaced, those usages will instead be undefined after ConvertSSA.

  3. The TVMScript printer doesn't display the min/extent of the tir::IterVar used in thread_extent. Instead, only the variable name and the attribute's value appear in the generated T.launch_thread, even when syntax sugar is disabled, making it quite difficult to debug.

  4. The test case from this thread does not have DeclBuffer nodes. As a result, the buffers are undefined when used, and assumed to need propagation from the host to the device. SplitHostDevice inserts a DeclBufferNode in the device kernel, but this occurs prior to the AllocateNode.

  5. The well-formed check for TIR (tvm.tir.analysis.verify_well_formed) does not validate SSA, nor does it validate whether variables/buffers are defined at their point of use.

I've put together draft PR #16250 with a fix. Can you test it and verify that it resolves the error observed in LLVM codegen.

@jinhongyii
Copy link
Contributor Author

The fix works for me. Thanks @Lunderberg

Lunderberg added a commit to Lunderberg/tvm that referenced this issue Dec 18, 2023
The buf reported in apache#16237 can be resolved by tracking variable usage
in a thread extent.
@Lunderberg
Copy link
Contributor

@jinhongyii Thank you for testing, and I've cleaned up the PR and marked it ready for review.

Lunderberg added a commit that referenced this issue Jan 3, 2024
* [TIR] In SplitHostDevice, check for variables in thread extents

Otherwise, they would be undefined after being de-duplicated by
`ConvertSSA`.

* Revert #16236

The buf reported in #16237 can be resolved by tracking variable usage
in a thread extent.

* lint fixes

* Update TIR well-formed checker for env thread SSA requirements

Environment threads must reuse the same `tir::Var` across all
`AttrStmt` instances in a `PrimFunc`, but must not reuse across
separate `PrimFunc`s in an `IRModule`.

* Update ConvertSSA to handle environment threads' SSA requirements

* lint fix

* Updated docstrings for VerifyWellFormed

* Rely on script.Complete for read/writes

Avoids issue in cortexm unit tests resulting from read/write
annotations being present in the root block, followed by application
of BindParams.

* Typo fix

* Added structural equal comparison in unit test
@Lunderberg
Copy link
Contributor

The long-term fix #16250 has landed, and so I believe this issue can be closed.

junrushao pushed a commit to junrushao/tvm that referenced this issue Jan 7, 2024
…he#16250)

* [TIR] In SplitHostDevice, check for variables in thread extents

Otherwise, they would be undefined after being de-duplicated by
`ConvertSSA`.

* Revert apache#16236

The buf reported in apache#16237 can be resolved by tracking variable usage
in a thread extent.

* lint fixes

* Update TIR well-formed checker for env thread SSA requirements

Environment threads must reuse the same `tir::Var` across all
`AttrStmt` instances in a `PrimFunc`, but must not reuse across
separate `PrimFunc`s in an `IRModule`.

* Update ConvertSSA to handle environment threads' SSA requirements

* lint fix

* Updated docstrings for VerifyWellFormed

* Rely on script.Complete for read/writes

Avoids issue in cortexm unit tests resulting from read/write
annotations being present in the root block, followed by application
of BindParams.

* Typo fix

* Added structural equal comparison in unit test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

3 participants