# Tridiagonal matrix solver benchmarks

In [1]:
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env OMP_NUM_THREADS=1
%env CUDA_VISIBLE_DEVICES=0

env: OMP_NUM_THREADS=1
env: CUDA_VISIBLE_DEVICES=0


In [2]:
!nvidia-smi

Sun Sep 13 21:41:30 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 207...  Off  | 00000000:09:00.0  On |                  N/A |
|  0%   46C    P5    21W / 215W |    402MiB /  7974MiB |      5%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import math
import numpy as np
import numba as nb
import numba.cuda
from tke import tke
from scipy.linalg import lapack

Load libllvmlite.so .. os.environ[PATH] is: /home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/till/anaconda3/envs/pyhpc-bench-gpu/bin:/home/till/anaconda3/condabin:/usr/local/cuda-11.0/bin:/home/till/anaconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin
Checking /home/till/anaconda3/envs/pyhpc-bench-gpu/lib/python3.7/site-packages/llvmlite/binding/libllvmlite.so
load_library_permanently(libsvml.so)
load_library_permanently .. so.environ[PATH] is: /home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/till/anaconda3/envs/pyhpc-bench-gpu/bin:/home/till/anaconda3/condabin:/usr/local/cuda-11.0/bin:/home/till/anaconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin
load_library_permanently(libsvml.so)


In [4]:
import jax
jax.config.update('jax_enable_x64', True)

In [5]:
shape = (360, 160, 115)

## Implement TDMA

#### NumPy

In [6]:
def tdma_naive(a, b, c, d):
    """
    Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[-1]

    for i in range(1, n):
        w = a[..., i] / b[..., i - 1]
        b[..., i] += -w * c[..., i - 1]
        d[..., i] += -w * d[..., i - 1]

    out = np.empty_like(a)
    out[..., -1] = d[..., -1] / b[..., -1]

    for i in range(n - 2, -1, -1):
        out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]

    return out

#### Lapack

In [7]:
def tdma_lapack(a, b, c, d):
    a[..., 0] = c[..., -1] = 0  # remove couplings between slices
    return lapack.dgtsv(a.flatten()[1:], b.flatten(), c.flatten()[:-1], d.flatten())[3].reshape(a.shape)

#### Numba CPU

In [8]:
@nb.guvectorize([(nb.float64[:],) * 5], '(n), (n), (n), (n) -> (n)', nopython=True)
def tdma_numba(a, b, c, d, out):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[0]

    for i in range(1, n):
        w = a[i] / b[i - 1]
        b[i] += -w * c[i - 1]
        d[i] += -w * d[i - 1]

    out[-1] = d[-1] / b[-1]

    for i in range(n - 2, -1, -1):
        out[i] = (d[i] - c[i] * out[i + 1]) / b[i]

#### Numba CUDA

In [9]:
nconst = shape[-1]


@nb.cuda.jit()
def tdma_numba_cuda_kernel(a, b, c, d, out):
    i, j = nb.cuda.grid(2)
    
    if not(i < a.shape[0] and j < a.shape[1]):
        return

    n = a.shape[2]
    
    cp = nb.cuda.local.array((nconst,), dtype=nb.float64)
    dp = nb.cuda.local.array((nconst,), dtype=nb.float64)
    
    cp[0] = c[i, j, 0] / b[i, j, 0]
    dp[0] = d[i, j, 0] / b[i, j, 0]
    
    for k in range(1, n):
        norm_factor = b[i, j, k] - a[i, j, k] * cp[k-1]
        cp[k] = c[i, j, k] / norm_factor
        dp[k] = (d[i, j, k] - a[i, j, k] * dp[k-1]) / norm_factor

    out[i, j, n-1] = dp[n-1]

    for k in range(n - 2, -1, -1):
        out[i, j, k] = dp[k] - cp[k] * out[i, j, k+1]
        

def tdma_numba_cuda(a, b, c, d):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(a.shape[0] / threadsperblock[0])
    blockspergrid_y = math.ceil(a.shape[1] / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    out = nb.cuda.device_array(a.shape, dtype=a.dtype)
    tdma_numba_cuda_kernel[blockspergrid, threadsperblock](a, b, c, d, out)
    return out

#### JAX

In [10]:
import jax.numpy as jnp
import jax.lax


def tdma_jax_kernel(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a.T, b.T, c.T, d.T)
    init = jnp.zeros((a.shape[1], a.shape[0]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1].T


tdma_jax = jax.jit(tdma_jax_kernel, backend='cpu')
tdma_jax_cuda = jax.jit(tdma_jax_kernel, backend='gpu')

#### CuPy

In [11]:
import cupy

In [12]:
from string import Template

kernel_old = Template('''
extern "C" __global__
void execute(
    const ${DTYPE} *a,
    const ${DTYPE} *b,
    const ${DTYPE} *c,
    const ${DTYPE} *d,
    ${DTYPE} *solution
){
    const size_t m = ${SYS_DEPTH};
    const size_t total_size = ${SIZE};
    const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * m;

    if (idx >= total_size) {
        return;
    }

    ${DTYPE} cp[${SYS_DEPTH}];
    ${DTYPE} dp[${SYS_DEPTH}];

    cp[0] = c[idx] / b[idx];
    dp[0] = d[idx] / b[idx];

    for (ptrdiff_t j = 1; j < m; ++j) {
        const ${DTYPE} norm_factor = b[idx+j] - a[idx+j] * cp[j-1];
        cp[j] = c[idx+j] / norm_factor;
        dp[j] = (d[idx+j] - a[idx+j] * dp[j-1]) / norm_factor;
    }

    solution[idx + m-1] = dp[m-1];
    for (ptrdiff_t j=m-2; j >= 0; --j) {
        solution[idx + j] = dp[j] - cp[j] * solution[idx + j+1];
    }
}
''').substitute(
    DTYPE='double',
    SYS_DEPTH=shape[-1],
    SIZE=np.product(shape)
)

tdma_cupy_kernel_old = cupy.RawKernel(kernel_old, 'execute')


def tdma_cupy_old(a, b, c, d, out,  blocksize=256):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape
    # a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))
    # out = cupy.empty(a.shape, dtype=a.dtype)
    
    tdma_cupy_kernel_old(
        (math.ceil(a.size / a.shape[-1] / blocksize),),
        (blocksize,),
        (a, b, c, d, out)
    )
    
    return out

In [13]:
order_kernel = Template('''
extern "C" __global__
    void order(
        const ${DTYPE} * __restrict__ a,
        const ${DTYPE} * __restrict__ b,
        const ${DTYPE} * __restrict__ c,
        const ${DTYPE} * __restrict__ d,
        ${DTYPE} *a_out,
        ${DTYPE} *b_out,
        ${DTYPE} *c_out,
        ${DTYPE} *d_out
    ){
        
        const unsigned int source_idx = blockIdx.x * blockDim.x + threadIdx.x;

        if (source_idx >= ${SIZE})
            return;

        const unsigned int m = ${SYS_DEPTH};
        const unsigned int stride = ${STRIDE};
        const unsigned int seg_idx = source_idx / m;
        const unsigned int seg_offset = source_idx % m;
        const unsigned int target_idx = seg_offset * stride + seg_idx;

        a_out[target_idx] = a[source_idx]; 
        b_out[target_idx] = b[source_idx];
        c_out[target_idx] = c[source_idx];
        d_out[target_idx] = d[source_idx];
    }

''').substitute(
        DTYPE='double',
        SYS_DEPTH=shape[-1],
        SIZE=np.product(shape),
        STRIDE=shape[0]*shape[1]
)

order_back_kernel = Template('''
extern "C" __global__
    void order_back(
        const ${DTYPE} * __restrict__ out,
        ${DTYPE} *o_out
    ){

        const unsigned int target_idx = blockIdx.x * blockDim.x + threadIdx.x;
        
        if (target_idx >= ${SIZE})
            return;
        const unsigned int m = ${SYS_DEPTH};
        const unsigned int stride = ${STRIDE};
        const unsigned int seg_idx = target_idx / m;
        const unsigned int seg_offset = target_idx % m;
        const unsigned int source_idx = seg_offset * stride + seg_idx;
        o_out[target_idx] = out[source_idx];
    }

''').substitute(
        DTYPE='double',
        SYS_DEPTH=shape[-1],
        SIZE=np.product(shape),
        STRIDE=shape[0]*shape[1]
)

kernel = Template('''
    extern "C" __global__
    void execute(
        const ${DTYPE} * __restrict__ a,
        const ${DTYPE} * __restrict__ b,
        const ${DTYPE} * __restrict__ c,
        const ${DTYPE} * __restrict__ d,
        ${DTYPE} *solution
    ){
        const unsigned int m = ${SYS_DEPTH};
        const unsigned int total_size = ${SIZE};
        const unsigned int stride = total_size / m;
        const unsigned int idx = blockDim.x * blockIdx.x + threadIdx.x;
        if (idx >= stride) {
            return;
        }

        ${DTYPE} cp[${SYS_DEPTH}];
        ${DTYPE} dp[${SYS_DEPTH}];

        cp[0] = c[idx] / b[idx];
        dp[0] = d[idx] / b[idx];
 
        for (int j = 1; j < m; ++j) {
            unsigned int indj = idx+(j*stride);
            const ${DTYPE} norm_factor = (b[indj] - a[indj] * cp[j-1]);
            cp[j] = c[indj] / norm_factor;
            dp[j] = (d[indj] - a[indj] * dp[j-1]) / norm_factor;
        }

        solution[idx + stride*(m-1)] = dp[m-1];

        for (int j=m-2; j >= 0; --j) {
            solution[idx + stride*j] = dp[j] - cp[j] * solution[idx + stride*(j+1)];
        }
    }
    ''').substitute(
        DTYPE='double',
        SYS_DEPTH=shape[-1],
        SIZE=np.product(shape),
        STRIDE=shape[0]*shape[1]
)
tdma_cupy_kernel = cupy.RawKernel(kernel, 'execute')
cupy_order_kernel = cupy.RawKernel(order_kernel, 'order')
cupy_order_back_kernel = cupy.RawKernel(order_back_kernel, 'order_back')


def tdma_cupy(a, b, c, d, o, blocksize=256):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    # a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))
    
    a_tmp, b_tmp, c_tmp, d_tmp, o_out = (cupy.empty(s.shape, dtype=s.dtype) for s in (a,b,c,d,a))

    cupy_order_kernel(
        (math.ceil(a.size  / blocksize),),
        (blocksize,),
        (a, b, c, d, a_tmp, b_tmp, c_tmp, d_tmp)
    )

    tdma_cupy_kernel(
        (math.ceil(a.size / a.shape[-1] / blocksize),),
        (blocksize,),
        (a_tmp, b_tmp, c_tmp, d_tmp, o_out)
    )
    
    cupy_order_back_kernel(
        (math.ceil(a.size / blocksize),),
        (blocksize,),
        (o_out, o)
    )
    return out


In [14]:
def tdma_futhark(a, b, c, d, futhark_tke):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    # dim1, dim2, seg_size = a.shape
    # # seg_count = dim1*dim2
    # a_s = np.reshape(a, (dim1*dim2, seg_size))
    # b_s = np.reshape(b, (dim1*dim2, seg_size))
    # c_s = np.reshape(c, (dim1*dim2, seg_size))
    # d_s = np.reshape(d, (dim1*dim2, seg_size))
    
    
    
    res = futhark_tke.tridagNested(a_s, b_s, c_s, d_s)
    
    
    return res #np.reshape(res, a.shape)

## Check results

In [15]:
np.random.seed(17)
a, b, c, d = np.random.randn(4, *shape)
res_naive = tdma_naive(a, b, c, d)

for imp in (tdma_cupy, tdma_cupy_old, tdma_lapack, tdma_numba, tdma_numba_cuda, tdma_jax, tdma_jax_cuda):
    np.random.seed(17)
    a, b, c, d = np.random.randn(4, *shape)
    out = imp(a, b, c, d)
    try:
        out = out.get()
    except AttributeError:
        pass

    np.testing.assert_allclose(out, res_naive)
    print('✔️')

✔️
✔️
✔️
✔️
✔️
✔️
✔️


## Benchmark

In [21]:
np.random.seed(17)
a, b, c, d = np.random.randn(4, shape[0], shape[1], shape[2])


### CPU

#### NumPy

In [68]:
%%timeit
tdma_naive(a, b, c, d)

89.1 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Lapack

In [28]:
%%timeit
tdma_lapack(a, b, c, d)

6.19 µs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


#### Numba

In [53]:
%%timeit
tdma_numba(a, b, c, d)

1 s ± 10.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#### JAX

In [54]:
%%timeit
tdma_jax(a, b, c, d).block_until_ready()

1.15 s ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### GPU

#### Numba

In [55]:
ac, bc, cc, dc = (nb.cuda.to_device(k) for k in (a, b, c, d))
tdma_numba_cuda(ac, bc, cc, dc);  # trigger compilation

ERROR:numba.cuda.cudadrv.driver:Call to cuMemAlloc results in CUDA_ERROR_OUT_OF_MEMORY
ERROR:numba.cuda.cudadrv.driver:Call to cuMemAlloc results in CUDA_ERROR_OUT_OF_MEMORY


CudaAPIError: [2] Call to cuMemAlloc results in CUDA_ERROR_OUT_OF_MEMORY

In [None]:
%%timeit
tdma_numba_cuda(ac, bc, cc, dc)
numba.cuda.synchronize()

#### JAX

In [None]:
aj, bj, cj, dj = (jnp.array(k).block_until_ready() for k in (a, b, c, d))

In [None]:
%%timeit
tdma_jax_cuda(aj, bj, cj, dj).block_until_ready()

#### CuPy

In [2]:
# stream = cupy.cuda.stream.Stream()

# with stream:
ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))
o = cupy.empty(ac.shape, dtype=ac.dtype)
tdma_cupy(ac, bc, cc, dc);  # trigger compilation

# stream.synchronize()

NameError: name 'a' is not defined

In [3]:
import time
runs = 20

total_time = 0

for i in range(runs):
    ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))
    o = cupy.empty(ac.shape, dtype=ac.dtype)    

    cupy.cuda.Stream.null.synchronize()
    start_time = time.time()    
    tdma_cupy(ac, bc, cc, dc, o)
    cupy.cuda.Stream.null.synchronize()
    total_time += (time.time() - start_time)
    del ac
    del bc
    del cc
    del dc

print((total_time) * 1000 / float(runs))

    


NameError: name 'a' is not defined

### Futhark

In [71]:
dim1, dim2, seg_size = a.shape
# seg_count = dim1*dim2
a_s = np.reshape(a, (dim1*dim2, seg_size))
b_s = np.reshape(b, (dim1*dim2, seg_size))
c_s = np.reshape(c, (dim1*dim2, seg_size))
d_s = np.reshape(d, (dim1*dim2, seg_size))

futhark_tke = tke()

In [72]:
%%timeit
tdma_futhark(a, b, c, d, futhark_tke)

857 µs ± 3.02 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Try Jax without transposes

In [25]:
def tdma_jax_kernel_notrans(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a, b, c, d)
    init = jnp.zeros((a.shape[1], a.shape[2]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1]


tdma_jax_notrans = jax.jit(tdma_jax_kernel_notrans, backend='cpu')
tdma_jax_cuda_notrans = jax.jit(tdma_jax_kernel_notrans, backend='gpu')

In [26]:
at, bt, ct, dt = (k.T for k in (a, b, c, d))

In [27]:
%%timeit
tdma_jax_notrans(at, bt, ct, dt).block_until_ready()

537 ms ± 3.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [28]:
aj, bj, cj, dj = (jnp.array(k.T).block_until_ready() for k in (a, b, c, d))

In [29]:
%%timeit
tdma_jax_cuda_notrans(aj, bj, cj, dj).block_until_ready()

10.1 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
