# Tridiagonal matrix solver benchmarks

In [10]:
# %env XLA_PYTHON_CLIENT_PREALLOCATE=false
%env OMP_NUM_THREADS=1
%env CUDA_VISIBLE_DEVICES=0

env: OMP_NUM_THREADS=1
env: CUDA_VISIBLE_DEVICES=0


In [11]:
!nvidia-smi

Fri Oct  2 15:31:32 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 207...  Off  | 00000000:09:00.0  On |                  N/A |
|  0%   38C    P5    14W / 215W |    607MiB /  7974MiB |      7%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [12]:
import math
import numpy as np
import numba as nb
import numba.cuda
from tke import tke
from scipy.linalg import lapack
from jax_xla import tridiag 

In [13]:
import jax
jax.config.update('jax_enable_x64', True)

In [14]:
shape = (360, 160, 115)

## Implement TDMA

#### NumPy

In [15]:
def tdma_naive(a, b, c, d):
    """
    Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[-1]

    for i in range(1, n):
        w = a[..., i] / b[..., i - 1]
        b[..., i] += -w * c[..., i - 1]
        d[..., i] += -w * d[..., i - 1]

    out = np.empty_like(a)
    out[..., -1] = d[..., -1] / b[..., -1]

    for i in range(n - 2, -1, -1):
        out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]

    return out

#### Lapack

In [16]:
def tdma_lapack(a, b, c, d):
    a[..., 0] = c[..., -1] = 0  # remove couplings between slices
    return lapack.dgtsv(a.flatten()[1:], b.flatten(), c.flatten()[:-1], d.flatten())[3].reshape(a.shape)

#### Numba CPU

In [17]:
@nb.guvectorize([(nb.float64[:],) * 5], '(n), (n), (n), (n) -> (n)', nopython=True)
def tdma_numba(a, b, c, d, out):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[0]

    for i in range(1, n):
        w = a[i] / b[i - 1]
        b[i] += -w * c[i - 1]
        d[i] += -w * d[i - 1]

    out[-1] = d[-1] / b[-1]

    for i in range(n - 2, -1, -1):
        out[i] = (d[i] - c[i] * out[i + 1]) / b[i]

#### Numba CUDA

In [18]:
nconst = shape[-1]


@nb.cuda.jit()
def tdma_numba_cuda_kernel(a, b, c, d, out):
    i, j = nb.cuda.grid(2)
    
    if not(i < a.shape[0] and j < a.shape[1]):
        return

    n = a.shape[2]
    
    cp = nb.cuda.local.array((nconst,), dtype=nb.float64)
    dp = nb.cuda.local.array((nconst,), dtype=nb.float64)
    
    cp[0] = c[i, j, 0] / b[i, j, 0]
    dp[0] = d[i, j, 0] / b[i, j, 0]
    
    for k in range(1, n):
        norm_factor = b[i, j, k] - a[i, j, k] * cp[k-1]
        cp[k] = c[i, j, k] / norm_factor
        dp[k] = (d[i, j, k] - a[i, j, k] * dp[k-1]) / norm_factor

    out[i, j, n-1] = dp[n-1]

    for k in range(n - 2, -1, -1):
        out[i, j, k] = dp[k] - cp[k] * out[i, j, k+1]
        

def tdma_numba_cuda(a, b, c, d):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    threadsperblock = (16, 16)
    blockspergrid_x = math.ceil(a.shape[0] / threadsperblock[0])
    blockspergrid_y = math.ceil(a.shape[1] / threadsperblock[1])
    blockspergrid = (blockspergrid_x, blockspergrid_y)

    out = nb.cuda.device_array(a.shape, dtype=a.dtype)
    tdma_numba_cuda_kernel[blockspergrid, threadsperblock](a, b, c, d, out)
    return out

#### JAX

In [19]:
import jax.numpy as jnp
import jax.lax


def tdma_jax_kernel(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a.T, b.T, c.T, d.T)
    init = jnp.zeros((a.shape[1], a.shape[0]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1].T


tdma_jax = jax.jit(tdma_jax_kernel, backend='cpu')
tdma_jax_cuda = jax.jit(tridiag.tridiag, backend='gpu') # jax.jit(tdma_jax_kernel, backend='gpu')

#### CuPy

In [20]:
import cupy

In [21]:
from string import Template

kernel_old = Template('''
extern "C" __global__
void execute(
    const ${DTYPE} *a,
    const ${DTYPE} *b,
    const ${DTYPE} *c,
    const ${DTYPE} *d,
    ${DTYPE} *solution
){
    const size_t m = ${SYS_DEPTH};
    const size_t total_size = ${SIZE};
    const size_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * m;

    if (idx >= total_size) {
        return;
    }

    ${DTYPE} cp[${SYS_DEPTH}];
    ${DTYPE} dp[${SYS_DEPTH}];

    cp[0] = c[idx] / b[idx];
    dp[0] = d[idx] / b[idx];

    for (ptrdiff_t j = 1; j < m; ++j) {
        const ${DTYPE} norm_factor = b[idx+j] - a[idx+j] * cp[j-1];
        cp[j] = c[idx+j] / norm_factor;
        dp[j] = (d[idx+j] - a[idx+j] * dp[j-1]) / norm_factor;
    }

    solution[idx + m-1] = dp[m-1];
    for (ptrdiff_t j=m-2; j >= 0; --j) {
        solution[idx + j] = dp[j] - cp[j] * solution[idx + j+1];
    }
}
''').substitute(
    DTYPE='double',
    SYS_DEPTH=shape[-1],
    SIZE=np.product(shape)
)

tdma_cupy_kernel_old = cupy.RawKernel(kernel_old, 'execute')


def tdma_cupy_old(a, b, c, d,  blocksize=256):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape
    a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))
    out = cupy.empty(a.shape, dtype=a.dtype)
    
    tdma_cupy_kernel_old(
        (math.ceil(a.size / a.shape[-1] / blocksize),),
        (blocksize,),
        (a, b, c, d, out)
    )
    
    return out

In [22]:
tile_dim = 32
block_rows = 8 
transpose4_kernel = Template('''
extern "C" __global__
    void transpose4(
  const ${DTYPE}* a,
  const ${DTYPE}* b,
  const ${DTYPE}* c,
  const ${DTYPE}* d,
  ${DTYPE}* a_t,
  ${DTYPE}* b_t,
  ${DTYPE}* c_t,
  ${DTYPE}* d_t,
  int xdim,
  int ydim,
  int total_size
)
{
  __shared__ ${DTYPE} tile[4*${TILE_DIM}][${TILE_DIM}+1];

  int x = blockIdx.x * ${TILE_DIM} + threadIdx.x;
  int y = blockIdx.y * ${TILE_DIM} + threadIdx.y;
  
  if (x < xdim)
  {
    for (int j = 0; j < ${TILE_DIM}; j += ${BLOCK_ROWS})
    {
        int index = (y+j)*xdim + x;
        if (index < total_size)
        {  
          tile[threadIdx.y+j][threadIdx.x] = a[index];
          tile[${TILE_DIM} + threadIdx.y+j][threadIdx.x] = b[index];
          tile[2 * ${TILE_DIM} + threadIdx.y+j][threadIdx.x] = c[index];
          tile[3 * ${TILE_DIM} + threadIdx.y+j][threadIdx.x] = d[index];
        }
    }
  }
  __syncthreads();

  x = blockIdx.y * ${TILE_DIM} + threadIdx.x;  // transpose block offset
  y = blockIdx.x * ${TILE_DIM} + threadIdx.y;
  if (x < ydim)
  {
    for (int j = 0; j < ${TILE_DIM}; j += ${BLOCK_ROWS})
    {
      int index = (y+j)*ydim + x;
      if (index < total_size)
      {
        
        a_t[index] = tile[threadIdx.x][threadIdx.y + j];
        b_t[index] = tile[${TILE_DIM} + threadIdx.x][threadIdx.y + j];
        c_t[index] = tile[2 * ${TILE_DIM} + threadIdx.x][threadIdx.y + j];
        d_t[index] = tile[3 * ${TILE_DIM} + threadIdx.x][threadIdx.y + j];
      }
    }
  }
}
''').substitute(
        DTYPE='double',
        TILE_DIM=tile_dim,
        BLOCK_ROWS=block_rows
)


transpose_kernel = Template('''
extern "C" __global__
void transpose(
  const ${DTYPE}* m,
  ${DTYPE}* m_t,
  int xdim,
  int ydim,
  int total_size
)
{
  __shared__ ${DTYPE} tile[${TILE_DIM}][${TILE_DIM}+1];

  int x = blockIdx.x * ${TILE_DIM} + threadIdx.x;
  int y = blockIdx.y * ${TILE_DIM} + threadIdx.y;
  
  if (x < xdim)
  {
    for (int j = 0; j < ${TILE_DIM}; j += ${BLOCK_ROWS})
    {
        int index = (y+j)*xdim + x;
        if (index < total_size)
          tile[threadIdx.y+j][threadIdx.x] = m[index];
    }
  }
  __syncthreads();

  x = blockIdx.y * ${TILE_DIM} + threadIdx.x;  // transpose block offset
  y = blockIdx.x * ${TILE_DIM} + threadIdx.y;
  if (x < ydim)
  {
    for (int j = 0; j < ${TILE_DIM}; j += ${BLOCK_ROWS})
    {
      int index = (y+j)*ydim + x;
      if (index < total_size)
        m_t[index] = tile[threadIdx.x][threadIdx.y + j];
    }
  }
}
''').substitute(
        DTYPE='double',
        TILE_DIM=tile_dim,
        BLOCK_ROWS=block_rows
)


kernel = Template('''
    extern "C" __global__
    void execute(
        const ${DTYPE} * __restrict__ a,
        const ${DTYPE} * __restrict__ b,
        const ${DTYPE} * __restrict__ c,
        const ${DTYPE} * __restrict__ d,
        ${DTYPE} *solution
    ){
        const unsigned int n = ${SYS_DEPTH};
        const unsigned int num_chunks = ${STRIDE};
        const unsigned int idx = blockDim.x * blockIdx.x + threadIdx.x;
        if (idx >= num_chunks) {
            return;
        }

        ${DTYPE} cp[n];
        ${DTYPE} dp[n];

        cp[0] = c[idx] / b[idx];
        dp[0] = d[idx] / b[idx];

        #pragma unroll
        for (int j = 1; j < n; ++j) {
            unsigned int indj = idx+(j*num_chunks);
            const ${DTYPE} norm_factor = 1.0 /(b[indj] - a[indj] * cp[j-1]);
            cp[j] = c[indj] * norm_factor;
            dp[j] = (d[indj] - a[indj] * dp[j-1]) * norm_factor;
        }

        int stridedIndex = num_chunks*(n-1);
        solution[idx + stridedIndex] = dp[n-1];
        
        #pragma unroll
        for (int j=n-2; j >= 0; --j)
        {
            ${DTYPE} s = dp[j] - cp[j] * solution[idx + stridedIndex];
            stridedIndex -= num_chunks;
            solution[idx + stridedIndex] = s;
        }
    }
    ''').substitute(
        DTYPE='double',
        SYS_DEPTH=shape[-1],
        SIZE=np.product(shape),
        STRIDE=shape[0]*shape[1]
)
transpose4 = cupy.RawKernel(transpose4_kernel, 'transpose4')
transpose = cupy.RawKernel(transpose_kernel, 'transpose')
tdma_cupy_kernel = cupy.RawKernel(kernel, 'execute')


def tdma_cupy(a, b, c, d, o=None, blocksize=256):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape
    
    if o is None:
      a, b, c, d = (cupy.asarray(k) for k in (a, b, c, d))
      o = cupy.empty(a.shape, dtype=a.dtype)

    a_tmp, b_tmp, c_tmp, d_tmp, o_out = (cupy.empty(s.shape, dtype=s.dtype) for s in (a,b,c,d,a))
    
    xdim = int((shape[2] + tile_dim -1) / tile_dim)
    ydim = int((shape[1]*shape[0] + tile_dim - 1) / tile_dim)
    transpose4(
        (xdim, ydim, 1),
        (tile_dim, block_rows, 1),
        (a, b, c, d, a_tmp, b_tmp, c_tmp, d_tmp, int(shape[2]), int(shape[1]*shape[0]), int(np.product(shape)))
    )
    
    tdma_cupy_kernel(
        (math.ceil(a.size / a.shape[-1] / blocksize),),
        (blocksize,),
        (a_tmp, b_tmp, c_tmp, d_tmp, o_out)
    )
    
    transpose(
        (ydim, xdim, 1),
        (tile_dim, block_rows, 1),
        (o_out, o, int(shape[1]*shape[0]), int(shape[2]), int(np.product(shape)))
    )
    return o


In [23]:
def tdma_futhark(a, b, c, d, futhark_tke):
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    # dim1, dim2, seg_size = a.shape
    # # seg_count = dim1*dim2
    # a_s = np.reshape(a, (dim1*dim2, seg_size))
    # b_s = np.reshape(b, (dim1*dim2, seg_size))
    # c_s = np.reshape(c, (dim1*dim2, seg_size))
    # d_s = np.reshape(d, (dim1*dim2, seg_size))
    
    
    
    res = futhark_tke.tridagNested(a_s, b_s, c_s, d_s)
    
    
    return res #np.reshape(res, a.shape)

## Check results

In [25]:
np.random.seed(17)
a, b, c, d = np.random.randn(4, *shape)
res_naive = tdma_naive(a, b, c, d)

for imp in (tdma_jax_cuda, tdma_cupy, tdma_cupy_old, tdma_lapack, tdma_numba, tdma_numba_cuda, tdma_jax):
    np.random.seed(17)
    a, b, c, d = np.random.randn(4, *shape)
    
    out = imp(a, b, c, d)
    try:
        out = out.get()
    except AttributeError:
        pass
    np.testing.assert_allclose(out, res_naive)
    print('✔️')

<class 'jaxlib.xla_extension.XlaOp'>


TypeError: Argument '<jaxlib.xla_extension.XlaOp object at 0x7f67c4784430>' of type '<class 'jaxlib.xla_extension.XlaOp'>' is not a valid JAX type

## Benchmark

In [67]:
np.random.seed(17)
a, b, c, d = np.random.randn(4, shape[0], shape[1], shape[2])


### CPU

#### NumPy

In [31]:
%%timeit
tdma_naive(a, b, c, d)

180 ms ± 7.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Lapack

In [32]:
%%timeit
tdma_lapack(a, b, c, d)

153 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### Numba

In [33]:
%%timeit
tdma_numba(a, b, c, d)

99.9 ms ± 77.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


#### JAX

In [34]:
%%timeit
tdma_jax(a, b, c, d).block_until_ready()

119 ms ± 3.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


### GPU

#### Numba

In [39]:
ac, bc, cc, dc = (nb.cuda.to_device(k) for k in (a, b, c, d))
tdma_numba_cuda(ac, bc, cc, dc);  # trigger compilation

In [40]:
%%timeit
tdma_numba_cuda(ac, bc, cc, dc)
numba.cuda.synchronize()

17 ms ± 519 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### JAX

In [68]:
aj, bj, cj, dj = (jnp.array(k).block_until_ready() for k in (a, b, c, d))

In [69]:
%%timeit
tdma_jax_cuda(aj, bj, cj, dj).block_until_ready()

17.3 ms ± 94.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### CuPy

In [101]:
# stream = cupy.cuda.stream.Stream()

# with stream:
ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))
o = cupy.empty(ac.shape, dtype=ac.dtype)
tdma_cupy(ac, bc, cc, dc,o);  # trigger compilation

# stream.synchronize()

In [102]:
import time
runs = 20

total_time = 0

for i in range(runs):
    ac, bc, cc, dc = (cupy.asarray(k) for k in (a, b, c, d))
    o = cupy.empty(ac.shape, dtype=ac.dtype)
    cupy.cuda.Stream.null.synchronize()
    start_time = time.time()    
    tdma_cupy(ac, bc, cc, dc, o)
    cupy.cuda.Stream.null.synchronize()
    total_time += (time.time() - start_time)
    del ac
    del bc
    del cc
    del dc

print((total_time) * 1000 / float(runs))

    


2.8993964195251465


### Futhark

In [49]:
dim1, dim2, seg_size = a.shape
# seg_count = dim1*dim2
a_s = np.reshape(a, (dim1*dim2, seg_size))
b_s = np.reshape(b, (dim1*dim2, seg_size))
c_s = np.reshape(c, (dim1*dim2, seg_size))
d_s = np.reshape(d, (dim1*dim2, seg_size))

futhark_tke = tke()

In [50]:
%%timeit
tdma_futhark(a, b, c, d, futhark_tke)

57.4 ms ± 743 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Try Jax without transposes

In [47]:
def tdma_jax_kernel_notrans(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a, b, c, d)
    init = jnp.zeros((a.shape[1], a.shape[2]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1]


tdma_jax_notrans = jax.jit(tdma_jax_kernel_notrans, backend='cpu')
tdma_jax_cuda_notrans = jax.jit(tdma_jax_kernel_notrans, backend='gpu')

In [48]:
at, bt, ct, dt = (k.T for k in (a, b, c, d))

In [49]:
%%timeit
tdma_jax_notrans(at, bt, ct, dt).block_until_ready()

175 ms ± 3.51 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [50]:
aj, bj, cj, dj = (jnp.array(k.T).block_until_ready() for k in (a, b, c, d))

In [51]:
%%timeit
tdma_jax_cuda_notrans(aj, bj, cj, dj).block_until_ready()

6.2 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
