diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index cbef6235c098..da6a4ef83452 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -463,6 +463,8 @@ def apply( TS, TR = 4, 64 else: TS, TR = 16, 32 + else: + TS, TR = 1, 64 elif target.kind.name == "metal": # Note that the following tile size is tuned on M2 Ultra for 7B TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" @@ -476,6 +478,8 @@ def apply( TS, TR = 4, 16 else: TS, TR = 2, 64 + else: + TS, TR = 1, 64 elif target.kind.name == "rocm": VEC_C = 4 # TODO: set LOAD_V_SHARED = False for now @@ -489,13 +493,15 @@ def apply( TS, TR = 1, 128 else: TS, TR = 8, 64 + else: + TS, TR = 1, 64 elif target.kind.name == "opencl" and "android" in str(target.host): TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" VEC_C = 8 LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 8 - TS, TR = 2, 64 + TS, TR = 2, 32 elif target.kind.name == "vulkan": VEC_C = 4 LOAD_V_SHARED = True @@ -506,6 +512,8 @@ def apply( TS, TR = 4, 32 else: TS, TR = 16, 32 + else: + TS, TR = 1, 64 elif target.kind.name == "opencl" and "mali" in str(target.attrs): VEC_C = 8 LOAD_V_SHARED = False @@ -519,9 +527,6 @@ def apply( UNROLL = 64 TS, TR = 1, 64 - if not isinstance(len_S, int): - TS, TR = 1, 64 - while TS * TR > target.max_num_threads: if TS > 1: TS //= 2 @@ -709,7 +714,11 @@ def apply( if not isinstance(len_r, int): return None - if isinstance(len_s, int) and len_s > 32000: + if not isinstance(len_s, int): + TS, TR = 256, 1 + LOAD_V_SHARED = True + + if isinstance(len_s, int) and len_s > 96000: return None _, TILE_R = ( @@ -754,7 +763,8 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid len_s = get_extent(sch, s) # The config is designed for Adreno - tx_len = 64 + LOAD_V_SHARED = 1 + tx_len = 128 vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1 inner_r = 4 @@ -768,16 +778,23 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8) sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) - cache_v = sch.cache_read(block, vector_input_buffers[0], "local") - sch.compute_at(cache_v, r1, preserve_unit_loops=True) - sch.vectorize(sch.get_loops(cache_v)[-1]) + if LOAD_V_SHARED: + V_shared = sch.cache_read(block, vector_input_buffers[0], storage_scope="shared") + sch.compute_at(V_shared, bx, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + _, tx, vec_r = sch.split(l, factors=[None, tx_len, 8], preserve_unit_iters=True) + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec_r) sch.vectorize(vec) # Schedule epilogue if epilogue_info is not None: - sch.reverse_compute_at(epilogue_info.block_rv, tx) - + sch.reverse_compute_at(epilogue_info.block_rv, bx, preserve_unit_loops=True) + ts_tile_s = sch.get_loops(epilogue_info.block_rv)[-1] + ts, vec = sch.split(ts_tile_s, factors=[tx_len, vec_len], preserve_unit_iters=True) + sch.bind(ts, "threadIdx.x") + sch.vectorize(vec) sch.set_scope(block, 0, "local") sch.decompose_reduction(block, r0) diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 4aae617654d2..0f7b6f45ae3f 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -1106,82 +1106,95 @@ def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) # with T.block("root"): var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16", scope="local") - var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), v), "float16", scope="local") - var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(4), T.int64(1), T.int64(1), v), "float16", scope="local") + var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(8), T.int64(1), T.int64(1), v), "float16", scope="local") + var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1), v), "float16", scope="local") lv613_local = T.alloc_buffer((T.int64(128), v), "float16", scope="local") lv612_local = T.alloc_buffer((T.int64(512), v), "uint32", scope="local") - for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(63)) // T.int64(64), thread="blockIdx.x"): - for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(T.int64(4), thread="threadIdx.y"): + lv1607_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(255)) // T.int64(256), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in T.thread_binding(T.int64(1), thread="threadIdx.y"): for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in T.vectorized(T.int64(8)): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(8), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) T.reads() T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = T.float16(0) - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): - for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1 in T.grid(T.int64(32), T.int64(1)): - for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - with T.block("lv613_local"): - v0 = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0) - v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) - T.reads(lv613[v0, v1]) - T.writes(lv613_local[v0, v1]) - lv613_local[v0, v1] = lv613[v0, v1] - for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)): + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in T.thread_binding(T.int64(1), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_0 in range(T.int64(128)): + for ax0, ax1, ax2_0, ax2_1 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + for ax2_2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax2_3 in T.thread_binding(T.int64(1), thread="threadIdx.y"): + for ax2_4 in T.vectorized(T.int64(4)): + with T.block("lv1607_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(4096), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + (ax2_0 * T.int64(1024) + ax2_1 * T.int64(1024) + ax2_2 * T.int64(4) + ax2_3 * T.int64(4) + ax2_4)) + T.where(((ax2_0 + ax2_1) * T.int64(256) + ax2_2 + ax2_3) * T.int64(4) + ax2_4 < T.int64(32)) + T.reads(lv1607[v0, v1, v2]) + T.writes(lv1607_shared[v0, v1, v2]) + lv1607_shared[v0, v1, v2] = lv1607[v0, v1, v2] + for ax1_0_fused_ax1_1_fused_1 in range(T.int64(1)): for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): - with T.block("lv612_local"): - v0 = T.axis.spatial(T.int64(512), ax1_0_fused_ax1_1_fused_0 * T.int64(16) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0) - v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) - T.reads(lv612[v0, v1]) - T.writes(lv612_local[v0, v1]) - lv612_local[v0, v1] = lv612[v0, v1] - for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(T.int64(8)): - with T.block("matmul_rf_update"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(32), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1) - vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 < v) - T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0], lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)], lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) - T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) - var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] + lv1607[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(16) + vax1_0_fused_ax1_1_fused_1 * T.int64(16) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(4) + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 * T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3 // T.int64(4), v0]) - for ax2 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + with T.block("lv613_local"): + v0 = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_0 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv613[v0, v1]) + T.writes(lv613_local[v0, v1]) + lv613_local[v0, v1] = lv613[v0, v1] + for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + with T.block("lv612_local"): + v0 = T.axis.spatial(T.int64(512), ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0) + v1 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) + T.reads(lv612[v0, v1]) + T.writes(lv612_local[v0, v1]) + lv612_local[v0, v1] = lv612[v0, v1] + for ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in T.vectorized(T.int64(8)): + with T.block("matmul_rf_update"): + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused = T.axis.spatial(T.int64(8), ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1, ax1_0_fused_ax1_1_fused_3]) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 < v) + T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0], lv1607_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused], lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3, v0], lv613_local[(vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused) // T.int64(32) + vax1_0_fused_ax1_1_fused_0 + vax1_0_fused_ax1_1_fused_1, v0]) + T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0]) + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused, T.int64(0), T.int64(0), v0] + lv1607_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1 * T.int64(4) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused // T.int64(8) + vax1_0_fused_ax1_1_fused_3, v0], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613_local[(vax1_0_fused_ax1_1_fused_3 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused) // T.int64(32) + vax1_0_fused_ax1_1_fused_0 + vax1_0_fused_ax1_1_fused_1, v0]) + for ax2 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"): with T.block("matmul_rf_init"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(T.int64(4), ax0) - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.spatial(T.int64(1), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2 < v) T.reads() T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) for ax1 in T.serial(T.int64(8), annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}): with T.block("matmul_rf_update"): vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 = T.axis.remap("SR", [ax0, ax1]) - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax2) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2 < v) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax2) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2 < v) T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0], var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0]) T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] = var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1, T.int64(0), T.int64(0), v0] - for ax1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): - for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for ax1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): + for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"): with T.block("matmul"): - vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(T.int64(4), ax0) - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax1) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax1 < v) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 = T.axis.reduce(T.int64(1), ax0) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax1 < v) T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0]) T.writes(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) with T.init(): var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0, T.int64(0), T.int64(0), v0] - for ax0_fused_0 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_fused_0 in T.thread_binding(T.int64(256), thread="threadIdx.x"): for ax0_fused_1 in range(T.int64(1)): with T.block("compute"): - v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(64) + ax0_fused_0 + ax0_fused_1) - T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + (ax0_fused_0 + ax0_fused_1) < v) + v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 * T.int64(256) + ax0_fused_0 + ax0_fused_1) + T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + (ax0_fused_0 + ax0_fused_1) < v) T.reads(var_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) p_output0_intermediate[T.int64(0), T.int64(0), v0] = T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0])