Skip to content

Commit

Permalink
[DLIGHT][GPU] Improved gemv outer fallback schedule (#16973)
Browse files Browse the repository at this point in the history
* [DLIGHT][GPU] Improved gemv outer fallback schedule

Improved the gemv outer fallback schedules. It improved
few gemv kernel by 20%.

* Fix lint error

* Fix the gemv schedule params for dynamic vocab_size kernel
  • Loading branch information
krishnaraj36 committed May 21, 2024
1 parent 18a2a25 commit 209971a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 61 deletions.
39 changes: 28 additions & 11 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def apply(
TS, TR = 4, 64
else:
TS, TR = 16, 32
else:
TS, TR = 1, 64
elif target.kind.name == "metal":
# Note that the following tile size is tuned on M2 Ultra for 7B
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
Expand All @@ -476,6 +478,8 @@ def apply(
TS, TR = 4, 16
else:
TS, TR = 2, 64
else:
TS, TR = 1, 64
elif target.kind.name == "rocm":
VEC_C = 4
# TODO: set LOAD_V_SHARED = False for now
Expand All @@ -489,13 +493,15 @@ def apply(
TS, TR = 1, 128
else:
TS, TR = 8, 64
else:
TS, TR = 1, 64
elif target.kind.name == "opencl" and "android" in str(target.host):
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 8
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 8
TS, TR = 2, 64
TS, TR = 2, 32
elif target.kind.name == "vulkan":
VEC_C = 4
LOAD_V_SHARED = True
Expand All @@ -506,6 +512,8 @@ def apply(
TS, TR = 4, 32
else:
TS, TR = 16, 32
else:
TS, TR = 1, 64
elif target.kind.name == "opencl" and "mali" in str(target.attrs):
VEC_C = 8
LOAD_V_SHARED = False
Expand All @@ -519,9 +527,6 @@ def apply(
UNROLL = 64
TS, TR = 1, 64

if not isinstance(len_S, int):
TS, TR = 1, 64

while TS * TR > target.max_num_threads:
if TS > 1:
TS //= 2
Expand Down Expand Up @@ -709,7 +714,11 @@ def apply(
if not isinstance(len_r, int):
return None

if isinstance(len_s, int) and len_s > 32000:
if not isinstance(len_s, int):
TS, TR = 256, 1
LOAD_V_SHARED = True

if isinstance(len_s, int) and len_s > 96000:
return None

_, TILE_R = (
Expand Down Expand Up @@ -754,7 +763,8 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid
len_s = get_extent(sch, s)

# The config is designed for Adreno
tx_len = 64
LOAD_V_SHARED = 1
tx_len = 128
vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1
inner_r = 4

Expand All @@ -768,16 +778,23 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid
sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8)
sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)

cache_v = sch.cache_read(block, vector_input_buffers[0], "local")
sch.compute_at(cache_v, r1, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(cache_v)[-1])
if LOAD_V_SHARED:
V_shared = sch.cache_read(block, vector_input_buffers[0], storage_scope="shared")
sch.compute_at(V_shared, bx, preserve_unit_loops=True)
l = sch.get_loops(block=V_shared)[-1]
_, tx, vec_r = sch.split(l, factors=[None, tx_len, 8], preserve_unit_iters=True)
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec_r)

sch.vectorize(vec)

# Schedule epilogue
if epilogue_info is not None:
sch.reverse_compute_at(epilogue_info.block_rv, tx)

sch.reverse_compute_at(epilogue_info.block_rv, bx, preserve_unit_loops=True)
ts_tile_s = sch.get_loops(epilogue_info.block_rv)[-1]
ts, vec = sch.split(ts_tile_s, factors=[tx_len, vec_len], preserve_unit_iters=True)
sch.bind(ts, "threadIdx.x")
sch.vectorize(vec)
sch.set_scope(block, 0, "local")

sch.decompose_reduction(block, r0)
Expand Down

0 comments on commit 209971a

Please sign in to comment.