# This is a playground notebook

In order to evaluate int4 compression, we spent some time trying to get a decently fast matrix multiply operation implemented using `numpy` and `numba`. Unfortunately it was quite hard to compete with the performance of built-in `matmul` in `numpy` due to that function relying on very well optimized BLAS routines.

In [1]:
import numba
import numpy as np
from numpy.typing import NDArray

jit = numba.njit(error_model="numpy", fastmath=True)
ks = 16  # Kernel size.

# Test data.
n, m, d = 256, 4 * 4096, 256
rng = np.random.default_rng(0)
a = rng.choice(15, (n, d)).astype(np.uint8)
b = rng.choice(15, (m, d)).astype(np.uint8).T

@jit
def uint8_matmul(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    """Optimized multi-threaded implementation of matmul between uint4-stored-in-uint8 values."""
    n, d = a.shape
    d2, m = b.shape
    assert d2 == d
    assert (a < 16).all(), "Large value will trigger multiplication overlfow"
    assert (b < 16).all(), "Large value will trigger multiplication overlfow"
    out = np.empty((n, m), dtype=np.uint32)
    for i in range(n):
        row = a[i, :]
        for j in range(m):
            col = b[:, j]
            tmp = np.uint32(0)
            for k in range(d):
                tmp += row[k] * col[k]
            out[i, j] = tmp
    return out

@jit
def _mm_uint8_kernel(out: NDArray[np.uint32], a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    tmp = np.uint32(0)
    for i in range(ks):
        row = a[i, :]
        for j in range(ks):
            col = b[:, j]
            tmp = np.uint32(0)
            for k in range(ks):
                tmp += row[k] * col[k]
            out[i, j] += tmp
    return out

@jit
def mm_uint8_tiled(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    n, d = a.shape
    d2, m = b.shape
    assert (a < 16).all(), "Large value will trigger multiplication overlfow"
    assert (b < 16).all(), "Large value will trigger multiplication overlfow"
    assert d == d2
    assert n % ks == 0
    assert m % ks == 0
    assert d % ks == 0
    out = np.zeros((n, m), dtype=np.uint32)
    n_chunks = n // ks
    m_chunks = m // ks
    d_chunks = d // ks
    for ijp in range(n_chunks * m_chunks):
        i, j = divmod(ijp, m_chunks)
        i_start = i * ks
        i_end = i_start + ks
        j_start = j * ks
        j_end = j_start + ks
        out_chunk = out[i_start:i_end, j_start:j_end]
        for k in range(d_chunks):
            k_start = k * ks
            k_end = k_start + ks
            a_chunk = a[i_start:i_end, k_start:k_end]
            b_chunk = b[k_start:k_end, j_start:j_end]
            _mm_uint8_kernel(out_chunk, a_chunk, b_chunk)
    return out


def mm_einsum(a: NDArray[np.uint8], b: NDArray[np.uint8]) -> NDArray[np.uint32]:
    return np.einsum("ij, jk -> ik", a, b, dtype=np.uint32)

## Trying to beat BLAS-based fp32 matmul

In [2]:
ground_truth = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.uint32)

In [3]:
%%timeit
_ = (a.astype(np.float32) @ b.astype(np.float32)).astype(np.uint32)

15.5 ms ± 3.15 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


## My attempts
(note: single-threaded, so I'd happy even being 1/10 as fast)

In [4]:
o1 = uint8_matmul(a, b)
assert np.all(o1 == ground_truth)

In [5]:
%%timeit
_ = uint8_matmul(a, b)

366 ms ± 16.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
o2 = mm_uint8_tiled(a, b)
assert np.all(o2 == ground_truth)

In [7]:
%%timeit
_ = mm_uint8_tiled(a, b)

383 ms ± 3.66 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
o3 = mm_einsum(a, b)
assert np.all(o3 == ground_truth)

In [9]:
%%timeit
_ = mm_einsum(a, b)

199 ms ± 2.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
