From c86fc68273731136478a59d37c636c47a4bcc08f Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 19 Aug 2024 21:24:47 +0000 Subject: [PATCH 1/4] Replaces the repeat kv for dense attention and flash attention kernel. --- jetstream_pt/environment.py | 1 - keys_original | Bin 0 -> 6074 bytes original_scores | Bin 0 -> 3652 bytes 3 files changed, 1 deletion(-) create mode 100644 keys_original create mode 100644 original_scores diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index fad4472..de1bdd7 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -124,7 +124,6 @@ class JetEngineEnvironmentData: # The ratio between query heads and kv heads n_reps: int = 0 - # pylint: disable-next=all class JetEngineEnvironment: # pylint: disable-next=all diff --git a/keys_original b/keys_original new file mode 100644 index 0000000000000000000000000000000000000000..515a15d16b7c408bca3aca1b6fb50216f38a0113 GIT binary patch literal 6074 zcmbW*du&rx90%~b_o#Ec-(fJEW(bT@*w#Vs??>uVLmEPvk1PzGrLC-?ZtG)<@CO;n zLox{@JVHQ-5(CH-Q6c`*`#VM?NQi>4C>rAnqR|*)NKC{;^jtQUwbE~lXZhrAz4v_c zyXUq$%`f0DC^U@HQsZA^f>CCyjjT(yb|j)}qV2)hyl^lTtm#}E^La~+m4l-vKa}c7 zgxYFayHe3u($SiTtnP}&!mX)Dd$J?Z+U;`0jq2*=5}lvaR|eNa;&}s89A00nBbb_B z_t?Pej(C1^p1#Z(FK8~%$53Ziyl`31h*Z32h2N<5=lPfT$NKYA@#1Rzsz2Y~l!|t=CmqWZ;YcD9Zj6Rfj(F*;rnsK&Xo}xg6Y!Umemvz^#Tz==7RY(y z6OA!hH@hKUJg3jQFe~}) z3ghEp`c9B?ev=WeI-rp2RLO{s&j?bX&JuEy#xRmk(>zEEZ&1i#IzdPSe_6jqP^}848w)QH-L(>)Vh#JdCb2>@LOV(5i*+Ux%IW08`Iij459HmWy z%-|Ika?+|%NH1B0oLbz;$Wg0b5Dz8w^U1ACZYacM_6gD^a~Zp!&*wH-P75+h9U)|^ z`prbn%P~Q=OTt9@J+*{%(`ZJvP#YmXN}EC!Q57Q}ii43|T)@aqCPG*(VPv-Y)IvgX z!9otxQ+of;?IA8v$Y1gvA#d|_L8|ygLFUU3X~etuBS8XmoRCUR>u)Hzy-&*oxlFeN z`9*dMQlTDTq>~Q`vXRRavWsR4GKFJ|9He!OoTgVTq@PZtkv%F&NRL@Y$TUy4{)U%Z z3w>ZBuX$P&GD|E$o|SJHDK#H7k+;l~f+W-?LC(oeg&YzGA$?Y*AZhL+WIx@okg2NI zL?U#P5Dzuz{WCYGJjX~#e}0jN(-}!}n31JiPKZN# zO=PJwGSW>KO=OkY#>i;?UGLDj{cUZvkc;}i7t$g#6%wTvJV+HE)jND{*R4$oIUu(! zq(n^;LJsQP9hqmv z2>DWW2=Zw9XF}R}BqN)}BBY4Vn#g+Er;z*ka|>CnUMD1<&!v$@Rl`Ub@3)ZerAUwi z@;D=V&36QO#VTMV$hQTtJ8x6cSKc+o>b9GfyYn{QFta*u-H}96Hw}H>hX2)iyVICk z;IJ$B?9iGg@^t~W+!a(AbFwO^xzqehMiS9rEV@2uw=i3yVV}49KNsYzt*>*}y4|() z3*5C04Rx-1x65e{ckM#wf(EC{RbTIJaMjsW&e-ts^VPaSTfm=Z7dzDX9o`z9V4uhs z_wEB{XAJLvPRy8pOaGObG1OAc4jKM1$7O^(shh)@Ap@<`?1bSDV|+%!B^{WVF!&b8 zju^gDRT&XYIx#ci-p?@LFD{-lIJ2T)?odY2;BR-2&)6MrJ^mQ1RL`?pWh6ZLXs1S_wEkMf zFCAmsp`hZgwbWJ{C}*!@Y{rHLi5(j?HfZXYD5(Ksqmwc<&UwXY{EV5&oSXCRyZ7$9 z?%L<TWxkRn%f{!fEGsB14b3g_E?!wsn4eqbEh_bu4m<$vV`R9uKS_I-R1>Yv-slf^Sayfk`*h9181EMEGrJ0lj(}h49J{nd!Mpm zcdWge*~cCMWyQg7FInO%Ec7lZEASPShW@@J-&^9%|6Re7ve4p?mtQUJQC1xKYH`o; zS((8hhmXqGqZZlUv!3gCv@7Bn#}~g*=JolZ)6>EpdEVuJVL@40q1O+-k^Ua-w!r@W zeDHp*$X)pV?YD$}p%BjDWPu!@*Lah2UzTDt_ouz$#x3dad`6K{ zhs9UR=SuNJBVzxtebv2s`;(7$=(AsSkJRgR;R|U2429({YPBopEwV z+9kQMjuy%??ydUEn_MC-POWp^N#t=HNVB<@XB^7m;VQZ!b9os;O&~79P`$n0gx>hv zS;FPkb<(?*W-w4&S`N4}ZC zE4UoP&3X*r7u6)wD(88ux=Gh101-$q%_h~HGbh=l9%8>~#0}oBYD_NTRi*lhV{igD zrBmRDufDv_fCRc^QOiXM(d2Z88g~yo=sP8&%0>xuy?upwkzH=wp78Pf_87I=8U~ zdr;4VxfR2BHSI&T?&Jx)6gLrQg7G50cUC&5a7R`nohu3G%@JtPqoj#@NIcdfjyLG@ zT;XImt9dK8(kV>9V2+b>h)@f}XUF>iUFE~7mFMERWf8CHB~pdzHI%O!u}#f4^LdV$ zsSa^AbzrLXBbP#G5@LB8VsKgBLkP#1A*xM$>^3!$!b#?19%edp7LP>>);g!f=X6p# zpS0|~dW_}STg~Cid=0VcO=leXQH)y4AM$1_!ZcPW;a)1m`VwY7!jE>muIuA^3uRyr z*Yj>Zj&@n09?ETT)Y&Ot@d+A^W9A-3nj)#-o48L=>NE$NqjJ|;T}7n|$80EE&c$7Ha*kdkIasKMs{v}L zsX?vUhgayBisB%33~d-KQK-U)sMd{Wp(~he`tm&NZ#>A&+!uLVhZUU4m*li;m&oz*VcXbVY@nxqZ!sl$25&Wa% z%Nm}~eYuEdpq>_Bm~zxQKw^2gX~K7$plao~ z>oVUnGTQRE9M3E}h^4cRKB*WG=4|ahmP@O>}Mx(!CRc9HOn<$lnNY&2K zCe5UBz6_*PX+t4z;9oF<+%nTv=_+{zO*(=d>TYXaquxe0Xun=$$7K^G(L*#~JM}a< zDw~s;%>|t0G{rI!ALE$K-@c{}i44mtk$3P~IzVxff)>v{3X)Y6&#$3En)MLg$w$bI zRyyMxr@hpZVst7l=tM63G8iZ62%VCpvQ_*Dd6?ITQVq7f_Vp3VwSYn`Vef+Py?Ok2!4suRU78; zb$tP~@|Ip{>&qTfF5A^kY%$$bBfqAaY>f;tu_{n4SGBy0L$JuaBvI-!8>KPI;55>( zfj6oT_zwH{F6LpijcCoskYvZOof1sF<$4y$SjLL8xn11)3!2ez-NxNs3E%_HU=-;f zuGG<}q#3lu=IcVODS*Qf0JqJSb9BU7_J%{V;c(B!HoKnPaAKr|B1|(z*s7I?bNC1L z@jIME2`Ez0rV&AEHm|j-^2g|)HFQtBvYdQ2Di4T9`q4dS5>DH(YLZ0BrvzF^A0Y~M z9kIFX##`2-4vSPJKiL(d3R~>`F&JW>EJuyZQD>fjVDZ;dZ;895!hFHno5ks}RKL7QwfUrxi-NFHq}d9m5;tde!Ih)ZY@viTeN z&T5npgqt|EOg73_bcpk49d%k(-^c^`T6(Zu+vvPFSm2!GGjP*Da{?)R8D2GwugEHP zTiZfskM+0Es+#WnuJzlm{MxVg`RDo`E|ztGp$pR<@?6P5wWm->=ee zQmCK7w^d93(9P;je}6uMQLZu1Fc|-<_$>976yy~atj_a`=-h(*^t4g`Bcb@D$;l~6 zDJe;lC#58%rY29EoH8-q-%{fzCnhINo}3tOpCnB5vwX4nU`niI=o@me8hdhn)xCW1Y0^w(S^)$+FqW)=^$0xMk6EJh_2IXm(F8Tg6#^Y1c{|St& zp3h=D3WTRY9v`t$Pk{XGKY;w_xy;IRyW<|cJEZ%#-%x@c{rmU%kN&-vZGQ^uI@|aB bn>N?=@UCnB Date: Mon, 19 Aug 2024 21:25:36 +0000 Subject: [PATCH 2/4] Remove temp test fils. --- keys_original | Bin 6074 -> 0 bytes original_scores | Bin 3652 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 keys_original delete mode 100644 original_scores diff --git a/keys_original b/keys_original deleted file mode 100644 index 515a15d16b7c408bca3aca1b6fb50216f38a0113..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6074 zcmbW*du&rx90%~b_o#Ec-(fJEW(bT@*w#Vs??>uVLmEPvk1PzGrLC-?ZtG)<@CO;n zLox{@JVHQ-5(CH-Q6c`*`#VM?NQi>4C>rAnqR|*)NKC{;^jtQUwbE~lXZhrAz4v_c zyXUq$%`f0DC^U@HQsZA^f>CCyjjT(yb|j)}qV2)hyl^lTtm#}E^La~+m4l-vKa}c7 zgxYFayHe3u($SiTtnP}&!mX)Dd$J?Z+U;`0jq2*=5}lvaR|eNa;&}s89A00nBbb_B z_t?Pej(C1^p1#Z(FK8~%$53Ziyl`31h*Z32h2N<5=lPfT$NKYA@#1Rzsz2Y~l!|t=CmqWZ;YcD9Zj6Rfj(F*;rnsK&Xo}xg6Y!Umemvz^#Tz==7RY(y z6OA!hH@hKUJg3jQFe~}) z3ghEp`c9B?ev=WeI-rp2RLO{s&j?bX&JuEy#xRmk(>zEEZ&1i#IzdPSe_6jqP^}848w)QH-L(>)Vh#JdCb2>@LOV(5i*+Ux%IW08`Iij459HmWy z%-|Ika?+|%NH1B0oLbz;$Wg0b5Dz8w^U1ACZYacM_6gD^a~Zp!&*wH-P75+h9U)|^ z`prbn%P~Q=OTt9@J+*{%(`ZJvP#YmXN}EC!Q57Q}ii43|T)@aqCPG*(VPv-Y)IvgX z!9otxQ+of;?IA8v$Y1gvA#d|_L8|ygLFUU3X~etuBS8XmoRCUR>u)Hzy-&*oxlFeN z`9*dMQlTDTq>~Q`vXRRavWsR4GKFJ|9He!OoTgVTq@PZtkv%F&NRL@Y$TUy4{)U%Z z3w>ZBuX$P&GD|E$o|SJHDK#H7k+;l~f+W-?LC(oeg&YzGA$?Y*AZhL+WIx@okg2NI zL?U#P5Dzuz{WCYGJjX~#e}0jN(-}!}n31JiPKZN# zO=PJwGSW>KO=OkY#>i;?UGLDj{cUZvkc;}i7t$g#6%wTvJV+HE)jND{*R4$oIUu(! zq(n^;LJsQP9hqmv z2>DWW2=Zw9XF}R}BqN)}BBY4Vn#g+Er;z*ka|>CnUMD1<&!v$@Rl`Ub@3)ZerAUwi z@;D=V&36QO#VTMV$hQTtJ8x6cSKc+o>b9GfyYn{QFta*u-H}96Hw}H>hX2)iyVICk z;IJ$B?9iGg@^t~W+!a(AbFwO^xzqehMiS9rEV@2uw=i3yVV}49KNsYzt*>*}y4|() z3*5C04Rx-1x65e{ckM#wf(EC{RbTIJaMjsW&e-ts^VPaSTfm=Z7dzDX9o`z9V4uhs z_wEB{XAJLvPRy8pOaGObG1OAc4jKM1$7O^(shh)@Ap@<`?1bSDV|+%!B^{WVF!&b8 zju^gDRT&XYIx#ci-p?@LFD{-lIJ2T)?odY2;BR-2&)6MrJ^mQ1RL`?pWh6ZLXs1S_wEkMf zFCAmsp`hZgwbWJ{C}*!@Y{rHLi5(j?HfZXYD5(Ksqmwc<&UwXY{EV5&oSXCRyZ7$9 z?%L<TWxkRn%f{!fEGsB14b3g_E?!wsn4eqbEh_bu4m<$vV`R9uKS_I-R1>Yv-slf^Sayfk`*h9181EMEGrJ0lj(}h49J{nd!Mpm zcdWge*~cCMWyQg7FInO%Ec7lZEASPShW@@J-&^9%|6Re7ve4p?mtQUJQC1xKYH`o; zS((8hhmXqGqZZlUv!3gCv@7Bn#}~g*=JolZ)6>EpdEVuJVL@40q1O+-k^Ua-w!r@W zeDHp*$X)pV?YD$}p%BjDWPu!@*Lah2UzTDt_ouz$#x3dad`6K{ zhs9UR=SuNJBVzxtebv2s`;(7$=(AsSkJRgR;R|U2429({YPBopEwV z+9kQMjuy%??ydUEn_MC-POWp^N#t=HNVB<@XB^7m;VQZ!b9os;O&~79P`$n0gx>hv zS;FPkb<(?*W-w4&S`N4}ZC zE4UoP&3X*r7u6)wD(88ux=Gh101-$q%_h~HGbh=l9%8>~#0}oBYD_NTRi*lhV{igD zrBmRDufDv_fCRc^QOiXM(d2Z88g~yo=sP8&%0>xuy?upwkzH=wp78Pf_87I=8U~ zdr;4VxfR2BHSI&T?&Jx)6gLrQg7G50cUC&5a7R`nohu3G%@JtPqoj#@NIcdfjyLG@ zT;XImt9dK8(kV>9V2+b>h)@f}XUF>iUFE~7mFMERWf8CHB~pdzHI%O!u}#f4^LdV$ zsSa^AbzrLXBbP#G5@LB8VsKgBLkP#1A*xM$>^3!$!b#?19%edp7LP>>);g!f=X6p# zpS0|~dW_}STg~Cid=0VcO=leXQH)y4AM$1_!ZcPW;a)1m`VwY7!jE>muIuA^3uRyr z*Yj>Zj&@n09?ETT)Y&Ot@d+A^W9A-3nj)#-o48L=>NE$NqjJ|;T}7n|$80EE&c$7Ha*kdkIasKMs{v}L zsX?vUhgayBisB%33~d-KQK-U)sMd{Wp(~he`tm&NZ#>A&+!uLVhZUU4m*li;m&oz*VcXbVY@nxqZ!sl$25&Wa% z%Nm}~eYuEdpq>_Bm~zxQKw^2gX~K7$plao~ z>oVUnGTQRE9M3E}h^4cRKB*WG=4|ahmP@O>}Mx(!CRc9HOn<$lnNY&2K zCe5UBz6_*PX+t4z;9oF<+%nTv=_+{zO*(=d>TYXaquxe0Xun=$$7K^G(L*#~JM}a< zDw~s;%>|t0G{rI!ALE$K-@c{}i44mtk$3P~IzVxff)>v{3X)Y6&#$3En)MLg$w$bI zRyyMxr@hpZVst7l=tM63G8iZ62%VCpvQ_*Dd6?ITQVq7f_Vp3VwSYn`Vef+Py?Ok2!4suRU78; zb$tP~@|Ip{>&qTfF5A^kY%$$bBfqAaY>f;tu_{n4SGBy0L$JuaBvI-!8>KPI;55>( zfj6oT_zwH{F6LpijcCoskYvZOof1sF<$4y$SjLL8xn11)3!2ez-NxNs3E%_HU=-;f zuGG<}q#3lu=IcVODS*Qf0JqJSb9BU7_J%{V;c(B!HoKnPaAKr|B1|(z*s7I?bNC1L z@jIME2`Ez0rV&AEHm|j-^2g|)HFQtBvYdQ2Di4T9`q4dS5>DH(YLZ0BrvzF^A0Y~M z9kIFX##`2-4vSPJKiL(d3R~>`F&JW>EJuyZQD>fjVDZ;dZ;895!hFHno5ks}RKL7QwfUrxi-NFHq}d9m5;tde!Ih)ZY@viTeN z&T5npgqt|EOg73_bcpk49d%k(-^c^`T6(Zu+vvPFSm2!GGjP*Da{?)R8D2GwugEHP zTiZfskM+0Es+#WnuJzlm{MxVg`RDo`E|ztGp$pR<@?6P5wWm->=ee zQmCK7w^d93(9P;je}6uMQLZu1Fc|-<_$>976yy~atj_a`=-h(*^t4g`Bcb@D$;l~6 zDJe;lC#58%rY29EoH8-q-%{fzCnhINo}3tOpCnB5vwX4nU`niI=o@me8hdhn)xCW1Y0^w(S^)$+FqW)=^$0xMk6EJh_2IXm(F8Tg6#^Y1c{|St& zp3h=D3WTRY9v`t$Pk{XGKY;w_xy;IRyW<|cJEZ%#-%x@c{rmU%kN&-vZGQ^uI@|aB bn>N?=@UCnB Date: Mon, 19 Aug 2024 21:59:42 +0000 Subject: [PATCH 3/4] Fix the performance regression with ragged attention on for llama2 7b model. --- jetstream_pt/environment.py | 1 + jetstream_pt/layers.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index de1bdd7..fad4472 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -124,6 +124,7 @@ class JetEngineEnvironmentData: # The ratio between query heads and kv heads n_reps: int = 0 + # pylint: disable-next=all class JetEngineEnvironment: # pylint: disable-next=all diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index d66909d..1d82a1c 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -438,7 +438,7 @@ def attend(xq, keys, values, local_mask=None): xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, xq, @@ -589,7 +589,7 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) # We are not using ragged attention for prefill yet. - if self.env.ragged_mha and seqlen == 1: + if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, xq, From 1874ad55807d925350dd5baaaaa5c6ff77550afa Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Tue, 20 Aug 2024 21:59:21 +0000 Subject: [PATCH 4/4] Replace global attention out calculation with more numerical stable way. Fix tests. Replace args with kwargs when possible to avoid potential issues. --- jetstream_pt/attention_kernel.py | 59 ++++++++-------- jetstream_pt/layers.py | 116 +++++++++++++++++-------------- tests/test_model_impl.py | 39 +++++++++-- 3 files changed, 129 insertions(+), 85 deletions(-) diff --git a/jetstream_pt/attention_kernel.py b/jetstream_pt/attention_kernel.py index 38edc89..234e9fe 100644 --- a/jetstream_pt/attention_kernel.py +++ b/jetstream_pt/attention_kernel.py @@ -558,33 +558,6 @@ def ragged_mha( return out, (m, l) -def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): - """The vanilla attention kernel implementation.""" - - bsz, _, _, head_dim = xq.shape - with jax.named_scope("attn_mat1"): - ## Attention start - # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) - scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) - if k_scaler is not None: - scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) - if mask is not None: - # if mask.shape != (1,1,16,16): - # breakpoint() - scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) - with jax.named_scope("attn_soft"): - scores = F.softmax(scores.float(), dim=-1).type_as(xq) - if v_scaler is not None: - scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) - - with jax.named_scope("attn_mat2"): - # output = torch.einsum( - # "ikjm,ikml->ikjl", scores, values - # ) # (bs, n_local_heads, seqlen, head_dim) - output = torch.einsum("ikjm,ikml->ikjl", scores, values) - return output - - def reshape_heads(xq, keys): """Reshapes the query head for GQA""" bq, hq, tq, dq = xq.shape @@ -607,6 +580,29 @@ def reshape_outputs(rep, o, m=None, d=None): return o, (m, d) +def _dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): + """The vanilla attention kernel implementation.""" + + bsz, _, _, head_dim = xq.shape + with jax.named_scope("attn_mat1"): + ## Attention start + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) + if k_scaler is not None: + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen + with jax.named_scope("attn_soft"): + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + if v_scaler is not None: + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) + + with jax.named_scope("attn_mat2"): + output = torch.einsum( + "ikjm,ikml->ikjl", scores, values + ) # (bs, n_local_heads, seqlen, head_dim) + return output + + def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): """The vanilla attention kernel implementation.""" xq, rep = reshape_heads(xq, keys) @@ -680,7 +676,14 @@ def flash_attention( """Flash attention kernel.""" xq, rep = reshape_heads(xq, keys) o, (logits_max, denominator) = _flash_attention( - xq, keys, values, k_scaler, v_scaler, mask + xq=xq, + keys=keys, + values=values, + layer=layer, + k_scaler=k_scaler, + v_scaler=v_scaler, + mask=mask, + normalize_var=normalize_var, ) return reshape_outputs(rep, o, logits_max, denominator) diff --git a/jetstream_pt/layers.py b/jetstream_pt/layers.py index 1d82a1c..1ef2fe5 100644 --- a/jetstream_pt/layers.py +++ b/jetstream_pt/layers.py @@ -433,7 +433,6 @@ def attend(xq, keys, values, local_mask=None): # When GQA is enabled, it not necessary to expand if not (self.env.ragged_mha and n_rep > 1) and seqlen == 1: true_len = 2 - # xq = torch.broadcast_to(xq, (xq.shape[0], xq.shape[1], 2, xq.shape[3])) xq = torch.nn.functional.pad( xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) @@ -449,15 +448,28 @@ def attend(xq, keys, values, local_mask=None): end, ragged_batch_index, ragged_block_index, + None, # k_scaler + None, # v_scaler ) elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( - xq, keys, values, self.layer_id, mask=local_mask + xq=xq, + keys=keys, + values=values, + layer=self.layer_id, + k_scaler=None, + v_scaler=None, + mask=local_mask, ) else: local_output = self.dense_attention( - xq, keys, values, None, None, local_mask + xq=xq, + keys=keys, + values=values, + k_scaler=None, + v_scaler=None, + mask=local_mask, ) local_max = None local_denom = None @@ -474,9 +486,6 @@ def attend(xq, keys, values, local_mask=None): if local_denom is not None: local_denom = local_denom[:, :, 0:seqlen, :] - # print(f"attention kernel local_output {local_output.shape} seqlen {seqlen}") - # if local_max is not None and local_denom is not None: - # print(f"local_max {local_max.shape} local_denom {local_denom.shape}") self.env.apply_sharding(local_output, axis=self.q_shard_axis) return local_output, (local_max, local_denom) @@ -486,7 +495,7 @@ def attend(xq, keys, values, local_mask=None): # print(f"attention kernel xq {xq.shape} seqlen {seqlen} keys {keys.shape} mask {mask.shape}") with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend( - xq, orig_keys, orig_values, mask + xq=xq, keys=orig_keys, values=orig_values, local_mask=mask ) # Updating cache during each step still has very large impact on latency. # For non flash attention or prefill, existing output contains everything @@ -495,23 +504,20 @@ def attend(xq, keys, values, local_mask=None): # For flash attention, existing output contains the existing kv cache generated logits with jax.named_scope("attn_new_qkv"): - new_output, (new_max, new_denom) = attend(xq, xk, xv, None) + new_output, (new_max, new_denom) = attend( + xq=xq, keys=xk, values=xv, local_mask=None + ) with jax.named_scope("attn_global"): - # print(f"existing_output {existing_output} existing_max {existing_max} existing_denom {existing_denom}") - # print(f"new_output {new_output} new_max {new_max} new_denom {new_denom}") - - global_sum = existing_denom * torch.exp( - existing_max - ) + new_denom * torch.exp(new_max) - existing_output = ( - existing_output - * existing_denom - * torch.exp(existing_max) - / global_sum - ) - new_output = new_output * new_denom * torch.exp(new_max) / global_sum - attn_out = existing_output + new_output + global_max = torch.max(existing_max, new_max) + alpha = torch.exp(existing_max - global_max) + beta = torch.exp(new_max - global_max) + global_denom = alpha * existing_denom + beta * new_denom + # global_denom = torch.where(global_denom == 0.0, 1.0, global_denom) + attn_out = ( + existing_denom * alpha * existing_output + + beta * new_output * new_denom + ) / global_denom return attn_out @@ -588,7 +594,6 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): xq, (0, 0, 0, true_len - seqlen), "constant", 0 ) - # We are not using ragged attention for prefill yet. if self.env.ragged_mha and seqlen == 1 and keys.shape[-2] > 1: local_output, (local_max, local_denom) = torch_xla2.interop.call_jax( impl, @@ -606,17 +611,22 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): elif self.env.flash_attention and seqlen == 1: with torch_xla2.default_env(): local_output, (local_max, local_denom) = self.flash_attention( - xq, - keys, - values, - self.layer_id, - k_scaler, - v_scaler, + xq=xq, + keys=keys, + values=values, + layer=self.layer_id, + k_scaler=k_scaler, + v_scaler=v_scaler, mask=local_mask, ) else: local_output = self.dense_attention( - xq, keys, values, k_scaler, v_scaler, local_mask + xq=xq, + keys=keys, + values=values, + k_scaler=k_scaler, + v_scaler=v_scaler, + mask=local_mask, ) local_max = None local_denom = None @@ -648,7 +658,12 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) = cache.update(xk, xv, self.layer_id) with jax.named_scope("attn_qkv"): existing_output, (existing_max, existing_denom) = attend( - xq, orig_keys, orig_values, k_scaler, v_scaler, mask + xq=xq, + keys=orig_keys, + values=orig_values, + k_scaler=k_scaler, + v_scaler=v_scaler, + local_mask=mask, ) # For non flash attention or prefill, existing output contains everything @@ -663,18 +678,15 @@ def attend(xq, keys, values, k_scaler, v_scaler, local_mask=None): ) with jax.named_scope("attn_global"): - global_sum = existing_denom * torch.exp( - existing_max - ) + new_denom * torch.exp(new_max) - existing_output = ( - existing_output - * existing_denom - * torch.exp(existing_max) - / global_sum - ) - new_output = new_output * new_denom * torch.exp(new_max) / global_sum - attn_out = existing_output + new_output - + global_max = torch.max(existing_max, new_max) + alpha = torch.exp(existing_max - global_max) + beta = torch.exp(new_max - global_max) + global_denom = alpha * existing_denom + beta * new_denom + # global_denom = torch.where(global_denom == 0.0, 1.0, global_denom) + attn_out = ( + existing_denom * alpha * existing_output + + beta * new_output * new_denom + ) / global_denom return attn_out @@ -800,16 +812,16 @@ def forward( # if cache is not None and cache.cache_k is not None: # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") output = self.attention_kernel( - xq, - xk, - xv, - mask, + xq=xq, + xk=xk, + xv=xv, + mask=mask, # cache[self.layer_id], - cache, - start, - end, - ragged_batch_index, - ragged_block_index, + cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, ).type_as(xq) # print(f"output {output.shape}") output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index 0b76c86..ff9ff84 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -18,6 +18,8 @@ import torch import torch_xla2 +from absl.testing import parameterized + from jetstream_pt.third_party.llama import model_exportable from jetstream_pt.third_party.llama import model_original from jetstream_pt.third_party.gemma import model_original as gemma_orig @@ -32,7 +34,7 @@ from . import helpers -class ModelComponentTest(unittest.TestCase): +class ModelComponentTest(parameterized.TestCase): """Test diff between original model and xla model for transformer, transformer block, attention and other component in model""" @@ -75,7 +77,7 @@ def _generate_mask(self, cache_length, pos, seqlen, ring_buffer=True): if ring_buffer: cond = jnp.logical_and(x <= pos, x >= pos - seqlen) else: - # Left aligned buffer we postpone the cache update + # Left aligned buffer we postpone the cache update therefore mask out pos cond = jnp.logical_and(x < pos, x >= pos - seqlen) res = jnp.where(cond, 0, float("-inf")) return torchjax.to_torch(res) @@ -98,10 +100,33 @@ def _make_one_cache_for_generate(self, env, pos): ) return cache_decode + @parameterized.named_parameters( + ("ring_buffer", "ring"), + ("non_ring_buffer_flash_attention", "flash"), + ("non_ring_buffer_ragged_attention", "ragged"), + ) # pylint: disable-next=all - def test_attention(self): + def test_attention(self, attn_type): torch.manual_seed(0) env, model_arg = helpers.make_env_tiny(False) + if attn_type == "ring": + env.lazy_cache_update = False + env.ragged_mha = False + env.flash_attention = False + self.generate_cache_stacked = False + env.ring_buffer = True + elif attn_type == "flash": + env.lazy_cache_update = True + env.ragged_mha = True + env.flash_attention = True + self.generate_cache_stacked = True + env.ring_buffer = False + elif attn_type == "flash": + env.lazy_cache_update = True + env.ragged_mha = False + env.flash_attention = True + self.generate_cache_stacked = True + env.ring_buffer = False attention_orig = model_original.Attention(model_arg) attention_ours = layers.Attention( @@ -167,10 +192,14 @@ def test_attention(self): ) expected_out = attention_orig(*inputs_orig2) cache_decode.input_pos = [pos] # next position to update - mask = self._generate_mask(env.cache_sequence_length, pos, seqlen) + mask = self._generate_mask( + env.cache_sequence_length, pos, seqlen, env.ring_buffer + ) mask = mask.reshape(1, 1, 1, -1) # seq dim is the last one freqs_cis = freqs_cis.reshape(batch, 1, -1) - input_ours2 = (x2, freqs_cis, mask, cache_decode) + start = torch.tensor([0] * batch, dtype=torch.int) + end = torch.tensor([pos] * batch, dtype=torch.int) + input_ours2 = (x2, freqs_cis, mask, cache_decode, start, end) result_torch = helpers.call_xla_model( attention_ours, attention_orig.state_dict(), input_ours2 )