In [1]:
import xarray as xr 
import numpy as np

from dask import delayed
from tqdm import tqdm
from dask.diagnostics import ProgressBar
import dask
from dask import delayed, compute
import glob
from dask.distributed import Client

import zarr
from numcodecs import Blosc

import gc

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))


from wofscast.data_generator import (load_wofscast_data, 
                                    wofscast_data_generator, 
                                    wofscast_batch_generator, 
                                    to_static_vars,
                                    add_local_solar_time
                                    
                                    )
from wofscast.wofscast_task_config import WOFS_TASK_CONFIG, train_lead_times, TARGET_VARS
from wofscast import data_utils
import dataclasses

from dask.distributed import performance_report


In [2]:
def read_mfnetcdfs_dask(paths, dim, transform_func=None, load=True):
    """Read multiple NetCDF files into memory, using Dask for parallel loading."""
    # Absolutely, crucial to set threads_per_worker=1!!!!
    # https://forum.access-hive.org.au/t/netcdf-not-a-valid-id-errors/389/19
    #To summarise in this thread, it looks like a work-around in netcdf4-python to deal 
    #with netcdf-c not being thread safe was removed in 1.6.1. 
    #The solution (for now) is to make sure your cluster only uses 1 thread per worker.

    dataset = xr.open_mfdataset(paths, concat_dim=dim, combine='nested',
                                parallel=True, preprocess=transform_func,
                                chunks={"time": 4}, )  # Adjust the chunking strategy as needed
    if load:
        with ProgressBar():
            loaded_dataset= dataset.compute()
        return loaded_dataset

    return dataset 

In [3]:
n_files = 1024
idx = 128
data_paths = glob.glob('/work/mflora/wofs-cast-data/datasets/2019/*.nc')

In [4]:
client = Client(n_workers=16, threads_per_worker=1) 
client     

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 16
Total threads: 16,Total memory: 503.42 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:44767,Workers: 16
Dashboard: http://127.0.0.1:8787/status,Total threads: 16
Started: Just now,Total memory: 503.42 GiB

0,1
Comm: tcp://127.0.0.1:38043,Total threads: 1
Dashboard: http://127.0.0.1:36573/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:45973,
Local directory: /tmp/dask-scratch-space/worker-06s07m0e,Local directory: /tmp/dask-scratch-space/worker-06s07m0e

0,1
Comm: tcp://127.0.0.1:46429,Total threads: 1
Dashboard: http://127.0.0.1:35931/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:41311,
Local directory: /tmp/dask-scratch-space/worker-54sfb3o7,Local directory: /tmp/dask-scratch-space/worker-54sfb3o7

0,1
Comm: tcp://127.0.0.1:36099,Total threads: 1
Dashboard: http://127.0.0.1:33023/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:40883,
Local directory: /tmp/dask-scratch-space/worker-eej4y3x5,Local directory: /tmp/dask-scratch-space/worker-eej4y3x5

0,1
Comm: tcp://127.0.0.1:44393,Total threads: 1
Dashboard: http://127.0.0.1:43559/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:41067,
Local directory: /tmp/dask-scratch-space/worker-uz7p5vfw,Local directory: /tmp/dask-scratch-space/worker-uz7p5vfw

0,1
Comm: tcp://127.0.0.1:46755,Total threads: 1
Dashboard: http://127.0.0.1:34605/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:34731,
Local directory: /tmp/dask-scratch-space/worker-55ym0fwf,Local directory: /tmp/dask-scratch-space/worker-55ym0fwf

0,1
Comm: tcp://127.0.0.1:34415,Total threads: 1
Dashboard: http://127.0.0.1:37517/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39111,
Local directory: /tmp/dask-scratch-space/worker-jmps06r0,Local directory: /tmp/dask-scratch-space/worker-jmps06r0

0,1
Comm: tcp://127.0.0.1:37707,Total threads: 1
Dashboard: http://127.0.0.1:42373/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:36475,
Local directory: /tmp/dask-scratch-space/worker-sdna5_zt,Local directory: /tmp/dask-scratch-space/worker-sdna5_zt

0,1
Comm: tcp://127.0.0.1:39919,Total threads: 1
Dashboard: http://127.0.0.1:40011/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:46843,
Local directory: /tmp/dask-scratch-space/worker-w8jzhpnx,Local directory: /tmp/dask-scratch-space/worker-w8jzhpnx

0,1
Comm: tcp://127.0.0.1:38443,Total threads: 1
Dashboard: http://127.0.0.1:35281/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38765,
Local directory: /tmp/dask-scratch-space/worker-c74_b51b,Local directory: /tmp/dask-scratch-space/worker-c74_b51b

0,1
Comm: tcp://127.0.0.1:46745,Total threads: 1
Dashboard: http://127.0.0.1:45175/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38769,
Local directory: /tmp/dask-scratch-space/worker-8b65doad,Local directory: /tmp/dask-scratch-space/worker-8b65doad

0,1
Comm: tcp://127.0.0.1:43071,Total threads: 1
Dashboard: http://127.0.0.1:37367/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:35355,
Local directory: /tmp/dask-scratch-space/worker-2bdqu3ol,Local directory: /tmp/dask-scratch-space/worker-2bdqu3ol

0,1
Comm: tcp://127.0.0.1:39151,Total threads: 1
Dashboard: http://127.0.0.1:40675/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:38761,
Local directory: /tmp/dask-scratch-space/worker-gy1t8g49,Local directory: /tmp/dask-scratch-space/worker-gy1t8g49

0,1
Comm: tcp://127.0.0.1:37755,Total threads: 1
Dashboard: http://127.0.0.1:34381/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:42065,
Local directory: /tmp/dask-scratch-space/worker-xc4re3ef,Local directory: /tmp/dask-scratch-space/worker-xc4re3ef

0,1
Comm: tcp://127.0.0.1:35335,Total threads: 1
Dashboard: http://127.0.0.1:33527/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:33183,
Local directory: /tmp/dask-scratch-space/worker-qx8c67l8,Local directory: /tmp/dask-scratch-space/worker-qx8c67l8

0,1
Comm: tcp://127.0.0.1:44955,Total threads: 1
Dashboard: http://127.0.0.1:34715/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:40245,
Local directory: /tmp/dask-scratch-space/worker-hgy0mg41,Local directory: /tmp/dask-scratch-space/worker-hgy0mg41

0,1
Comm: tcp://127.0.0.1:39051,Total threads: 1
Dashboard: http://127.0.0.1:42057/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:34201,
Local directory: /tmp/dask-scratch-space/worker-9y_6woty,Local directory: /tmp/dask-scratch-space/worker-9y_6woty


In [5]:
%%time

# 4096 files -> 69GBs
# CPU times: user 3min 21s, sys: 28.1 s, total: 3min 49s
# Wall time: 16min 15s

#1024 files -> 17GB
#CPU times: user 39.3 s, sys: 5.07 s, total: 44.4 s
#Wall time: 2min 24s 

# Using 12 workers, 
# Using 16 workers, 1min 38s
# Using 32 workers, 1min 23s 
# Using 64 workers, no bueno 

files = data_paths[idx:idx + n_files]
dataset = read_mfnetcdfs_dask(files, dim='batch', transform_func=add_local_solar_time, load=False) 
    
# Initialize Blosc compressor with LZ4 for speed
compressor = Blosc(cname='lz4', clevel=5, shuffle=Blosc.BITSHUFFLE)
    
dataset.to_zarr('/work/mflora/wofs-cast-data/wofcast_dataset_test1.zarr', 
                    mode='w', 
                   consolidated=True, 
                   encoding={var: {'compressor': compressor} for var in dataset.variables})

CPU times: user 55.6 s, sys: 8.6 s, total: 1min 4s
Wall time: 1min 37s


<xarray.backends.zarr.ZarrStore at 0x1453a692a340>

In [6]:
client.close()

2024-04-17 11:09:43,564 - distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.
Traceback (most recent call last):
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/site-packages/distributed/comm/tcp.py", line 225, in read
    frames_nosplit_nbytes_bin = await stream.read_bytes(fmt_size)
tornado.iostream.StreamClosedError: Stream is closed

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/site-packages/distributed/worker.py", line 1252, in heartbeat
    response = await retry_operation(
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/site-packages/distributed/utils_comm.py", line 455, in retry_operation
    return await retry(
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/site-packages/distributed/utils_comm.py", line 434, in retry
    return await coro()
  File "/work/mflora/miniconda3/envs/wofs-ca

In [7]:
def wofscast_data_generator(path='/work/mflora/wofs-cast-data/wofcast_dataset.zarr', 
                            train_lead_times=train_lead_times,
                            batch_chunk_size=256, 
                    client=None, 
                    task_config=None):

    with xr.open_dataset(path, 
                         engine='zarr', 
                         consolidated=True, 
                         chunks={'batch' : batch_chunk_size}
                        ) as ds:
        
        total_samples = len(ds.batch)
        total_batches = total_samples // batch_chunk_size + (1 if total_samples % batch_chunk_size > 0 else 0)
    
        for batch_num in tqdm(range(total_batches), desc='Loading Data..'):
            start_idx = batch_num * batch_chunk_size
            end_idx = min((batch_num + 1) * batch_chunk_size, total_samples)
            batch_indices = slice(start_idx, end_idx)  # Use slice for more efficient indexing
        
            print(batch_num, batch_indices) 
        
            # Load this batch into memory. 
            this_batch = ds.isel(batch=batch_indices)
        
            inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(
                this_batch,
                target_lead_times=train_lead_times,
                **dataclasses.asdict(task_config)
            )
        
            inputs = to_static_vars(inputs)
        
            inputs = inputs.transpose('batch', 'time', 'lat', 'lon', 'level')
            targets = targets.transpose('batch', 'time', 'lat', 'lon', 'level')
            forcings = forcings.transpose('batch', 'time', 'lat', 'lon')

            inputs, targets, forcings = dask.compute(inputs, targets, forcings)
            
            yield inputs, targets, forcings