# This is a playground notebook

In order to evaluate int4 compression, we spent some time trying to get a decently fast matrix multiply operation implemented using `numpy` and `numba`. Unfortunately it was quite hard to compete with the performance of built-in `matmul` in `numpy` due to that function relying on very well optimized BLAS routines.

In [1]:
%load_ext cython

In [2]:
%%cython --compile-args=-fopenmp --link-args=-fopenmp
cimport cython
from cython.parallel import prange

@cython.boundscheck(False)
@cython.wraparound(False)
def cython_matmul(cython.uint[:, ::1] out, cython.char[:, ::1] a, cython.char[::1, :] b, cython.long n, cython.long m, cython.long d):
    cdef cython.uint tmp
    cdef cython.int i, j, k
    for i in prange(n, nogil=True):
        for j in range(m):
            tmp = 0
            for k in range(d):
                tmp = tmp + a[i, k] * b[k, j]
            out[i, j] = out[i, j] + tmp

In [3]:
from multiprocessing import cpu_count
from multiprocessing.pool import ThreadPool
from collections import deque

import numba
import numpy as np
from numpy.typing import NDArray

jit = numba.njit(error_model="numpy", fastmath=True)
ks = 16  # Kernel size.

# Test data.
n, m, d = 256, 4 * 4096, 256
rng = np.random.default_rng(0)
a = rng.choice(15, (n, d)).astype(np.uint8)
b = rng.choice(15, (m, d)).astype(np.uint8).T

def uint8_matmul_cython(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    n, d = a.shape
    d2, m = b.shape
    assert d == d2
    assert a.flags.c_contiguous
    assert b.flags.f_contiguous
    out = np.zeros((n, m), dtype=np.uint32)
    cython_matmul(out, a, b, n, m, d)
    return out

@jit
def uint8_matmul_numba(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    """Optimized multi-threaded implementation of matmul between uint4-stored-in-uint8 values."""
    n, d = a.shape
    d2, m = b.shape
    assert d2 == d
    assert (a < 16).all(), "Large value will trigger multiplication overlfow"
    assert (b < 16).all(), "Large value will trigger multiplication overlfow"
    out = np.empty((n, m), dtype=np.uint32)
    for i in range(n):
        row = a[i, :]
        for j in range(m):
            col = b[:, j]
            tmp = np.uint32(0)
            for k in range(d):
                tmp += row[k] * col[k]
            out[i, j] = tmp
    return out

def uint8_matmul_einsum(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    return np.einsum("ij, jk -> ik", a, b, dtype=np.uint32)

@jit
def _mm_uint8_kernel(out: NDArray[np.uint32], a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    for i in range(ks):
        row = a[i, :]
        for j in range(ks):
            col = b[:, j]
            tmp = np.uint32(0)
            for k in range(ks):
                tmp += row[k] * col[k]
            out[i, j] += tmp
    return out

@jit
def uint8_matmul_tiled(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    n, d = a.shape
    d2, m = b.shape
    assert (a < 16).all(), "Large value will trigger multiplication overlfow"
    assert (b < 16).all(), "Large value will trigger multiplication overlfow"
    assert d == d2
    assert n % ks == 0
    assert m % ks == 0
    assert d % ks == 0
    out = np.zeros((n, m), dtype=np.uint32)
    n_chunks = n // ks
    m_chunks = m // ks
    d_chunks = d // ks
    for ijp in range(n_chunks * m_chunks):
        i, j = divmod(ijp, m_chunks)
        i_start = i * ks
        i_end = i_start + ks
        j_start = j * ks
        j_end = j_start + ks
        out_chunk = out[i_start:i_end, j_start:j_end]
        for k in range(d_chunks):
            k_start = k * ks
            k_end = k_start + ks
            a_chunk = a[i_start:i_end, k_start:k_end]
            b_chunk = b[k_start:k_end, j_start:j_end]
            _mm_uint8_kernel(out_chunk, a_chunk, b_chunk)
    return out

def uint8_matmul_einsum_threaded(a: NDArray[np.uint8], b: NDArray[np.uint8], num_thread: int = 4) -> NDArray[np.uint32]:
    n, d = a.shape
    d2, m = b.shape

    # Allocate output.
    out = np.empty((n, m), dtype=np.uint32)

    # Swap a and b via transposing if b is bigger, since we split on a.
    transpose = m > n
    if transpose:
        tmp = a
        a = b.T
        b = tmp.T
        out = out.T
        n, d = a.shape
        d2, m = b.shape

    slice_size = n // num_thread
    row_slices = [slice(start, start + slice_size) for start in range(0, n, slice_size)]

    def _target(s):
        out[s, :] = uint8_matmul_einsum(a[s, :], b)
        
    with ThreadPool(num_thread) as pool:
        deque(pool.map(_target, row_slices), maxlen=0)

    # Un-transpose if we transposed above.
    if transpose:
        out = out.T

    return out

## Trying to beat BLAS-based fp32 matmul

In [4]:
ground_truth = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.uint32)

In [5]:
%%timeit
_ = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.uint32)

13.8 ms ± 2.65 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Attempts

In [6]:
assert np.all(uint8_matmul_cython(a, b) == ground_truth)
assert np.all(uint8_matmul_numba(a, b) == ground_truth)
assert np.all(uint8_matmul_einsum(a, b) == ground_truth)
assert np.all(uint8_matmul_tiled(a, b) == ground_truth)
assert np.all(uint8_matmul_einsum_threaded(a, b) == ground_truth)

In [7]:
%%timeit
_ = uint8_matmul_einsum(a, b)

202 ms ± 989 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
_ = uint8_matmul_numba(a, b)

358 ms ± 2.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit
_ = uint8_matmul_tiled(a, b)

381 ms ± 2.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
%%timeit
_ = uint8_matmul_einsum_threaded(a, b, num_thread=8)

33.9 ms ± 1.72 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [11]:
%%timeit
_ = uint8_matmul_cython(a, b)

12.5 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# A possibly easier challenge - vector-matrix multiplication

In practice, search is often performed one query at a time, implying a (query embedding) vector vs. (document embedding) matrix multiplication. Let's see if we can accelerate this operation on uint8 datatypes to be competitive with the BLAS baseline.

In [12]:
@jit
def uint8_vector_matrix_multiplication(mat: NDArray[np.uint8], vec: NDArray[np.uint8]) -> NDArray[np.uint32]:
    """Optimized multi-threaded implementation of matmul between uint4-stored-in-uint8 values."""
    assert mat.ndim == 2
    assert vec.ndim == 1
    n, d = mat.shape
    (d2,) = vec.shape
    assert d2 == d
    assert (mat < 16).all(), "Large value will trigger multiplication overlfow"
    assert (vec < 16).all(), "Large value will trigger multiplication overlfow"
    out = np.empty(n, dtype=np.uint32)
    for i in range(n):
        row = mat[i, :]
        tmp = np.uint32(0)
        for j in range(d):
            tmp += row[j] * vec[j]
        out[i] = tmp
    return out

@jit
def _mv_uint8_kernel(out: NDArray[np.uint32], mat: NDArray[np.uint8], vec: NDArray[np.uint8]) -> NDArray[np.uint32]:
    for i in range(ks):
        row = mat[i, :]
        for k in range(ks):
            out[i] += row[k] * vec[k]
    return out

@jit
def mv_uint8_tiled(mat: NDArray[np.uint8], vec: NDArray[np.uint8]) -> NDArray[np.uint32]:
    assert mat.ndim == 2
    assert vec.ndim == 1
    n, d = mat.shape
    (d2,) = vec.shape
    assert d2 == d
    assert (mat < 16).all(), "Large value will trigger multiplication overlfow"
    assert (vec < 16).all(), "Large value will trigger multiplication overlfow"
    assert n % ks == 0
    assert d % ks == 0
    out = np.zeros(n, dtype=np.uint32)
    n_chunks = n // ks
    d_chunks = d // ks
    for j in range(d_chunks):
        j_start = j * ks
        j_end = j_start + ks
        vec_chunk = vec[j_start:j_end]
        for i in range(n_chunks):
            i_start = i * ks
            i_end = i_start + ks
            mat_chunk = mat[i_start:i_end, j_start:j_end]
            out_chunk = out[i_start:i_end]
            _mv_uint8_kernel(out_chunk, mat_chunk, vec_chunk)
    return out


def mat_vec_einsum(mat: NDArray[np.uint8], vec: NDArray[np.uint8]) -> NDArray[np.uint32]:
    return np.einsum("ij, j -> i", mat, vec, dtype=np.uint32)



def mat_vec_einsum_multithread(mat: NDArray[np.uint8], vec: NDArray[np.uint8]) -> NDArray[np.uint32]:
    n, d = mat.shape
    num_thread = cpu_count()

    # Split the rows of the matrix across worker threads to dispatch to multiple CPU cores.
    slice_size = n // num_thread
    mat_row_slices = [slice(start, start + slice_size) for start in range(0, n, slice_size)]
    out = np.empty(n, dtype=np.uint32)

    def _target(s):
        out[s] = mat_vec_einsum(mat[s, :], vec)
        
    with ThreadPool(num_thread) as pool:
        deque(pool.map(_target, mat_row_slices), maxlen=0)

    return out

In [13]:
m_matvec = 1024 * 1024
d_matvec = 256
mat = rng.choice(15, (m_matvec, d_matvec)).astype(np.uint8)
vec = rng.choice(15, (d_matvec,)).astype(np.uint8)

In [14]:
mv_gt = (mat.astype(np.float32) @ vec.astype(np.float32)).astype(np.uint32)

In [15]:
%%timeit
_ = (mat.astype(np.float32) @ vec.astype(np.float32)).astype(np.uint32)

67.4 ms ± 1.92 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
assert np.all(uint8_vector_matrix_multiplication(mat, vec) == mv_gt)
assert np.all(mv_uint8_tiled(mat, vec) == mv_gt)
assert np.all(mat_vec_einsum(mat, vec) == mv_gt)
assert np.all(mat_vec_einsum_multithread(mat, vec) == mv_gt)
assert np.all(uint8_matmul_cython(vec[None, :], mat.T) == mv_gt)

In [17]:
%%timeit
uint8_vector_matrix_multiplication(mat, vec)

424 ms ± 1.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
%%timeit
mv_uint8_tiled(mat, vec)

629 ms ± 4.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
%%timeit
mat_vec_einsum(mat, vec)

50.5 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [20]:
%%timeit
mat_vec_einsum_multithread(mat, vec)

12.9 ms ± 284 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [21]:
%%timeit
uint8_matmul_cython(vec[None, :], mat.T).T

16.8 ms ± 135 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
