In [1]:
from dask.distributed import Client

In [2]:
client = Client() # n_workers=2, threads_per_worker=2, memory_limit='500MB'

2022-10-21 15:34:27,650 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-ihwf7yzu', purging


In [None]:
import dask.array as da
import xarray as xr
import numpy as np

In [None]:
# Mimic all steps in preprocessing with dask arrays

# Generate fake inputs
da.random.RandomState(42)
sp = da.random.random((24, 641, 1440), chunks=(1, -1, -1))
q = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))
u = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))
v = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))

# Get dp
dp = sp[:, None, ...]

# Calculate cwv
cwv = dp * q

# Calculate fluxes
fx = u * cwv
fy = v * cwv

# Aggregate to two layers
s_lower = cwv[:, :10, ...].sum(axis=1)
s_upper = cwv[:, 10:, ...].sum(axis=1)

fx_lower = fx[:, :10, ...].sum(axis=1)
fx_upper = fx[:, 10:, ...].sum(axis=1)

fy_lower = fy[:, :10, ...].sum(axis=1)
fy_upper = fy[:, 10:, ...].sum(axis=1)

In [None]:
%%time
# Final output to netcdf with xarray
dims3 = ['time', 'latitude', 'longitude']
dims4 = ['time', 'level', 'latitude', 'longitude']
xr.Dataset({
    's_upper': xr.DataArray(s_upper, dims=dims3),
    's_lower': xr.DataArray(s_lower, dims=dims3),
    'fx_upper': xr.DataArray(s_upper, dims=dims3),
    'fx_lower': xr.DataArray(s_lower, dims=dims3),
    'fy_upper': xr.DataArray(s_upper, dims=dims3),
    'fy_lower': xr.DataArray(s_lower, dims=dims3),
    
}).to_netcdf('/data/volume_2/test_peter/test3.nc')

CPU times: user 650 ms, sys: 98.4 ms, total: 749 ms
Wall time: 11.1 s


In [None]:
# %%time
# # Now with xarray (current implementation)

# # Generate fake inputs
# xrsp = xr.DataArray(sp, dims=dims3, name='sp')
# xrq = xr.DataArray(q, dims=dims4, name='q')
# xru = xr.DataArray(q, dims=dims4, name='u')
# xrv = xr.DataArray(q, dims=dims4, name='v')

# # Get dp
# xrdp = xrsp.expand_dims(level=range(22), axis=1)

# # Calculate cwv
# xrcwv = xrdp * xrq

# # Calculate fluxes 
# xrfx = xrcwv * xru
# xrfy = xrcwv * xrv

# # Aggregate to two layers
# lower = xrdp.level > 10
# xrs_upper = xrcwv.where(~lower).sum('level')
# xrs_lower = xrcwv.where(lower).sum('level')

# xrfx_upper = xrfx.where(~lower).sum('level')
# xrfx_lower = xrfx.where(lower).sum('level')

# xrfy_upper = xrfy.where(~lower).sum('level')
# xrfy_lower = xrfy.where(lower).sum('level')

# # Realize data to disk
# xr.Dataset({
#     's_upper': xrs_upper,
#     's_lower': xrs_lower,
#     'fx_upper': xrfx_upper,
#     'fx_lower': xrfx_lower,
#     'fy_upper': xrfy_upper,
#     'fy_lower': xrfy_lower,
# }).to_netcdf('/data/volume_2/test_peter/test.nc')

In [None]:
%%time
# Now with xarray (optimized)

# Generate fake inputs
xrsp = xr.DataArray(sp, dims=dims3, name='sp')
xrq = xr.DataArray(q, dims=dims4, name='q')
xru = xr.DataArray(q, dims=dims4, name='u')
xrv = xr.DataArray(q, dims=dims4, name='v')

# Get dp
xrdp = xrsp.expand_dims(level=range(22), axis=1)

# Calculate cwv
xrcwv = xrdp * xrq

# Calculate fluxes 
xrfx = xrcwv * xru
xrfy = xrcwv * xrv

# Aggregate to two layers
idx = xrdp.level.searchsorted(10, side='right')
upper = np.s_[:, :idx, :, :]
lower = np.s_[:, idx:, :, :]

xrs_upper = xrcwv[upper].sum('level')
xrs_lower = xrcwv[lower].sum('level')

xrfx_upper = xrfx[upper].sum('level')
xrfx_lower = xrfx[lower].sum('level')

xrfy_upper = xrfy[upper].sum('level')
xrfy_lower = xrfy[lower].sum('level')

# Realize data to disk
xr.Dataset({
    's_upper': xrs_upper,
    's_lower': xrs_lower,
    'fx_upper': xrfx_upper,
    'fx_lower': xrfx_lower,
    'fy_upper': xrfy_upper,
    'fy_lower': xrfy_lower,
}).to_netcdf('/data/volume_2/test_peter/test2.nc')

CPU times: user 922 ms, sys: 142 ms, total: 1.06 s
Wall time: 19.2 s


In [None]:
before = xr.open_dataset('/data/volume_2/test_peter/test.nc')
after =  xr.open_dataset('/data/volume_2/test_peter/test2.nc')
before == after

In [None]:
np.all(before==after)