From de117e37a1378e5cd4743337e8d4dae672a9efe0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 24 Oct 2023 17:15:08 +0000 Subject: [PATCH] [Transform][Redo] Apply split_rotary optimization on prefill Prior to this commit, the `transform.fuse_split_rotary_embedding` function was only applicable to the `decode` function of a Llama-type model. This was due to the sequence length being restricted to one, both in the pattern-match rule and in the `split_rotary` function, and the function being restricted to operate only on the `decode` function. This commit updates the `transform.fuse_split_rotary_embedding` pass to be a `tvm.ir.transform.Pass`, operating on all applicable matched in the `IRModule`. The `split_rotary` function is now produced as a fully-generic function, with static parameters substituted in afterwards. At this stage, the sequence length is retained as a dynamic parameter, such that it can be used by the `prefill` function. This commit reapplies the reverted commit https://github.com/mlc-ai/mlc-llm/pull/1033. The error in the previous implementation was in the definition of `rotary_embedding_offset`, which provided the `query_sequence_length` instead of `kv_sequence_length`. This was able to pass the validity tests described [here](https://github.com/mlc-ai/mlc-llm/pull/1058#issuecomment-1761622534), as these two sequence lengths are identical for the first call. --- mlc_llm/core.py | 11 +- .../transform/fuse_split_rotary_embedding.py | 446 ++++++++++-------- 2 files changed, 247 insertions(+), 210 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 8c3a75c374..5550a87fcd 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -440,12 +440,11 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() mod = fuse_split_rotary_embedding( - mod, - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, - ) + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + )(mod) if args.target_kind == "cuda": patterns = [] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index a7dbdf6c31..ed19a7095c 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -1,5 +1,5 @@ +import tvm from tvm import relax -from tvm.script import tir as T from tvm.relax.dpl import ( PatternContext, is_op, @@ -10,237 +10,275 @@ TuplePattern, is_shape, ) -from tvm.script import relax as R +from tvm.script import relax as R, tir as T -def get_split_rotary(num_attention_heads, head_dim, position_embedding_base): - hidden_size = num_attention_heads * head_dim +def get_dynamic_split_rotary(): + """Implementation of R.split(rotary_embedding(fused_qkv)) - @T.prim_func + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, + fused_qkv_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + rotary_offset: T.int64, + batch_size: T.int64, + seq_len: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, ): - A = T.match_buffer(qkv, [1, 1, hidden_size * 3], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, hidden_size], dtype="float16") + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [batch_size, seq_len, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)], - ) - T.writes( - T_split[v_ax0, v_ax1, v_ax2], - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) + + for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) + pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) + inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), ) freq: T.float32 = pos * inv_freq cos_value: T.float16 = T.Cast("float16", T.cos(freq)) sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + T.int64(head_dim // 2)] + + input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] * T.float16(-1), + Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)] + if head_num < num_query_heads: + EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + param_sinfo = [] + for param in split_rotary.params: + if param in split_rotary.buffer_map: + buf = split_rotary.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + split_rotary, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) return split_rotary -def get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base +def fuse_split_rotary_embedding( + num_query_heads, num_kv_heads, hidden_size, position_embedding_base ): - query_hidden_size = num_query_heads * head_dim - kv_hidden_size = num_kv_heads * head_dim - total_size = query_hidden_size + kv_hidden_size * 2 + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary() - @T.prim_func - def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, - ): - A = T.match_buffer(qkv, [1, 1, total_size], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, query_hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, kv_hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, kv_hidden_size], dtype="float16") + ( + dyn_batch_size, + dyn_seq_len, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-6:] - T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(query_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - ) - T.writes(T_split[v_ax0, v_ax1, v_ax2]) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(kv_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size)], - ) - T.writes( - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + T.int64(head_dim // 2)] - * T.float16(-1), - ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size) - ] + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_batch_size: T.int64(1), + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), + } + ) - return split_rotary + mod["split_rotary"] = split_rotary + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) -def fuse_split_rotary_embedding( - mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base -): - if "rotary_embedding1" not in [gv.name_hint for gv in mod.functions]: - return mod - - head_dim = hidden_size // num_query_heads - mod["split_rotary"] = ( - get_split_rotary(num_query_heads, head_dim, position_embedding_base) - if num_query_heads == num_kv_heads - else get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base - ) - ) + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([batch_size, seq_len, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) - gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(gvar, mod.get_global_var("rotary_embedding1").struct_info) + pat_rotary_embedding_gvar = GlobalVarPattern() - with PatternContext() as ctx: - # lv3: R.Tuple(R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")) = R.split(lv2, indices_or_sections=[4096, 8192], axis=2) + pat_flat_fused_qkv = wildcard() + pat_offset = wildcard() - # lv1521: R.Tensor((1, 1, 4096), dtype="float16") = lv3[0] - # lv1522: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1521, R.shape([1, 1, 32, 128])) - # lv1524: R.Tensor((1, 1, 4096), dtype="float16") = lv3[1] - # lv1525: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1524, R.shape([1, 1, 32, 128])) - # lv1527: R.Tensor((1, 1, 4096), dtype="float16") = lv3[2] - # lv1528: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1527, R.shape([1, 1, 32, 128])) - # lv1530 = R.call_tir(cls.rotary_embedding1, (lv1525, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape([n])) - # lv_1 = R.call_tir(cls.rotary_embedding1, (lv1522, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape( + # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_value_shape = wildcard() - inp_pat = wildcard() - offset = wildcard() + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) - lv3 = is_op("relax.split")(inp_pat) - lv1521 = is_tuple_get_item(lv3, 0) - lv1522 = is_op("relax.reshape")( - lv1521, is_shape([1, 1, num_query_heads, head_dim]), add_constraint=False - ) - lv1521.used_by(lv1522) - lv1524 = is_tuple_get_item(lv3, 1) - lv1525 = is_op("relax.reshape")( - lv1524, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1524.used_by(lv1525) - lv1527 = is_tuple_get_item(lv3, 2) - V = is_op("relax.reshape")( - lv1527, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1527.used_by(V) + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_query]), + pat_offset, + add_constraint=False, + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_key]), + pat_offset, + add_constraint=False, + ) - Q = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1522]), offset, add_constraint=False - ) - K = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1525]), offset, add_constraint=False - ) + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) - lv3.used_by(lv1521) - lv3.used_by(lv1524) - lv3.used_by(lv1527) - lv1522.used_by(Q) - lv1525.used_by(K) - - def rewriter(matchings, bindings): - inp = matchings[inp_pat] - call_tir = matchings[Q] - n = bindings[call_tir].args[-1] - out_sinfo = [ - R.Tensor((1, 1, num_query_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - ] - lv3_new = R.call_tir( - mod.get_global_var("split_rotary"), (inp,), out_sinfo=out_sinfo, tir_vars=n - ) - lv1521_new = lv3_new[0] - lv1522_new = R.reshape(lv1521_new, R.shape([1, 1, num_query_heads, head_dim])) - lv1524_new = lv3_new[1] - lv1525_new = R.reshape(lv1524_new, R.shape([1, 1, num_kv_heads, head_dim])) - lv1527_new = lv3_new[2] - lv1528_new = R.reshape(lv1527_new, R.shape([1, 1, num_kv_heads, head_dim])) - - return { - matchings[lv3]: lv3_new, - matchings[lv1521]: lv1521_new, - matchings[lv1522]: lv1522_new, - matchings[lv1524]: lv1524_new, - matchings[lv1525]: lv1525_new, - matchings[lv1527]: lv1527_new, - matchings[V]: lv1528_new, - matchings[Q]: lv1522_new, - matchings[K]: lv1525_new, - } - - mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"]) - return mod + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + flat_qkv_tuple = matchings[pat_flat_qkv_tuple] + + flat_query = matchings[pat_flat_query] + flat_key = matchings[pat_flat_key] + flat_value = matchings[pat_flat_value] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + # rotary_embedding_offset = bindings[query].args[-1][1] + rotary_embedding_offset = bindings[embedded_query].args[-1][0] + + batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape + _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv,), + out_sinfo=split_rotary_sinfo, + tir_vars=[rotary_embedding_offset], + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + return { + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod + + return ir_module_pass