diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index fbd6402924..052fe28c60 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1000,38 +1000,38 @@ sycl::event gemm_impl(sycl::queue &exec_q, dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_shape_strides); - OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_shape_strides); - OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1045,29 +1045,34 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - size_t lws = delta_n * delta_k; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t lws = delta_n * delta_k; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1081,34 +1086,39 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -1128,10 +1138,9 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - return gemm_ev; + }); + return gemm_ev; + } } typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( @@ -1172,36 +1181,36 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerIndexerT lhs_indexer{}; + OuterInnerIndexerT rhs_indexer{}; + OuterInnerIndexerT res_indexer{}; - using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerIndexerT lhs_indexer{}; - OuterInnerIndexerT rhs_indexer{}; - OuterInnerIndexerT res_indexer{}; + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t lws = delta_n * delta_k; - size_t lws = delta_n * delta_k; + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1215,29 +1224,35 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1251,34 +1266,39 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -1298,10 +1318,10 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); + }); - return gemm_ev; + return gemm_ev; + } } template ( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3610,30 +3630,36 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3651,34 +3677,39 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t lws = wg_delta_n * wg_delta_m; + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t lws = wg_delta_n * wg_delta_m; - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -3702,10 +3733,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - return gemm_ev; + }); + return gemm_ev; + } } typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( @@ -3756,49 +3786,50 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, return res_init_ev; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3816,30 +3847,36 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3857,34 +3894,40 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -3908,10 +3951,10 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - }); + }); - return gemm_ev; + return gemm_ev; + } } template