33
44#include < aie_api/aie.hpp>
55#include < stdint.h>
6+ #include < math.h>
67
78#define SM_VEC_LEN 64 // 32
89#define log2e 1.4453125 // 1.44269504089
@@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
3031 aie::vector<bfloat16, SM_VEC_LEN> in_elems, exp_val, input_bf16, log2e_vec, max_val_vec;
3132 aie::accum<accfloat, SM_VEC_LEN> out_vals, exp_val_accum, scaled_accum, exp_in_accum;
3233
33- float max_val = 0 ;
34+ float max_val = -INFINITY ;
3435 float accum_exp_val = 0 ;
3536 float running_max = 0 ;
3637 bfloat16 col_sum_inv;
@@ -159,6 +160,118 @@ void partial_softmax_alias_bf16(bfloat16 *restrict input_vector,
159160 return ;
160161}
161162
163+ // ---------------------------------------------------------------------------
164+ // Online (partial / tiled) softmax helpers
165+ //
166+ // These three kernels implement a two-pass online softmax that processes a row
167+ // in sub-tile chunks, keeping running max and sum statistics in a small local
168+ // buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1]
169+ // used):
170+ // stats[0] = running max (scaled by log2e)
171+ // stats[1] = running sum (of exp2(x*log2e - max))
172+ // ---------------------------------------------------------------------------
173+
174+ void softmax_partial_stats_impl (bfloat16 *restrict input,
175+ bfloat16 *stats,
176+ const int32_t vector_size)
177+ {
178+ event0 ();
179+
180+ const int elem_iters = vector_size / SM_VEC_LEN;
181+
182+ aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
183+ aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum;
184+ aie::accum<accfloat, SM_VEC_LEN> exp_val_accum = aie::zeros<accfloat, SM_VEC_LEN>();
185+
186+ aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
187+ aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);
188+
189+ // --- Phase 1: find local max (scaled by log2e) -------------------------
190+ float local_max = -INFINITY;
191+ auto it_in1 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
192+ for (int i = 0 ; i < elem_iters; i++) {
193+ input_bf16 = *it_in1++;
194+ scaled_accum = aie::mul (input_bf16, log2e_vec);
195+ float chunk_max = aie::reduce_max (scaled_accum.to_vector <bfloat16>());
196+ if (chunk_max > local_max) {
197+ local_max = chunk_max;
198+ }
199+ }
200+
201+ // --- Phase 2: update running max, rescale running sum ------------------
202+ float old_max = (float )stats[0 ];
203+ float old_sum = (float )stats[1 ];
204+
205+ if (local_max > old_max) {
206+ // New max is larger — rescale the old sum by exp2(old_max - new_max)
207+ aie::vector<float , SM_VEC_LEN> diff_vec =
208+ aie::broadcast<float , SM_VEC_LEN>(old_max - local_max);
209+ aie::vector<bfloat16, SM_VEC_LEN> corr = aie::exp2<bfloat16>(diff_vec);
210+ old_sum = old_sum * (float )corr[0 ];
211+ old_max = local_max;
212+ }
213+
214+ // --- Phase 3: accumulate exp2(input * log2e - max) for this chunk ------
215+ aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
216+ aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)old_max);
217+
218+ auto it_in2 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
219+ for (int i = 0 ; i < elem_iters; i++) {
220+ input_bf16 = *it_in2++;
221+ scaled_accum = aie::mul (input_bf16, log2e_vec);
222+ exp_in_accum = aie::sub (scaled_accum, max_val_vec);
223+ aie::vector<bfloat16, SM_VEC_LEN> exp_val =
224+ aie::exp2<bfloat16>(exp_in_accum.to_vector <float >());
225+ exp_val_accum = add (exp_val_accum, exp_val);
226+ }
227+
228+ aie::vector<float , SM_VEC_LEN> reduce = exp_val_accum.to_vector <float >();
229+ float local_sum = aie::reduce_add (reduce);
230+
231+ // --- Phase 4: store updated stats --------------------------------------
232+ stats[0 ] = (bfloat16)old_max;
233+ stats[1 ] = (bfloat16)(old_sum + local_sum);
234+
235+ event1 ();
236+ }
237+
238+ void softmax_partial_norm_impl (bfloat16 *restrict input,
239+ bfloat16 *restrict output,
240+ bfloat16 *stats,
241+ const int32_t vector_size)
242+ {
243+ event0 ();
244+
245+ const int elem_iters = vector_size / SM_VEC_LEN;
246+
247+ float max_val = (float )stats[0 ];
248+ float sum_val = (float )stats[1 ];
249+ bfloat16 inv_sum = (bfloat16)aie::inv (sum_val);
250+
251+ aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
252+ aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);
253+ aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
254+ aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)max_val);
255+
256+ aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
257+ aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum, out_vals;
258+
259+ auto it_in = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
260+ auto it_out = aie::begin_restrict_vector<SM_VEC_LEN>((bfloat16 *)output);
261+
262+ for (int i = 0 ; i < elem_iters; i++) {
263+ input_bf16 = *it_in++;
264+ scaled_accum = aie::mul (input_bf16, log2e_vec);
265+ exp_in_accum = aie::sub (scaled_accum, max_val_vec);
266+ aie::vector<bfloat16, SM_VEC_LEN> exp_val =
267+ aie::exp2<bfloat16>(exp_in_accum.to_vector <float >());
268+ out_vals = aie::mul (exp_val, inv_sum);
269+ *it_out++ = out_vals.to_vector <bfloat16>();
270+ }
271+
272+ event1 ();
273+ }
274+
162275extern " C" {
163276
164277void softmax_bf16 (bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size)
@@ -177,6 +290,27 @@ void partial_softmax_bf16(bfloat16 *restrict input,
177290 partial_softmax_alias_bf16 (input, output, scale_buffer, input_size, row_idx, num_rows, scale);
178291}
179292
293+ void softmax_partial_init_bf16 (bfloat16 *stats)
294+ {
295+ stats[0 ] = (bfloat16)(-INFINITY);
296+ stats[1 ] = (bfloat16)(0 .0f );
297+ }
298+
299+ void softmax_partial_stats_bf16 (bfloat16 *restrict input,
300+ bfloat16 *stats,
301+ const int32_t vector_size)
302+ {
303+ softmax_partial_stats_impl (input, stats, vector_size);
304+ }
305+
306+ void softmax_partial_norm_bf16 (bfloat16 *restrict input,
307+ bfloat16 *restrict output,
308+ bfloat16 *stats,
309+ const int32_t vector_size)
310+ {
311+ softmax_partial_norm_impl (input, output, stats, vector_size);
312+ }
313+
180314void mask_bf16 (bfloat16 *inout, const int32 unmasked_size, const int32 total_size)
181315{
182316 // TODO: Optimize this to use vector code
0 commit comments