In [1]:
import numpy as np
import xarray as xr

from scipy.sparse import coo_matrix

import dask
import dask.array as da
from dask.distributed import Client, progress
from dask_kubernetes import KubeCluster

Read weights generated by [make_large_weights.ipynb](./make_large_weights.ipynb). Too big so not included in git repo.

In [2]:
!du -h bilinear_1001x2000_1200x2400.nc

176M	bilinear_1001x2000_1200x2400.nc


In [3]:
# read regridding weights from disk
ds = xr.open_dataset("bilinear_1001x2000_1200x2400.nc")
n_s = ds.dims['n_s']
col = ds['col'].values - 1
row = ds['row'].values - 1
S = ds['S'].values
A = coo_matrix((S, (row, col)), shape=[2880000, 2002000]) 
A.shape

(2880000, 2002000)

In [4]:
def apply_A(data, A):
    # make sure this is a pure function without external data
    return A.dot(data.T).T

# Single machine

In [5]:
# will read from cloud object storage in real cases
x = np.ones([100, A.shape[1]])
x.nbytes / 1e9 # GB

1.6016

In [6]:
%time apply_A(x, A)

CPU times: user 3.08 s, sys: 1.97 s, total: 5.06 s
Wall time: 5.05 s


array([[0., 0., 0., ..., 1., 1., 0.],
       [0., 0., 0., ..., 1., 1., 0.],
       [0., 0., 0., ..., 1., 1., 0.],
       ...,
       [0., 0., 0., ..., 1., 1., 0.],
       [0., 0., 0., ..., 1., 1., 0.],
       [0., 0., 0., ..., 1., 1., 0.]])

# Distributed on Pangeo

In [7]:
cluster = KubeCluster(n_workers=20)
cluster

VBox(children=(HTML(value='<h2>KubeCluster</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .…

In [8]:
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://10.23.154.4:36104  Dashboard: /user/jiaweizhuang/proxy/8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [9]:
x_dask_dist = da.ones([1600, A.shape[1]], chunks=[40, -1])
x_dask_dist.nbytes / 1e9

25.6256

In [10]:
# initialize this large array in distributed memory
# never call compute (will blow-up master memory) ! 
x_dask_dist = client.persist(x_dask_dist)
progress(x_dask_dist)

VBox()

In [11]:
# manually scatter out the weights before calling regridding
A_future = client.scatter(A, broadcast=True)
progress(A_future)

VBox()

In [12]:
%%time
# very fast
# would have taken several seconds if the function contains data
out_dask_dist = da.map_blocks(apply_A, x_dask_dist, A_future,
                              dtype=np.float64, chunks=[40, A.shape[0]])

CPU times: user 2 ms, sys: 0 ns, total: 2 ms
Wall time: 1.42 ms


In [13]:
out_dask_dist.nbytes/1e9

36.864

In [14]:
# will be writing back to cloud object storage in real cases
# here just hold results in memory
out_dask_dist = client.persist(out_dask_dist)
progress(out_dask_dist)

VBox()

In [15]:
%%time
# sanity check. should get roughly 1 
out_dask_dist.mean().compute()

CPU times: user 448 ms, sys: 33 ms, total: 481 ms
Wall time: 3.53 s


0.9983340277777778