1515
1616#include < aie_api/aie.hpp>
1717
18+ #ifndef VEC_SIZE
19+ #define VEC_SIZE 64
20+ #endif
21+
1822void matvec_scalar (uint32_t m,
1923 uint32_t k,
2024 const bfloat16 *__restrict a,
@@ -40,22 +44,17 @@ Matrix-vector multiplication kernel
4044 - c: Pointer to the output vector
4145 - r: Vector size; data from the matrix and vector will be loaded in and processed in chunks of this size
4246*/
43- template <uint32_t r>
44- void matvec_vectorized (uint32_t m,
45- uint32_t k,
46- const bfloat16 *__restrict a,
47- const bfloat16 *__restrict b,
48- bfloat16 *__restrict c)
47+ template <uint32_t r, uint32_t k>
48+ void matvec_vectorized (uint32_t m, const bfloat16 *__restrict a, const bfloat16 *__restrict b, bfloat16 *__restrict c)
4949{
5050 ::aie::set_rounding (aie::rounding_mode::conv_even);
5151 bfloat16 *c_end = c + m;
5252 const bfloat16 *b_end = b + k;
5353 for (; c < c_end; c++) {
5454 aie::accum acc = aie::zeros<accfloat, r>();
55- // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that k is at least
56- // two. This assumption should hold for any useful use of this function; if k were one, this would be a simple
57- // scalar multiplication of a vector.
58- AIE_LOOP_MIN_ITERATION_COUNT (2 )
55+ // The following two pragmas enable pipelining the zero-overhead loop, but they do assume that there are at
56+ // least two iterations of the loop, i.e. k >= 2*r. This pragma will break the code if that is not the case!
57+ AIE_LOOP_MIN_ITERATION_COUNT (k / VEC_SIZE)
5958 for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) {
6059 aie::vector<bfloat16, r> a_vec = aie::load_v<r>(a);
6160 aie::vector<bfloat16, r> b_vec = aie::load_v<r>(b_cur);
@@ -72,25 +71,23 @@ extern "C" {
7271 * `c`. */
7372
7473void matvec_scalar_bf16_bf16 (uint32_t m,
75- uint32_t k,
7674 uint32_t row_offset,
7775 const bfloat16 *__restrict a_in,
7876 const bfloat16 *__restrict b_in,
7977 bfloat16 *__restrict c_out)
8078{
8179 c_out += row_offset;
82- matvec_scalar (m, k , a_in, b_in, c_out);
80+ matvec_scalar (m, DIM_K , a_in, b_in, c_out);
8381}
8482
8583void matvec_vectorized_bf16_bf16 (uint32_t m,
86- uint32_t k,
8784 uint32_t row_offset,
8885 const bfloat16 *__restrict a_in,
8986 const bfloat16 *__restrict b_in,
9087 bfloat16 *__restrict c_out)
9188{
9289 c_out += row_offset;
93- matvec_vectorized<64 >(m, k , a_in, b_in, c_out);
90+ matvec_vectorized<VEC_SIZE, DIM_K >(m, a_in, b_in, c_out);
9491}
9592
9693} // extern "C"
0 commit comments