1313from src .operator .aie_gemm import AIEGEMM
1414from src .operator .aie_gemv import AIEGEMV
1515from src .operator .aie_silu import AIESiLU
16+ from src .operator .aie_swiglu_prefill import AIESwiGLUPrefill
17+ from src .operator .aie_swiglu_decode import AIESwiGLUDecode
18+ from ml_dtypes import bfloat16
1619
1720
1821class FeedForward (nn .Module ):
@@ -25,6 +28,16 @@ def __init__(
2528 super ().__init__ ()
2629 self .cfg = cfg .copy ()
2730
31+ assert (
32+ cfg ["use_aie_ffn_swiglu" ]
33+ and not (
34+ cfg ["use_aie_ffn_silu" ]
35+ or cfg ["use_aie_ffn_gemm" ]
36+ or cfg ["use_aie_ffn_mul" ]
37+ )
38+ or not cfg ["use_aie_ffn_swiglu" ]
39+ ), "Cannot mix fused SwiGLU with individual AIE operators."
40+
2841 self .emb_dim = cfg ["emb_dim" ]
2942 self .hidden_dim = cfg ["hidden_dim" ]
3043
@@ -36,10 +49,17 @@ def __init__(
3649 else :
3750 self .silu = nn .SiLU ()
3851
39- self .emb_dim = cfg ["emb_dim" ]
40- self .hidden_dim = cfg ["hidden_dim" ]
52+ if self .cfg ["use_aie_ffn_swiglu" ]:
53+ self .aie_swiglu_prefill = AIESwiGLUPrefill (
54+ seq_len = prompt_length ,
55+ embedding_dim = self .emb_dim ,
56+ hidden_dim = self .hidden_dim ,
57+ )
58+ if self .cfg ["use_kv_cache" ]:
59+ self .aie_swiglu_decode = AIESwiGLUDecode (
60+ embedding_dim = self .emb_dim , hidden_dim = self .hidden_dim
61+ )
4162
42- # Initialize FFN up and down projections
4363 if self .cfg ["use_aie_ffn_gemm" ]:
4464 if self .cfg ["use_kv_cache" ]:
4565 M_prefill = prompt_length
@@ -108,8 +128,15 @@ def forward(self, x):
108128 or (len (x .shape ) == 3 and x .shape [0 ] == 1 and x .shape [1 ] == 1 )
109129 )
110130
131+ is_prefill = not is_vector or not self .cfg ["use_kv_cache" ]
111132 is_decode_with_kv = is_vector and self .cfg ["use_kv_cache" ]
112133
134+ if self .cfg ["use_aie_ffn_swiglu" ]:
135+ if is_prefill :
136+ return self .aie_swiglu_prefill (x )
137+ else :
138+ return self .aie_swiglu_decode (x )
139+
113140 if is_decode_with_kv and self .cfg ["use_aie_gemv" ]:
114141 x_fc1 = self .aie_fc1_gemv (x )
115142 x_fc2 = self .aie_fc2_gemv (x )
@@ -131,6 +158,21 @@ def forward(self, x):
131158 return self .fc3 (x ).view (original_shape )
132159
133160 def assign_weights (self , l , fc1 , fc2 , fc3 ):
161+ if self .cfg ["use_kv_cache" ] and self .cfg ["use_aie_gemv" ]:
162+ self .aie_fc1_gemv .weight = fc1
163+ self .aie_fc2_gemv .weight = fc2
164+ self .aie_fc3_gemv .weight = fc3
165+
166+ if self .cfg ["use_aie_ffn_swiglu" ]:
167+ self .aie_swiglu_prefill .weights_1 = fc1
168+ self .aie_swiglu_prefill .weights_2 = fc2
169+ self .aie_swiglu_prefill .weights_3 = fc3
170+ if self .cfg ["use_kv_cache" ]:
171+ self .aie_swiglu_decode .weights_1 = fc1
172+ self .aie_swiglu_decode .weights_2 = fc2
173+ self .aie_swiglu_decode .weights_3 = fc3
174+ return
175+
134176 self .fc1 .weight = assign (
135177 self .fc1 .weight ,
136178 fc1 ,
@@ -146,8 +188,3 @@ def assign_weights(self, l, fc1, fc2, fc3):
146188 fc3 ,
147189 f"model.layers.{ l } .mlp.down_proj.weight" ,
148190 )
149-
150- if self .cfg ["use_kv_cache" ] and self .cfg ["use_aie_gemv" ]:
151- self .aie_fc1_gemv .weight = fc1
152- self .aie_fc2_gemv .weight = fc2
153- self .aie_fc3_gemv .weight = fc3
0 commit comments