From f7188ceb1a0d517ac46a9a6d4908bc1ad21d9981 Mon Sep 17 00:00:00 2001 From: Oleksandr Pavlyk Date: Mon, 5 Feb 2024 16:10:31 -0600 Subject: [PATCH] Moved logic of submitting different kernels in gemm outside of handler function This should improve debugging experience, where one can see in the debugger which branch the execution took. No other logic has been changed. --- .../include/kernels/linalg_functions/gemm.hpp | 681 ++++++++++-------- 1 file changed, 362 insertions(+), 319 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp index fbd6402924..052fe28c60 100644 --- a/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp +++ b/dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -1000,38 +1000,38 @@ sycl::event gemm_impl(sycl::queue &exec_q, dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; - OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, - lhs_shape_strides); - OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, - rhs_shape_strides); - OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1045,29 +1045,34 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - size_t lws = delta_n * delta_k; + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t lws = delta_n * delta_k; - auto ndRange = sycl::nd_range<1>(gRange, lRange); + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1081,34 +1086,39 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -1128,10 +1138,9 @@ sycl::event gemm_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - return gemm_ev; + }); + return gemm_ev; + } } typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( @@ -1172,36 +1181,36 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, dev.get_info(); const size_t reserved_slm_size = 512; - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerIndexerT lhs_indexer{}; + OuterInnerIndexerT rhs_indexer{}; + OuterInnerIndexerT res_indexer{}; - using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerIndexerT lhs_indexer{}; - OuterInnerIndexerT rhs_indexer{}; - OuterInnerIndexerT res_indexer{}; + if (m < 4) { + constexpr size_t m_groups = 1; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + size_t lws = delta_n * delta_k; - size_t lws = delta_n * delta_k; + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1215,29 +1224,35 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -1251,34 +1266,39 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, workspace, local_B_block, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = - sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = sycl::range<1>(n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -1298,10 +1318,10 @@ sycl::event gemm_contig_impl(sycl::queue &exec_q, lhs_tp, rhs_tp, res_tp, local_A_block, local_B_block, n, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, lhs_indexer, rhs_indexer, res_indexer)); - } - }); + }); - return gemm_ev; + return gemm_ev; + } } template ( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); - size_t lws = delta_n * delta_k; - - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3610,30 +3630,36 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3651,34 +3677,39 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t lws = wg_delta_n * wg_delta_m; + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + size_t lws = wg_delta_n * wg_delta_m; - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -3702,10 +3733,9 @@ sycl::event gemm_batch_impl(sycl::queue &exec_q, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - }); - - return gemm_ev; + }); + return gemm_ev; + } } typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( @@ -3756,49 +3786,50 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, return res_init_ev; } - sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { - cgh.depends_on(res_init_ev); - - using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; - OuterInnerDimsIndexerT lhs_indexer{}; - OuterInnerDimsIndexerT rhs_indexer{}; - OuterInnerDimsIndexerT res_indexer{}; - using dpctl::tensor::offset_utils::Strided1DIndexer; - using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; - using BatchDimsIndexerT = - ThreeOffsets_CombinedIndexer; - - BatchDimsIndexerT batch_indexer( - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * k)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(k * m)}, - Strided1DIndexer{0, static_cast(batch_nelems), - static_cast(n * m)}); - if (m < 4) { - constexpr size_t m_groups = 1; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); - - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - - size_t lws = delta_n * delta_k; + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + OuterInnerDimsIndexerT lhs_indexer{}; + OuterInnerDimsIndexerT rhs_indexer{}; + OuterInnerDimsIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + BatchDimsIndexerT batch_indexer( + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * k)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(k * m)}, + Strided1DIndexer{0, static_cast(batch_nelems), + static_cast(n * m)}); + + if (m < 4) { + constexpr size_t m_groups = 1; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); - - auto ndRange = sycl::nd_range<1>(gRange, lRange); + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3816,30 +3847,36 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else if (k > n && k > m) { - constexpr size_t m_groups = 4; - size_t delta_k(4); - size_t n_wi(64); - size_t delta_n(32); - - gemm_detail::scale_gemm_k_parameters( - local_mem_size, reserved_slm_size, delta_k, - n_wi, // modified by reference - delta_n // modified by reference - ); + }); - size_t n_blocks = (n + delta_n - 1) / delta_n; - size_t m_blocks = (m + m_groups - 1) / m_groups; - size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + return gemm_ev; + } + else if (k > n && k > m) { + constexpr size_t m_groups = 4; + const size_t delta_k(4); + size_t n_wi(64); + size_t delta_n(32); - size_t lws = delta_n * delta_k; + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); - auto gRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lRange = sycl::range<1>(lws); + size_t n_blocks = (n + delta_n - 1) / delta_n; + size_t m_blocks = (m + m_groups - 1) / m_groups; + size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); - auto ndRange = sycl::nd_range<1>(gRange, lRange); + size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT = sycl::local_accessor, 1>; LocAccT local_B_block(n_wi * delta_k, cgh); @@ -3857,34 +3894,40 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, n, n_blocks, delta_n, k, k_blocks, delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - else { - constexpr int wi_delta_n = 2; - constexpr int wi_delta_m = 4; - size_t wg_delta_n(16); // rows of A processed in WG - size_t wg_delta_m(16); // rows of B processed in WG - size_t wi_delta_k(64); // Elements in K dimension processed by WI - - gemm_detail::scale_gemm_nm_parameters( - local_mem_size, reserved_slm_size, wi_delta_n, - wi_delta_k, // modified by reference - wg_delta_n, // modified by reference - wg_delta_m // modified by reference - ); + }); - size_t lws = wg_delta_n * wg_delta_m; + return gemm_ev; + } + else { + constexpr int wi_delta_n = 2; + constexpr int wi_delta_m = 4; + size_t wg_delta_n(16); // rows of A processed in WG + size_t wg_delta_m(16); // rows of B processed in WG + size_t wi_delta_k(64); // Elements in K dimension processed by WI - size_t n_blocks = - ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); - size_t m_blocks = - ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); - size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); - auto gwsRange = sycl::range<1>(batch_nelems * n_blocks * m_blocks * - k_blocks * lws); - auto lwsRange = sycl::range<1>(lws); + size_t lws = wg_delta_n * wg_delta_m; - auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + + auto gwsRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lwsRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(res_init_ev); using LocAccT1 = sycl::local_accessor; LocAccT1 local_A_block( @@ -3908,10 +3951,10 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, wg_delta_n, k, k_blocks, wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); - } - }); + }); - return gemm_ev; + return gemm_ev; + } } template