In [None]:
import numpy as np

In [None]:
%load_ext cython

In [None]:
# Naive Cython implementation

In [None]:
%%cython -a
cimport numpy as np
import numpy as np
import cython

@cython.boundscheck(False)
@cython.wraparound(False)
def matmul1(double[:,:] a, double[:,:] b):
    cdef int N = a.shape[0]
    cdef int M = a.shape[1]
    cdef int P = b.shape[1]
    # b = b.T.copy()
    cdef int i, j, k, ii, jj
    cdef double[:,:] c = np.zeros((N, P), dtype=np.double)
    for i in range(0, N):
        for j in range(0, M):
            for k in range(0, P):
                c[i][k] += a[i] [j] * b[j][k]
    return np.array(c)

In [None]:
%%cython -a
cimport numpy as np
import numpy as np
import cython

@cython.boundscheck(False)
@cython.wraparound(False)
def matmul2(double[:,:] a, double[:,:] b):
    cdef int N = a.shape[0]
    cdef int M = a.shape[1]
    cdef int P = b.shape[1]
    b = b.T.copy()
    cdef int i, j, k, kk
    cdef double val0, val1, val2, val3
    cdef double[:,:] c = np.zeros((N , P), dtype=np.double, order='C')
    cdef np.ndarray[double, ndim=2, mode = 'c'] a_rowbuf = np.ascontiguousarray(a, dtype = np.double)
    cdef np.ndarray[double, ndim=2, mode = 'c'] b_rowbuf = np.ascontiguousarray(b, dtype = np.double)
    cdef np.ndarray[double, ndim=2, mode = 'c'] rowbuf = np.ascontiguousarray(c, dtype = np.double)
    cdef double* a_buf_ptr = <double*> a_rowbuf.data
    cdef double* b_buf_ptr = <double*> b_rowbuf.data
    cdef double* buf_ptr = <double*> rowbuf.data
 
    for i in range(0, N):
        for k in range(0, P, 4):
            val0 , val1, val2, val3 = 0, 0, 0, 0
            for j in range(0, M, 4):

                val0 += a[i][j] *  b[k][j] 
                val0 += a[i][j + 1] *  b[k][j + 1] 
                val0 += a[i][j + 2] *  b[k][j + 2] 
                val0 += a[i][j + 3] *  b[k][j + 3]
                
                val1 += a[i][j] *  b[k + 1][j] 
                val1 +=a[i][j + 1] *  b[k+1][j + 1] 
                val1 +=a[i][j + 2] *  b[k+1][j + 2] 
                val1 +=a[i][j + 3] *  b[k+1][j + 3]
                
                val2 += a[i][j] *  b[k + 2][j]
                val2 +=a[i][j + 1] *  b[k+2][j + 1]
                val2 +=a[i][j + 2] *  b[k+2][j + 2] 
                val2 +=a[i][j + 3] *  b[k+2][j + 3]
                
                val3 += a[i][j] *  b[k + 3][j] 
                val3 += a[i][j + 1] *  b[k+3][j + 1] 
                val3 += a[i][j + 2] *  b[k+3][j + 2] 
                val3 += a[i][j + 3] *  b[k+3][j + 3]
                

            c[i][k] = val0;
            c[i][k+1] = val1;
            c[i][k+2] = val2;
            c[i][k+3] = val3;
            
    return rowbuf

In [None]:
# sanity check

In [None]:
N, M, P = 1024, 1024, 1024
a = np.random.randn(N, M)
b = np.random.randn(M, P)
funcs = ['np.dot',"matmul1", "matmul2"]
c = [eval('np.dot')(a, b)]

[np.allclose(c[0], eval(func)(a, b)) for func in funcs]

In [None]:
import timeit
import matplotlib.pyplot as plt

In [None]:
size_rng = range(8, 128, 4)
ts = []
for i in size_rng:
    N, M, P = i, i, i
    a = np.random.randn(N, M)
    b = np.random.randn(M, P)
    ts.append(
        [
            timeit.timeit(
                globals=globals(),
                stmt=func + '(a,b)', number=100
            ) / 100 for func in funcs
        ]
    )

In [None]:
for idx, func in enumerate(funcs):
    plt.plot(size_rng, np.array(ts)[:, idx], label=func)

plt.xlabel("Mat Size [N]")
plt.ylabel("Time (averaged)")
plt.legend()