Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,38 @@ std::tuple<sycl::event, sycl::event, bool>

standardize_strides_to_nonzero(a_stride, a_shape);
standardize_strides_to_nonzero(b_stride, b_shape);
standardize_strides_to_nonzero(c_stride, c_shape);
const bool A_base_is_f_contig =
a_stride[1] == 1 && a_stride[2] == a_shape[1];
const bool A_base_is_c_contig =
a_stride[1] == a_shape[2] && a_stride[2] == 1;
const bool B_base_is_f_contig =
b_stride[1] == 1 && b_stride[2] == b_shape[1];
const bool B_base_is_c_contig =
b_stride[1] == b_shape[2] && b_stride[2] == 1;
const bool C_base_is_f_contig =
c_stride[1] == 1 && c_stride[2] == c_shape[1];
const bool C_base_is_c_contig =
c_stride[1] == c_shape[2] && c_stride[2] == 1;

bool is_row_major = true;
if (A_base_is_f_contig && B_base_is_f_contig) {
is_row_major = false;
}

if (!A_base_is_f_contig and !A_base_is_c_contig) {
throw py::value_error("The 2D base of the first input array is not "
"c-contiguous nor f-contiguous.");
}
if (!B_base_is_f_contig and !B_base_is_c_contig) {
throw py::value_error("The 2D base of the second input array is not "
"c-contiguous nor f-contiguous.");
}
if (!C_base_is_f_contig and !C_base_is_c_contig) {
throw py::value_error("The 2D base of result array is not c-contiguous "
"nor f-contiguous.");
}

oneapi::mkl::transpose transA;
oneapi::mkl::transpose transB;
if (is_row_major) {
Expand Down Expand Up @@ -346,10 +368,10 @@ std::tuple<sycl::event, sycl::event, bool>
strideb, stridec, transA, transB, a_typeless_ptr,
b_typeless_ptr, r_typeless_ptr, is_row_major, depends);

sycl::event args_batch_ev = dpctl::utils::keep_args_alive(
sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {matrixA, matrixB, resultC}, {gemm_batch_ev});

return std::make_tuple(args_batch_ev, gemm_batch_ev, is_row_major);
return std::make_tuple(args_ev, gemm_batch_ev, is_row_major);
}

template <typename fnT, typename Tab, typename Tc>
Expand Down
8 changes: 0 additions & 8 deletions dpnp/backend/extensions/blas/gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,6 @@ extern std::pair<sycl::event, sycl::event>
const bool transpose,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gemv_batch(sycl::queue &exec_q,
const dpctl::tensor::usm_ndarray &matrixA,
const dpctl::tensor::usm_ndarray &vectorX,
const dpctl::tensor::usm_ndarray &vectorY,
const bool transpose,
const std::vector<sycl::event> &depends);

extern void init_gemv_dispatch_vector(void);
extern void init_gemv_batch_dispatch_vector(void);
} // namespace dpnp::extensions::blas
81 changes: 47 additions & 34 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,12 @@ def _create_result_array(
"""
Create the result array.

If `out` is not ``None`` and its features match the specified `shape`, `dtype,
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
If `out` is not ``None`` and its shape and dtype match the desired `shape`
and `dtype`, and its 2-D base is contiguous and it does not have any memory
overlap with `x1` and `x2`, `out` itself is returned.
If these conditions are not satisfied, an empty array is returned with the
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.

"""

if out is not None:
Expand All @@ -150,7 +151,6 @@ def _create_result_array(
if (
out.dtype == dtype
and out.shape == shape
and out.usm_type == usm_type
and contig_flag
and not ti._array_overlap(x1_usm, out_usm)
and not ti._array_overlap(x2_usm, out_usm)
Expand Down Expand Up @@ -325,10 +325,13 @@ def _get_result_shape(x1, x2, out, np_flag):

def _gemm_batch_matmul(exec_q, x1, x2, res):
# arrays here are already at least 3D, make them 3D
x1 = dpnp.reshape(x1, (-1, x1.shape[-2], x1.shape[-1]))
x2 = dpnp.reshape(x2, (-1, x2.shape[-2], x2.shape[-1]))
x1_shape = x1.shape
x2_shape = x2.shape
x1 = dpnp.reshape(x1, (-1, x1_shape[-2], x1_shape[-1]))
x2 = dpnp.reshape(x2, (-1, x2_shape[-2], x2_shape[-1]))
orig_shape = res.shape
res = dpnp.reshape(res, (-1, res.shape[-2], res.shape[-1]))
res = dpnp.reshape(res, (-1, orig_shape[-2], orig_shape[-1]))
res_shape = res.shape

# gemm_batch does not handle negative strides, make a copy if needed
x1 = _copy_array(x1, copy_flag=x1.strides[0] < 0)
Expand All @@ -338,16 +341,16 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
_manager = dpu.SequentialOrderManager[exec_q]

# onemkl::blas::gemm_bacth throws an exception (Provided range is out
# of integer limits) if the batch_size is too large (>=4096*4096), so
# we need to split the batch into smaller chunks
chunk = 2048 * 2048
batch_size = res.shape[0]
# of integer limits) if the batch_size is too large, so we need to
# split the batch into smaller chunks, the size depnends on device
chunk = 4096 * 4096 - 2
batch_size = res_shape[0]
for i in range(0, batch_size, chunk):
if x1.shape[0] == 1:
if x1_shape[0] == 1:
# x1 is repeatedly multiplied with each matrix in x2
x1_usm = dpnp.get_usm_ndarray(x1)
x2_usm = dpnp.get_usm_ndarray(x2[i : i + chunk, ...])
elif x2.shape[0] == 1:
elif x2_shape[0] == 1:
x1_usm = dpnp.get_usm_ndarray(x1[i : i + chunk, ...])
x2_usm = dpnp.get_usm_ndarray(x2)
else:
Expand All @@ -364,25 +367,36 @@ def _gemm_batch_matmul(exec_q, x1, x2, res):
)
_manager.add_event_pair(ht_ev, blas_ev)

res_shape = res.shape
_, res_is_c_contig, res_is_f_contig = _define_contig_flag(res)
if row_major:
if res_is_f_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="F"),
(res_shape[1], res_shape[2], batch_size),
).transpose(2, 0, 1)
# Considering the multiplication for one of the batches,
# we have result[0, 1] = a[0, :]*b[1, :]. In row_major mode,
# it is assumed result array is c-contiguous, i.e. the value of
# result[0, 1] is has the second place memory.
# however, the result array is batches of 2D f-contiguous array,
# i.e. the second place of memory points out to res[1, 0].
# So, we need to read data of each 2D array in the batch in
# "F" order and write it in "C" order
res = (
res.ravel(order="F")
.reshape(res_shape[1], res_shape[2], batch_size)
.transpose(2, 0, 1)
)
else:
if res_is_c_contig:
res = dpnp.reshape(
dpnp.ravel(res, order="C"),
(batch_size, res_shape[2], res_shape[1]),
).transpose(0, 2, 1)
# read data of each 2D array in the batch in "C" order and
# write it in "F" order
res = (
res.ravel(order="C")
.reshape(batch_size, res_shape[2], res_shape[1])
.transpose(0, 2, 1)
)

if res_shape != orig_shape:
res = res.reshape(orig_shape)

return dpnp.ascontiguousarray(res)
return res


def _gemm_matmul(exec_q, x1, x2, res):
Expand All @@ -400,13 +414,13 @@ def _gemm_matmul(exec_q, x1, x2, res):
if row_major:
if res.flags.f_contiguous is True:
# read data in "F" order and write it in "C" order
res = dpnp.reshape(dpnp.ravel(res, order="F"), res.shape, order="C")
res = dpnp.ravel(res, order="F").reshape(res.shape, order="C")
else:
if res.flags.c_contiguous is True:
# read data in "C" order and write it in "F" order
res = dpnp.reshape(dpnp.ravel(res, order="C"), res.shape, order="F")
res = dpnp.ravel(res, order="C").reshape(res.shape, order="F")

return dpnp.ascontiguousarray(res)
return res


def _shape_error(a, b, core_dim, err_msg):
Expand Down Expand Up @@ -767,9 +781,9 @@ def dpnp_matmul(
call_flag = "multiply"
elif x1_is_1D and x2_is_1D:
call_flag = "dot"
x1 = dpnp.reshape(x1, x1_shape[-1])
if x2_ndim != 1:
x2 = dpnp.reshape(x2, x2_shape[-2])
# arrays are inehrently 1D, make them 1D
x1 = dpnp.ravel(x1)
x2 = dpnp.ravel(x2)
elif x1_base_is_1D and x2_base_is_1D:
# TODO: implement a batch version of dot to use it here
call_flag = "gemm_batch"
Expand Down Expand Up @@ -912,12 +926,11 @@ def dpnp_matmul(
# we need to update it to match the passed `order`.
if order not in ["k", "K"]:
return dpnp.array(result, copy=False, order=order)
return result
# dpnp.ascontiguousarray changes 0-D array to 1-D array
if result.ndim == 0:
return result
return dpnp.ascontiguousarray(result)

# TODO: There is opportunity to improve performance when out keyword is
# present. For some cases, out is NOT result but they have the same base
# (They are views of the same data). In this case, we can avoid copyign
# result to out.
result = dpnp.get_result_array(result, out, casting=casting)
if axes is not None and out is result:
# out and out_orig contain the same data but they have different shape
Expand Down
8 changes: 5 additions & 3 deletions tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3037,6 +3037,7 @@ def test_matmul_strided3(self, stride, transpose):
@pytest.mark.parametrize("incy", [-2, 2], ids=["-2", "2"])
@pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"])
def test_matmul_strided_mat_vec(self, shape, incx, incy, transpose):
# vector is strided
if transpose:
s1 = shape[-2]
s2 = shape[-1]
Expand Down Expand Up @@ -3069,6 +3070,7 @@ def test_matmul_strided_mat_vec(self, shape, incx, incy, transpose):
@pytest.mark.parametrize("incy", [-2, 2], ids=["-2", "2"])
@pytest.mark.parametrize("transpose", [False, True], ids=["False", "True"])
def test_matmul_strided_vec_mat(self, shape, incx, incy, transpose):
# vector is strided
if transpose:
s1 = shape[-2]
s2 = shape[-1]
Expand Down Expand Up @@ -3217,9 +3219,9 @@ def test_matmul_out_0D(self, out_shape):
@pytest.mark.parametrize(
"shape_pair",
[
((4096, 4096, 2, 2), (4096, 4096, 2, 2)),
((2, 2), (4096, 4096, 2, 2)),
((4096, 4096, 2, 2), (2, 2)),
((5000, 5000, 2, 2), (5000, 5000, 2, 2)),
((2, 2), (5000, 5000, 2, 2)),
((5000, 5000, 2, 2), (2, 2)),
],
)
def test_matmul_large(self, shape_pair):
Expand Down