From ee7c6d2173a27240499422fdade2e2126cbfec0f Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Thu, 26 Oct 2023 00:10:36 -0400 Subject: [PATCH 1/2] metal perf --- python/tvm/dlight/gpu/gemv.py | 19 +++++++++++++++---- python/tvm/dlight/gpu/utils.py | 2 ++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 3544719af054..9c7bab6e2f4d 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -418,11 +418,22 @@ def apply( else: TS, TR = 16, 32 elif target.kind.name == "metal": - VEC_C = 2 - LOAD_V_SHARED = True - LOAD_V_VEC = 4 + # VEC_C = 2 + # LOAD_V_SHARED = True + # LOAD_V_VEC = 4 + # UNROLL = 256 + # TS, TR = 64, 8 + # Note that the following tile size is tuned on M2 Ultra for 7B + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 4 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 UNROLL = 256 - TS, TR = 64, 8 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 1, 64 + else: + TS, TR = 1, 256 elif target.kind.name == "rocm": VEC_C = 4 LOAD_V_SHARED = True diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py index 9f9a9c5ae48d..00d97ab7f12a 100644 --- a/python/tvm/dlight/gpu/utils.py +++ b/python/tvm/dlight/gpu/utils.py @@ -53,6 +53,8 @@ def suggest_threads_per_block( threads = 256 elif target.kind.name == "rocm": threads = 256 + elif target.kind.name == "metal": + threads = 256 else: threads = 64 results: List[Optional[int]] = [] From 7105e653a0384c0c6cee8884862d86e1a2245653 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 1 Nov 2023 13:20:14 -0700 Subject: [PATCH 2/2] Fix unittest --- python/tvm/dlight/gpu/gemv.py | 5 -- tests/python/dlight/test_gpu_gemv.py | 70 +++++++++++----------------- 2 files changed, 28 insertions(+), 47 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 9c7bab6e2f4d..76839d41662d 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -418,11 +418,6 @@ def apply( else: TS, TR = 16, 32 elif target.kind.name == "metal": - # VEC_C = 2 - # LOAD_V_SHARED = True - # LOAD_V_VEC = 4 - # UNROLL = 256 - # TS, TR = 64, 8 # Note that the following tile size is tuned on M2 Ultra for 7B TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" VEC_C = 4 diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 7f60d5db329d..83d2c3c06cb1 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -209,78 +209,64 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): - var_NT_matmul_intermediate_rf_local = T.alloc_buffer((16, 1, 1, 22016), "float16", scope="local") - var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((8, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 22016), "float16", scope="local") lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local") - lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") - for u_fused_ax0_fused_fused_0 in T.thread_binding(688, thread="blockIdx.x"): - for u_fused_ax0_fused_fused_1 in T.thread_binding(32, thread="threadIdx.y"): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(8, thread="threadIdx.x"): - for ax0, ax1 in T.grid(1, 1): - for ax2_0 in T.serial(4, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): - for ax2_1 in T.thread_binding(32, thread="threadIdx.y"): - for ax2_2 in T.thread_binding(8, thread="threadIdx.x"): - for ax2_3 in T.vectorized(4): - with T.block("lv1654_shared"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - v2 = T.axis.spatial(4096, ax2_0 * 1024 + ax2_1 * 32 + ax2_2 * 4 + ax2_3) - T.reads(lv1654[v0, v1, v2]) - T.writes(lv1654_shared[v0, v1, v2]) - lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] + for u_fused_ax0_fused_fused_0 in T.thread_binding(22016, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(1, thread="threadIdx.x"): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.y"): for u_fused_ax0_fused_fused_2_init in range(1): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(2): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): with T.block("NT_matmul_rf_init"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) T.reads() T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) - for ax1_0_fused_ax1_1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax0_0, ax1 in T.grid(1, 1): for ax0_1 in T.vectorized(1): with T.block("lv571_local"): - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) - v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 8 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + ax0_0 + ax0_1) + v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) T.reads(lv571[v0, v1]) T.writes(lv571_local[v0, v1]) lv571_local[v0, v1] = lv571[v0, v1] - for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 4): - for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(2): + for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): with T.block("NT_matmul_rf_update"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(16, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2], lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) + T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 8 + vax1_0_fused_ax1_1_fused_2 // 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 2 * 8 + vax1_0_fused_ax1_1_fused_2 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 2) // 32]) - for ax2_fused_0 in T.thread_binding(32, thread="threadIdx.y"): - for ax0 in T.thread_binding(8, thread="threadIdx.x"): + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + for ax2_fused_0 in T.thread_binding(1, thread="threadIdx.x"): + for ax0 in T.thread_binding(64, thread="threadIdx.y"): for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax2_fused_1_1 in T.vectorized(1): with T.block("NT_matmul_rf_init"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(8, ax0) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, v0 = T.axis.remap("SS", [ax0, u_fused_ax0_fused_fused_0]) T.reads() T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) - for ax1 in range(2): + for ax1 in range(4): with T.block("NT_matmul_rf_update"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) - T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, v0 = T.axis.remap("SRS", [ax0, ax1, u_fused_ax0_fused_fused_0]) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 2 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] for ax1_fused_1 in range(1): - for ax1_fused_0 in T.thread_binding(32, thread="threadIdx.y"): - for ax0 in T.thread_binding(8, thread="threadIdx.x"): + for ax1_fused_0 in T.thread_binding(1, thread="threadIdx.x"): + for ax0 in T.thread_binding(64, thread="threadIdx.y"): with T.block("NT_matmul"): - vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(8, ax0) - v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 32 + ax1_fused_0 + ax1_fused_1) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, v0 = T.axis.remap("RS", [ax0, u_fused_ax0_fused_fused_0]) T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) T.writes(var_NT_matmul_intermediate[0, 0, v0]) with T.init(): var_NT_matmul_intermediate[0, 0, v0] = T.float16(0) var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + # fmt: on mod = tvm.IRModule({"main": before})