Skip to content

Commit ea275b5

Browse files
authored
Merge 4375996 into 6b7326b
2 parents 6b7326b + 4375996 commit ea275b5

17 files changed

Lines changed: 2496 additions & 108 deletions

File tree

aie_kernels/aie2/softmax.cc

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "lut_based_ops.h"
55

66
#include <aie_api/aie.hpp>
7+
#include <math.h>
78
#include <stdint.h>
89

910
using namespace aie;
@@ -57,13 +58,132 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
5758
return;
5859
}
5960

61+
// ---------------------------------------------------------------------------
62+
// Online (partial / tiled) softmax helpers
63+
//
64+
// These three kernels implement a two-pass online softmax that processes a row
65+
// in sub-tile chunks, keeping running max and sum statistics in a small local
66+
// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1]
67+
// used):
68+
// stats[0] = running max
69+
// stats[1] = running sum (of exp(x - max))
70+
// ---------------------------------------------------------------------------
71+
72+
void softmax_partial_stats_impl(bfloat16 *restrict input,
73+
bfloat16 *stats,
74+
const int32_t vector_size)
75+
{
76+
event0();
77+
78+
const int elem_iters = vector_size / 16;
79+
80+
float running_max = (float)stats[0];
81+
float running_sum = (float)stats[1];
82+
83+
aie::vector<bfloat16, 16> input_bf16;
84+
aie::accum<accfloat, 16> exp_val_accum = aie::zeros<accfloat, 16>();
85+
86+
auto it_in = aie::cbegin_vector<16>((bfloat16 *)input);
87+
88+
// Single-pass online algorithm: for each vector chunk, check if max
89+
// needs updating, rescale the running sum if so, then accumulate
90+
// exp(x - max).
91+
for (int i = 0; i < elem_iters; i++) {
92+
input_bf16 = *it_in++;
93+
float chunk_max = aie::reduce_max(input_bf16);
94+
95+
if (chunk_max > running_max) {
96+
// Rescale accumulated exp values by exp(old_max - new_max)
97+
aie::vector<bfloat16, 16> correction =
98+
to_v16bfloat16(getExpBf16(
99+
aie::broadcast<bfloat16, 16>((bfloat16)(running_max - chunk_max))));
100+
float scale = (float)correction[0];
101+
// Rescale the partial vector accumulator
102+
aie::vector<bfloat16, 16> scale_vec =
103+
aie::broadcast<bfloat16, 16>((bfloat16)scale);
104+
exp_val_accum = aie::mul(exp_val_accum.to_vector<bfloat16>(), scale_vec);
105+
// Rescale the running scalar sum from previous chunks
106+
running_sum *= scale;
107+
running_max = chunk_max;
108+
}
109+
110+
aie::vector<bfloat16, 16> shifted = aie::sub(
111+
input_bf16, aie::broadcast<bfloat16, 16>((bfloat16)running_max));
112+
aie::vector<bfloat16, 16> exp_val = to_v16bfloat16(getExpBf16(shifted));
113+
exp_val_accum = add(exp_val_accum, exp_val);
114+
}
115+
116+
// Reduce the vector accumulator and add to running sum
117+
aie::vector<float, 16> reduce = exp_val_accum.to_vector<float>();
118+
running_sum += aie::reduce_add(reduce);
119+
120+
stats[0] = (bfloat16)running_max;
121+
stats[1] = (bfloat16)running_sum;
122+
123+
event1();
124+
}
125+
126+
void softmax_partial_norm_impl(bfloat16 *restrict input,
127+
bfloat16 *restrict output,
128+
bfloat16 *stats,
129+
const int32_t vector_size)
130+
{
131+
event0();
132+
133+
const int elem_iters = vector_size / 16;
134+
135+
float max_val = (float)stats[0];
136+
float sum_val = (float)stats[1];
137+
bfloat16 inv_sum = (bfloat16)aie::inv(sum_val);
138+
139+
aie::vector<bfloat16, 16> max_val_vec =
140+
aie::broadcast<bfloat16, 16>((bfloat16)max_val);
141+
142+
aie::vector<bfloat16, 16> input_bf16;
143+
aie::accum<accfloat, 16> out_vals;
144+
145+
auto it_in = aie::cbegin_restrict_vector<16>((bfloat16 *)input);
146+
auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output);
147+
148+
for (int i = 0; i < elem_iters; i++) {
149+
input_bf16 = *it_in++;
150+
aie::vector<bfloat16, 16> shifted = aie::sub(input_bf16, max_val_vec);
151+
aie::vector<bfloat16, 16> exp_val = to_v16bfloat16(getExpBf16(shifted));
152+
out_vals = aie::mul(exp_val, inv_sum);
153+
*it_out++ = out_vals.to_vector<bfloat16>();
154+
}
155+
156+
event1();
157+
}
158+
60159
extern "C" {
61160

62161
void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size)
63162
{
64163
softmax_simple_bf16(input, output, input_size);
65164
}
66165

166+
void softmax_partial_init_bf16(bfloat16 *stats)
167+
{
168+
stats[0] = (bfloat16)(-INFINITY);
169+
stats[1] = (bfloat16)(0.0f);
170+
}
171+
172+
void softmax_partial_stats_bf16(bfloat16 *restrict input,
173+
bfloat16 *stats,
174+
const int32_t vector_size)
175+
{
176+
softmax_partial_stats_impl(input, stats, vector_size);
177+
}
178+
179+
void softmax_partial_norm_bf16(bfloat16 *restrict input,
180+
bfloat16 *restrict output,
181+
bfloat16 *stats,
182+
const int32_t vector_size)
183+
{
184+
softmax_partial_norm_impl(input, output, stats, vector_size);
185+
}
186+
67187
void mask_bf16(bfloat16 *inout, const int32_t unmasked_size, const int32_t total_size)
68188
{
69189
for (int32_t i = unmasked_size; i < total_size; i++) {

aie_kernels/aie2p/softmax.cc

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <aie_api/aie.hpp>
55
#include <stdint.h>
6+
#include <math.h>
67

78
#define SM_VEC_LEN 64 // 32
89
#define log2e 1.4453125 // 1.44269504089
@@ -30,7 +31,7 @@ void softmax_simple_bf16(bfloat16 *restrict input_vector, bfloat16 *restrict out
3031
aie::vector<bfloat16, SM_VEC_LEN> in_elems, exp_val, input_bf16, log2e_vec, max_val_vec;
3132
aie::accum<accfloat, SM_VEC_LEN> out_vals, exp_val_accum, scaled_accum, exp_in_accum;
3233

33-
float max_val = 0;
34+
float max_val = -INFINITY;
3435
float accum_exp_val = 0;
3536
float running_max = 0;
3637
bfloat16 col_sum_inv;
@@ -159,6 +160,118 @@ void partial_softmax_alias_bf16(bfloat16 *restrict input_vector,
159160
return;
160161
}
161162

163+
// ---------------------------------------------------------------------------
164+
// Online (partial / tiled) softmax helpers
165+
//
166+
// These three kernels implement a two-pass online softmax that processes a row
167+
// in sub-tile chunks, keeping running max and sum statistics in a small local
168+
// buffer (`stats`). Layout of the stats buffer (bfloat16[16], only [0..1]
169+
// used):
170+
// stats[0] = running max (scaled by log2e)
171+
// stats[1] = running sum (of exp2(x*log2e - max))
172+
// ---------------------------------------------------------------------------
173+
174+
void softmax_partial_stats_impl(bfloat16 *restrict input,
175+
bfloat16 *stats,
176+
const int32_t vector_size)
177+
{
178+
event0();
179+
180+
const int elem_iters = vector_size / SM_VEC_LEN;
181+
182+
aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
183+
aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum;
184+
aie::accum<accfloat, SM_VEC_LEN> exp_val_accum = aie::zeros<accfloat, SM_VEC_LEN>();
185+
186+
aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
187+
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);
188+
189+
// --- Phase 1: find local max (scaled by log2e) -------------------------
190+
float local_max = -INFINITY;
191+
auto it_in1 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
192+
for (int i = 0; i < elem_iters; i++) {
193+
input_bf16 = *it_in1++;
194+
scaled_accum = aie::mul(input_bf16, log2e_vec);
195+
float chunk_max = aie::reduce_max(scaled_accum.to_vector<bfloat16>());
196+
if (chunk_max > local_max) {
197+
local_max = chunk_max;
198+
}
199+
}
200+
201+
// --- Phase 2: update running max, rescale running sum ------------------
202+
float old_max = (float)stats[0];
203+
float old_sum = (float)stats[1];
204+
205+
if (local_max > old_max) {
206+
// New max is larger — rescale the old sum by exp2(old_max - new_max)
207+
aie::vector<float, SM_VEC_LEN> diff_vec =
208+
aie::broadcast<float, SM_VEC_LEN>(old_max - local_max);
209+
aie::vector<bfloat16, SM_VEC_LEN> corr = aie::exp2<bfloat16>(diff_vec);
210+
old_sum = old_sum * (float)corr[0];
211+
old_max = local_max;
212+
}
213+
214+
// --- Phase 3: accumulate exp2(input * log2e - max) for this chunk ------
215+
aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
216+
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)old_max);
217+
218+
auto it_in2 = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
219+
for (int i = 0; i < elem_iters; i++) {
220+
input_bf16 = *it_in2++;
221+
scaled_accum = aie::mul(input_bf16, log2e_vec);
222+
exp_in_accum = aie::sub(scaled_accum, max_val_vec);
223+
aie::vector<bfloat16, SM_VEC_LEN> exp_val =
224+
aie::exp2<bfloat16>(exp_in_accum.to_vector<float>());
225+
exp_val_accum = add(exp_val_accum, exp_val);
226+
}
227+
228+
aie::vector<float, SM_VEC_LEN> reduce = exp_val_accum.to_vector<float>();
229+
float local_sum = aie::reduce_add(reduce);
230+
231+
// --- Phase 4: store updated stats --------------------------------------
232+
stats[0] = (bfloat16)old_max;
233+
stats[1] = (bfloat16)(old_sum + local_sum);
234+
235+
event1();
236+
}
237+
238+
void softmax_partial_norm_impl(bfloat16 *restrict input,
239+
bfloat16 *restrict output,
240+
bfloat16 *stats,
241+
const int32_t vector_size)
242+
{
243+
event0();
244+
245+
const int elem_iters = vector_size / SM_VEC_LEN;
246+
247+
float max_val = (float)stats[0];
248+
float sum_val = (float)stats[1];
249+
bfloat16 inv_sum = (bfloat16)aie::inv(sum_val);
250+
251+
aie::vector<bfloat16, SM_VEC_LEN> log2e_vec =
252+
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)log2e);
253+
aie::vector<bfloat16, SM_VEC_LEN> max_val_vec =
254+
aie::broadcast<bfloat16, SM_VEC_LEN>((bfloat16)max_val);
255+
256+
aie::vector<bfloat16, SM_VEC_LEN> input_bf16;
257+
aie::accum<accfloat, SM_VEC_LEN> scaled_accum, exp_in_accum, out_vals;
258+
259+
auto it_in = aie::cbegin_restrict_vector<SM_VEC_LEN>((bfloat16 *)input);
260+
auto it_out = aie::begin_restrict_vector<SM_VEC_LEN>((bfloat16 *)output);
261+
262+
for (int i = 0; i < elem_iters; i++) {
263+
input_bf16 = *it_in++;
264+
scaled_accum = aie::mul(input_bf16, log2e_vec);
265+
exp_in_accum = aie::sub(scaled_accum, max_val_vec);
266+
aie::vector<bfloat16, SM_VEC_LEN> exp_val =
267+
aie::exp2<bfloat16>(exp_in_accum.to_vector<float>());
268+
out_vals = aie::mul(exp_val, inv_sum);
269+
*it_out++ = out_vals.to_vector<bfloat16>();
270+
}
271+
272+
event1();
273+
}
274+
162275
extern "C" {
163276

164277
void softmax_bf16(bfloat16 *restrict input, bfloat16 *restrict output, const int32_t input_size)
@@ -177,6 +290,27 @@ void partial_softmax_bf16(bfloat16 *restrict input,
177290
partial_softmax_alias_bf16(input, output, scale_buffer, input_size, row_idx, num_rows, scale);
178291
}
179292

293+
void softmax_partial_init_bf16(bfloat16 *stats)
294+
{
295+
stats[0] = (bfloat16)(-INFINITY);
296+
stats[1] = (bfloat16)(0.0f);
297+
}
298+
299+
void softmax_partial_stats_bf16(bfloat16 *restrict input,
300+
bfloat16 *stats,
301+
const int32_t vector_size)
302+
{
303+
softmax_partial_stats_impl(input, stats, vector_size);
304+
}
305+
306+
void softmax_partial_norm_bf16(bfloat16 *restrict input,
307+
bfloat16 *restrict output,
308+
bfloat16 *stats,
309+
const int32_t vector_size)
310+
{
311+
softmax_partial_norm_impl(input, output, stats, vector_size);
312+
}
313+
180314
void mask_bf16(bfloat16 *inout, const int32 unmasked_size, const int32 total_size)
181315
{
182316
// TODO: Optimize this to use vector code

0 commit comments

Comments
 (0)