1212from pathlib import Path
1313from src .block .transformer import TransformerBlock
1414from operators .rope .rope_utils import compute_rope_params
15- from operators import AIERMSNorm
15+ from operators import (
16+ AIERMSNorm ,
17+ AIEGEMM ,
18+ )
1619from rich .console import Console
1720from rich .text import Text
1821
@@ -169,13 +172,37 @@ def __init__(
169172 self .cfg ["emb_dim" ], eps = 1e-5 , dtype = self .cfg ["dtype" ]
170173 )
171174
172- # Depedns on use_aie_final_gemm
173- self .out_head = nn .Linear (
174- self .cfg ["emb_dim" ],
175- self .cfg ["vocab_size" ],
176- bias = False ,
177- dtype = self .cfg ["dtype" ],
178- )
175+ # Offload final linear layer if enabled
176+ if self .cfg .get ("use_aie_final_gemm" , False ):
177+ # Since this GEMM has such a large N dimension, partition the N dimension by 4,
178+ # and GEMM will execute for a workload of that smaller N dimension across different buffers of B and C
179+ aie_config_prefill = {
180+ "num_aie_columns" : 8 ,
181+ "tile_m" : 64 ,
182+ "tile_k" : 64 ,
183+ "tile_n" : 64 ,
184+ "b_col_maj" : True ,
185+ "use_static_weight" : True ,
186+ "separate_c_tiles" : True ,
187+ "partition_N" : 4 ,
188+ }
189+ if self .cfg ["use_kv_cache" ]:
190+ M_for_gemm = self .prompt_length
191+ else :
192+ M_for_gemm = self .prompt_length + self .num_tokens
193+ self .out_head = AIEGEMM (
194+ M = M_for_gemm ,
195+ K = self .cfg ["emb_dim" ],
196+ N = self .cfg ["vocab_size" ],
197+ ** aie_config_prefill ,
198+ )
199+ else :
200+ self .out_head = nn .Linear (
201+ self .cfg ["emb_dim" ],
202+ self .cfg ["vocab_size" ],
203+ bias = False ,
204+ dtype = self .cfg ["dtype" ],
205+ )
179206
180207 # Reusable utilities
181208 cos , sin = compute_rope_params (
@@ -194,6 +221,22 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
194221 tok_embeds = self .tok_emb (in_idx )
195222 x = tok_embeds
196223
224+ # Check if input is a vector (decode phase) or matrix (prefill phase)
225+ # Handle 1D: (emb_dim,), 2D: (1, emb_dim), or 3D: (1, 1, emb_dim)
226+ is_vector = (
227+ len (x .shape ) == 1
228+ or (len (x .shape ) == 2 and x .shape [0 ] == 1 )
229+ or (len (x .shape ) == 3 and x .shape [0 ] == 1 and x .shape [1 ] == 1 )
230+ )
231+
232+ # (batch, sequence, embedding) where sequence=1 indicates decode
233+ if len (x .shape ) == 3 :
234+ is_decode_with_kv = (x .shape [1 ] == 1 ) and self .cfg ["use_kv_cache" ]
235+ elif len (x .shape ) == 2 :
236+ is_decode_with_kv = (x .shape [0 ] == 1 ) and self .cfg ["use_kv_cache" ]
237+ else :
238+ is_decode_with_kv = False
239+
197240 num_tokens = x .shape [1 ]
198241
199242 # During generation phase with KV cache, don't create a mask
@@ -219,19 +262,39 @@ def forward(self, in_idx, input_pos=None, use_kv_cache=False):
219262 else :
220263 x = self .final_norm (x )
221264
222- logits = self .out_head (x .to (self .cfg ["dtype" ]))
265+ if is_decode_with_kv and self .cfg ["use_aie_gemv" ]:
266+ # TODO: Offload to NPU
267+ # logits = self.aie_out_head_gemv(x)
268+ logits = self .out_head (x )
269+ else :
270+ logits = self .out_head (x )
223271
224272 return logits
225273
226- def assign_weights (self , final_norm ):
274+ def assign_weights (self , final_norm , out_head , out_head_name ):
227275 if self .cfg .get ("use_aie_final_norm" , False ):
228276 self .aie_final_norm_prefill .weight = final_norm
229277 if self .cfg ["use_kv_cache" ]:
230278 self .aie_final_norm_decode .weight = final_norm
231- return
279+ else :
280+ self .final_norm .weight = assign (
281+ self .final_norm .weight ,
282+ final_norm ,
283+ f"model.norm.weight" ,
284+ )
232285
233- self .final_norm .weight = assign (
234- self .final_norm .weight ,
235- final_norm ,
236- f"model.norm.weight" ,
237- )
286+ # TODO: Offload GEMV to NPU
287+ # if self.cfg["use_kv_cache"] and self.cfg["use_aie_gemv"]:
288+ # self.aie_out_head_gemv.weight = out_head
289+ if self .cfg ["use_aie_final_gemm" ]:
290+ # Want column-major for B
291+ self .out_head .weight = out_head .T
292+ # TODO: Create separate linear layers for prefill and decode (with gemm/gemv)
293+ # if self.cfg["use_kv_cache"]:
294+ # self.out_head.weight = out_head.T
295+ else :
296+ self .out_head .weight = assign (
297+ self .out_head .weight ,
298+ out_head ,
299+ out_head_name ,
300+ )
0 commit comments