In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from scipy.sparse import coo_matrix
import dask.array as da

from numba import jit, prange, guvectorize

# Preparation

In [2]:
# sparse matrix
ds = xr.open_dataset("weights.nc")
n_s = ds.dims['n_s']
col = ds['col'].values - 1
row = ds['row'].values - 1
S = ds['S'].values

In [3]:
# input data
data = np.random.rand(500, 240000)

# Reference

In [4]:
@jit(nopython=True, nogil=True)
def sparse_dot(data_out, data, col, row, S):
    for j in range(data.shape[0]):
        for i in range(S.size):
            data_out[j, row[i]] += data[j, col[i]]*S[i]

@jit(nopython=True, nogil=True, parallel=True)
def sparse_dot_pa(data_out, data, col, row, S):
    for j in prange(data.shape[0]):
        for i in range(S.size):
            data_out[j, row[i]] += data[j, col[i]]*S[i]

# cannot use nopython mode to create array
@jit(nogil=True)
def apply_A(data, parallel=False):
    if parallel:
        func = sparse_dot_pa
    else:
        func = sparse_dot
    data_out = np.zeros([data.shape[0], 120000])
    func(data_out, data, col, row, S)    # use global col, row, S here
    return data_out

out_numba = apply_A(data) # reference result

In [5]:
# reference performance
%timeit apply_A(data)
%timeit apply_A(data, parallel=True) # obvious speed-up with parallelization

697 ms ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
315 ms ± 3.86 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# guvectorize

In [6]:
# shapes of input arguments that guvectorize would want to know
data.shape, col.shape, row.shape, S.shape, out_numba.shape 

((500, 240000), (480000,), (480000,), (480000,), (500, 120000))

In [7]:
# the output grid size 120000 is not contained in any input arguments,
# so I just do something silly: create a dummy array to pass the shape info
shape_arr = np.arange(120000)

In [8]:
@guvectorize(["void(float64[:], int32[:], int32[:], float64[:], int64[:], float64[:])"],
             "(n),(k),(k),(k),(m)->(m)",
             target='parallel')
def sparse_dot_guvec(data, col, row, S, shape_arr, data_out):
    # only one loop for the grid dimension
    # let numba vectorize over extra dimension
    for i in range(S.size):
        data_out[row[i]] += data[col[i]]*S[i]
        
out_guvec = sparse_dot_guvec(data, col, row, S, shape_arr)

In [9]:
# result is correct
np.array_equal(out_numba, out_guvec)

True

Doesn't seem to be faster than mannually parallelizing with `prange`. So we will still use the previous simple way.

In [10]:
%timeit sparse_dot_guvec(data, col, row, S, shape_arr)

352 ms ± 4.99 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Apply numba function on dask array

In [11]:
data_dask = da.from_array(data, chunks=(50, 240000))
data_dask

dask.array<array, shape=(500, 240000), dtype=float64, chunksize=(50, 240000)>

In [12]:
dr_dask = xr.DataArray(data_dask, 
                       dims=['extra_dims', 'grid_dims'],
                       name='data')
dr_dask

<xarray.DataArray 'data' (extra_dims: 500, grid_dims: 240000)>
dask.array<shape=(500, 240000), dtype=float64, chunksize=(50, 240000)>
Dimensions without coordinates: extra_dims, grid_dims

In [13]:
dr_out_pa = xr.apply_ufunc(apply_A, dr_dask, 
                           input_core_dims=[['grid_dims']],
                           output_core_dims=[['out_grid']],
                           output_sizes={'out_grid': 120000},
                           dask='parallelized', 
                           output_dtypes=[float])

In [14]:
# result is correct
np.array_equal(dr_out_pa, out_numba)

True

The serial case is slower than the pure numpy version, and the parallel efficiency is not too great.

In [15]:
for n in [1, 2, 4]:
    %timeit dr_out_pa.compute(num_workers=n)

1.12 s ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
809 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
772 ms ± 20.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Much faster without concatenating. But still not comparable with Numba.

In [16]:
for n in [1, 2, 4]:
    %timeit dr_out_pa.persist(num_workers=n)

872 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
550 ms ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
511 ms ± 14.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
