-
Notifications
You must be signed in to change notification settings - Fork 23
dot/matmul fixes in python/cython #922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
It is not clear for me why SVD tests with complex fails with |
@@ -97,11 +97,18 @@ def dot(x1, x2, **kwargs): | |||
dim1 = x1_desc.ndim | |||
dim2 = x2_desc.ndim | |||
|
|||
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1_desc.dtype == x2_desc.dtype): | |||
# for now we work only with these cases | |||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could split the condition with code like below:
if (dim1 == 1 and dim2 == 1) or (dim1 == 2 and dim2 == 2):
pass
elif x1_desc.dtype != x2_desc.dtype:
pass
else:
result_obj = dpnp_dot(x1_desc, x2_desc).get_pyobj()
if (dim1 == 2 and dim2 == 2):
return result_obj
else:
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
return result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could split the condition with code like below:
if (dim1 == 1 and dim2 == 1) or (dim1 == 2 and dim2 == 2): pass elif x1_desc.dtype != x2_desc.dtype: pass else: result_obj = dpnp_dot(x1_desc, x2_desc).get_pyobj() if (dim1 == 2 and dim2 == 2): return result_obj else: result = dpnp.convert_single_elem_array_to_scalar(result_obj) return result
You want to return behavior back and send 2D and 1D cases in fallback :))
if (dim1 == 1 and dim2 == 1) or (dim1 == 2 and dim2 == 2): pass
I think it is a typo, and I understand how to rewrite these new conditions in few lines, but I suggest leave it as is and with existing comments. This code should be changed one more time in nearest future because support of strides in matrix multiplication is requested.
@@ -246,7 +253,7 @@ def matmul(x1, x2, out=None, **kwargs): | |||
x1_desc = dpnp.get_dpnp_descriptor(x1) | |||
x2_desc = dpnp.get_dpnp_descriptor(x2) | |||
if x1_desc and x2_desc and not kwargs: | |||
if x1_desc.size != x2_desc.size: | |||
if x1_desc.ndim != 2 or x2_desc.ndim != 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all 1D or 3D goes to fallback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In c++ backend implementation GEMM from mkl is used, this function works with 2D matrices. Maybe I am wrong and code on cython layer can extend dimension for 1D, but I don't think that we are ready to work with common 3D case. Please suggest better condition.
In any way, previous condition sends to fallback anything with different sizes, for example matrices with shapes (2, 3) and (3,4) are going to fallback...
No description provided.