In [1]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:32983")
client

0,1
Connection method: Direct,
Dashboard: http://127.0.0.1:8787/status,

0,1
Comm: tcp://127.0.0.1:32983,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 12.30 GiB

0,1
Comm: tcp://127.0.0.1:32909,Total threads: 2
Dashboard: http://127.0.0.1:38357/status,Memory: 3.08 GiB
Nanny: tcp://127.0.0.1:34935,
Local directory: /tmp/dask-worker-space/worker-hijsv1ul,Local directory: /tmp/dask-worker-space/worker-hijsv1ul
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 99.11 MiB,Spilled bytes: 0 B
Read bytes: 17.61 kiB,Write bytes: 17.61 kiB

0,1
Comm: tcp://127.0.0.1:33073,Total threads: 2
Dashboard: http://127.0.0.1:35611/status,Memory: 3.08 GiB
Nanny: tcp://127.0.0.1:37875,
Local directory: /tmp/dask-worker-space/worker-hq_sevjx,Local directory: /tmp/dask-worker-space/worker-hq_sevjx
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 6.0%,Last seen: Just now
Memory usage: 99.16 MiB,Spilled bytes: 0 B
Read bytes: 34.55 kiB,Write bytes: 34.55 kiB

0,1
Comm: tcp://127.0.0.1:36819,Total threads: 2
Dashboard: http://127.0.0.1:40565/status,Memory: 3.08 GiB
Nanny: tcp://127.0.0.1:44783,
Local directory: /tmp/dask-worker-space/worker-5rj82c6s,Local directory: /tmp/dask-worker-space/worker-5rj82c6s
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 4.0%,Last seen: Just now
Memory usage: 101.06 MiB,Spilled bytes: 0 B
Read bytes: 34.46 kiB,Write bytes: 34.46 kiB

0,1
Comm: tcp://127.0.0.1:39635,Total threads: 2
Dashboard: http://127.0.0.1:34965/status,Memory: 3.08 GiB
Nanny: tcp://127.0.0.1:36105,
Local directory: /tmp/dask-worker-space/worker-iwsksn6f,Local directory: /tmp/dask-worker-space/worker-iwsksn6f
Tasks executing: 0,Tasks in memory: 0
Tasks ready: 0,Tasks in flight: 0
CPU usage: 2.0%,Last seen: Just now
Memory usage: 99.00 MiB,Spilled bytes: 0 B
Read bytes: 34.57 kiB,Write bytes: 34.57 kiB


In [2]:
import dask.array as da
import xarray as xr
import numpy as np

In [3]:
# Mimic all steps in preprocessing with dask arrays

# Generate fake inputs
da.random.RandomState(42)
sp = da.random.random((24, 641, 1440), chunks=(1, -1, -1))
q = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))
u = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))
v = da.random.random((24, 22, 641, 1440), chunks=(1, -1, -1, -1))

# Get dp
dp = sp[:, None, ...]

# Calculate cwv
cwv = dp * q

# Calculate fluxes
fx = u * cwv
fy = v * cwv

# Aggregate to two layers
s_lower = cwv[:, :10, ...].sum(axis=1)
s_upper = cwv[:, 10:, ...].sum(axis=1)

fx_lower = fx[:, :10, ...].sum(axis=1)
fx_upper = fx[:, 10:, ...].sum(axis=1)

fy_lower = fy[:, :10, ...].sum(axis=1)
fy_upper = fy[:, 10:, ...].sum(axis=1)

In [4]:
fy_upper

Unnamed: 0,Array,Chunk
Bytes,169.01 MiB,7.04 MiB
Shape,"(24, 641, 1440)","(1, 641, 1440)"
Count,9 Graph Layers,24 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 169.01 MiB 7.04 MiB Shape (24, 641, 1440) (1, 641, 1440) Count 9 Graph Layers 24 Chunks Type float64 numpy.ndarray",1440  641  24,

Unnamed: 0,Array,Chunk
Bytes,169.01 MiB,7.04 MiB
Shape,"(24, 641, 1440)","(1, 641, 1440)"
Count,9 Graph Layers,24 Chunks
Type,float64,numpy.ndarray


In [5]:
display(fy_upper.visualize())

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

In [6]:
%%time
# Final output to netcdf with xarray
dims3 = ['time', 'latitude', 'longitude']
dims4 = ['time', 'level', 'latitude', 'longitude']
xr.Dataset({
    's_upper': xr.DataArray(s_upper, dims=dims3),
    's_lower': xr.DataArray(s_lower, dims=dims3),
    'fx_upper': xr.DataArray(s_upper, dims=dims3),
    'fx_lower': xr.DataArray(s_lower, dims=dims3),
    'fy_upper': xr.DataArray(s_upper, dims=dims3),
    'fy_lower': xr.DataArray(s_lower, dims=dims3),
    
}).to_netcdf('/home/peter/WAM2layers/test3.nc')

CPU times: user 404 ms, sys: 48.8 ms, total: 453 ms
Wall time: 10.9 s


In [7]:
%%time
# Now with xarray (current implementation)

# Generate fake inputs
xrsp = xr.DataArray(sp, dims=dims3, name='sp')
xrq = xr.DataArray(q, dims=dims4, name='q')
xru = xr.DataArray(q, dims=dims4, name='u')
xrv = xr.DataArray(q, dims=dims4, name='v')

# Get dp
xrdp = xrsp.expand_dims(level=range(22), axis=1)

# Calculate cwv
xrcwv = xrdp * xrq

# Calculate fluxes 
xrfx = xrcwv * xru
xrfy = xrcwv * xrv

# Aggregate to two layers
lower = xrcwv.level > 10
xrs_upper = xrcwv.where(~lower).sum('level')
xrs_lower = xrcwv.where(lower).sum('level')

xrfx_upper = xrfx.where(~lower).sum('level')
xrfx_lower = xrfx.where(lower).sum('level')

xrfy_upper = xrfy.where(~lower).sum('level')
xrfy_lower = xrfy.where(lower).sum('level')

# # Realize data to disk
# xr.Dataset({
#     's_upper': xrs_upper,
#     's_lower': xrs_lower,
#     'fx_upper': xrfx_upper,
#     'fx_lower': xrfx_lower,
#     'fy_upper': xrfy_upper,
#     'fy_lower': xrfy_lower,
# }).to_netcdf('/home/peter/WAM2layers/test.nc')

CPU times: user 99.9 ms, sys: 0 ns, total: 99.9 ms
Wall time: 96.4 ms


In [8]:
display(xrs_upper.data.visualize())

CytoscapeWidget(cytoscape_layout={'name': 'dagre', 'rankDir': 'BT', 'nodeSep': 10, 'edgeSep': 10, 'spacingFact…

In [9]:
%%time
# Now with xarray (optimized)

# Generate fake inputs
xrsp = xr.DataArray(sp, dims=dims3, name='sp')
xrq = xr.DataArray(q, dims=dims4, name='q')
xru = xr.DataArray(q, dims=dims4, name='u')
xrv = xr.DataArray(q, dims=dims4, name='v')

# Get dp
xrdp = xrsp.expand_dims(level=range(22), axis=1)

# Calculate cwv
xrcwv = xrdp * xrq

# Calculate fluxes 
xrfx = xrcwv * xru
xrfy = xrcwv * xrv

# Aggregate to two layers
idx = xrdp.level.searchsorted(10, side='right')
upper = np.s_[:, :idx, :, :]
lower = np.s_[:, idx:, :, :]

xrs_upper = xrcwv[upper].sum('level')
xrs_lower = xrcwv[lower].sum('level')

xrfx_upper = xrfx[upper].sum('level')
xrfx_lower = xrfx[lower].sum('level')

xrfy_upper = xrfy[upper].sum('level')
xrfy_lower = xrfy[lower].sum('level')

# Realize data to disk
xr.Dataset({
    's_upper': xrs_upper,
    's_lower': xrs_lower,
    'fx_upper': xrfx_upper,
    'fx_lower': xrfx_lower,
    'fy_upper': xrfy_upper,
    'fy_lower': xrfy_lower,
}).to_netcdf('/home/peter/WAM2layers/test2.nc')

CPU times: user 442 ms, sys: 41.2 ms, total: 484 ms
Wall time: 19.6 s
