Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non unit-stride ndarrays #16

Open
mratsim opened this issue Sep 15, 2019 · 1 comment
Open

Non unit-stride ndarrays #16

mratsim opened this issue Sep 15, 2019 · 1 comment

Comments

@mratsim
Copy link

mratsim commented Sep 15, 2019

Looking into the code I'm pretty sure the code is buggy for non-unit stride ndarrays such as those resulting from slicing, reverse-slicing or broadcasting:

cython-blis/blis/py.pyx

Lines 64 to 102 in c5df079

def gemm(const_reals2d_ft A, const_reals2d_ft B,
np.ndarray out=None, bint trans1=False, bint trans2=False,
double alpha=1., double beta=1.):
cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
if const_reals2d_ft is const_float2d_t:
if out is None:
out = numpy.zeros((nM, nN), dtype='f')
C = <float*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
nM, nN, nK,
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
elif const_reals2d_ft is const_double2d_t:
if out is None:
out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
C = <double*>out.data
with nogil:
cy.gemm(
cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
A.shape[0], B.shape[1], A.shape[1],
alpha,
&A[0,0], A.shape[1], 1,
&B[0,0], B.shape[1], 1,
beta,
C, out.shape[1], 1)
return out
else:
C = NULL
raise TypeError("Unhandled fused type")

There is no check for row-major inputs but this &A[0,0], A.shape[1], 1 assumes row-major layout.

Instead the code should probably be:

def gemm(const_reals2d_ft A, const_reals2d_ft B,
         np.ndarray out=None, bint trans1=False, bint trans2=False,
         double alpha=1., double beta=1.):
    cdef cy.dim_t nM = A.shape[0] if not trans1 else A.shape[1]
    cdef cy.dim_t nK = A.shape[1] if not trans1 else A.shape[0]
    cdef cy.dim_t nN = B.shape[1] if not trans2 else B.shape[0]
    if const_reals2d_ft is const_float2d_t:
        if out is None:
            out = numpy.zeros((nM, nN), dtype='f')
        C = <float*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                nM, nN, nK,
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, out.strides[0], out.strides[1])
        return out
    elif const_reals2d_ft is const_double2d_t:
        if out is None:
            out = numpy.zeros((A.shape[0], B.shape[1]), dtype='d')
        C = <double*>out.data
        with nogil:
            cy.gemm(
                cy.TRANSPOSE if trans1 else cy.NO_TRANSPOSE,
                cy.TRANSPOSE if trans2 else cy.NO_TRANSPOSE,
                A.shape[0], B.shape[1], A.shape[1],
                alpha,
                &A[0,0], A.strides[0], A.strides[1],
                &B[0,0], B.strides[0], B.strides[1],
                beta,
                C, , out.strides[0], out.strides[1])
        return out
    else:
        C = NULL
        raise TypeError("Unhandled fused type")

same thing for gemv.

This has several advantages:

  • works for any strides
  • faster than default OpenBLAS/MKL as there is no conversion to contiguous array needed.

The main draw of the BLIS API is supporting strided arrays without giving up performance, this is the perfect use-case.

@honnibal
Copy link
Member

Thanks! I think you're right, will fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants