In [1]:
import xarray as xr 
import numpy as np

from dask import delayed
from tqdm import tqdm
from dask.diagnostics import ProgressBar
import dask
from dask import delayed, compute
import glob
from dask.distributed import Client

import zarr
from numcodecs import Blosc

import gc

import sys, os 
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))


from wofscast.data_generator import (load_wofscast_data, 
                                    wofscast_data_generator, 
                                    wofscast_batch_generator, 
                                    to_static_vars,
                                    add_local_solar_time
                                    
                                    )
from wofscast.wofscast_task_config import WOFS_TASK_CONFIG, train_lead_times, TARGET_VARS
from wofscast import data_utils
import dataclasses

from dask.distributed import performance_report


In [2]:
def read_mfnetcdfs_dask(paths, dim, transform_func=None, load=True):
    """Read multiple NetCDF files into memory, using Dask for parallel loading."""
    # Absolutely, crucial to set threads_per_worker=1!!!!
    # https://forum.access-hive.org.au/t/netcdf-not-a-valid-id-errors/389/19
    #To summarise in this thread, it looks like a work-around in netcdf4-python to deal 
    #with netcdf-c not being thread safe was removed in 1.6.1. 
    #The solution (for now) is to make sure your cluster only uses 1 thread per worker.

    dataset = xr.open_mfdataset(paths, concat_dim=dim, combine='nested',
                                parallel=True, preprocess=transform_func,
                                chunks={"time": 4}, )  # Adjust the chunking strategy as needed
    if load:
        with ProgressBar():
            loaded_dataset= dataset.compute()
        return loaded_dataset

    return dataset 

In [3]:
n_files = 8192
idx = 128
data_paths = glob.glob('/work/mflora/wofs-cast-data/datasets/2019/*.nc')

In [4]:
client = Client(n_workers=16, threads_per_worker=1) 
client     

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 16
Total threads: 16,Total memory: 503.42 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:38253,Workers: 16
Dashboard: http://127.0.0.1:8787/status,Total threads: 16
Started: Just now,Total memory: 503.42 GiB

0,1
Comm: tcp://127.0.0.1:33013,Total threads: 1
Dashboard: http://127.0.0.1:36185/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:35547,
Local directory: /tmp/dask-scratch-space/worker-kuz091fp,Local directory: /tmp/dask-scratch-space/worker-kuz091fp

0,1
Comm: tcp://127.0.0.1:41849,Total threads: 1
Dashboard: http://127.0.0.1:38319/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39913,
Local directory: /tmp/dask-scratch-space/worker-xsk1z6ss,Local directory: /tmp/dask-scratch-space/worker-xsk1z6ss

0,1
Comm: tcp://127.0.0.1:32809,Total threads: 1
Dashboard: http://127.0.0.1:34947/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:44599,
Local directory: /tmp/dask-scratch-space/worker-uu8hcs0x,Local directory: /tmp/dask-scratch-space/worker-uu8hcs0x

0,1
Comm: tcp://127.0.0.1:44733,Total threads: 1
Dashboard: http://127.0.0.1:37235/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:36459,
Local directory: /tmp/dask-scratch-space/worker-ubb8745s,Local directory: /tmp/dask-scratch-space/worker-ubb8745s

0,1
Comm: tcp://127.0.0.1:45287,Total threads: 1
Dashboard: http://127.0.0.1:44881/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:40115,
Local directory: /tmp/dask-scratch-space/worker-4skoqzwi,Local directory: /tmp/dask-scratch-space/worker-4skoqzwi

0,1
Comm: tcp://127.0.0.1:40461,Total threads: 1
Dashboard: http://127.0.0.1:37859/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:32861,
Local directory: /tmp/dask-scratch-space/worker-bvxjv_qy,Local directory: /tmp/dask-scratch-space/worker-bvxjv_qy

0,1
Comm: tcp://127.0.0.1:34269,Total threads: 1
Dashboard: http://127.0.0.1:45153/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:34031,
Local directory: /tmp/dask-scratch-space/worker-aic5i4li,Local directory: /tmp/dask-scratch-space/worker-aic5i4li

0,1
Comm: tcp://127.0.0.1:35539,Total threads: 1
Dashboard: http://127.0.0.1:42089/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:37485,
Local directory: /tmp/dask-scratch-space/worker-jax8bx4s,Local directory: /tmp/dask-scratch-space/worker-jax8bx4s

0,1
Comm: tcp://127.0.0.1:36493,Total threads: 1
Dashboard: http://127.0.0.1:37831/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39263,
Local directory: /tmp/dask-scratch-space/worker-u6zhqu8s,Local directory: /tmp/dask-scratch-space/worker-u6zhqu8s

0,1
Comm: tcp://127.0.0.1:35047,Total threads: 1
Dashboard: http://127.0.0.1:40997/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39591,
Local directory: /tmp/dask-scratch-space/worker-r4xzty46,Local directory: /tmp/dask-scratch-space/worker-r4xzty46

0,1
Comm: tcp://127.0.0.1:41553,Total threads: 1
Dashboard: http://127.0.0.1:45871/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:39299,
Local directory: /tmp/dask-scratch-space/worker-723tmqhn,Local directory: /tmp/dask-scratch-space/worker-723tmqhn

0,1
Comm: tcp://127.0.0.1:39823,Total threads: 1
Dashboard: http://127.0.0.1:41347/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:44267,
Local directory: /tmp/dask-scratch-space/worker-ev0qr4ki,Local directory: /tmp/dask-scratch-space/worker-ev0qr4ki

0,1
Comm: tcp://127.0.0.1:45139,Total threads: 1
Dashboard: http://127.0.0.1:35425/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:43209,
Local directory: /tmp/dask-scratch-space/worker-f7a43jdw,Local directory: /tmp/dask-scratch-space/worker-f7a43jdw

0,1
Comm: tcp://127.0.0.1:39501,Total threads: 1
Dashboard: http://127.0.0.1:43243/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:43813,
Local directory: /tmp/dask-scratch-space/worker-vl48z6h_,Local directory: /tmp/dask-scratch-space/worker-vl48z6h_

0,1
Comm: tcp://127.0.0.1:45597,Total threads: 1
Dashboard: http://127.0.0.1:45471/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:36405,
Local directory: /tmp/dask-scratch-space/worker-zntlpxg7,Local directory: /tmp/dask-scratch-space/worker-zntlpxg7

0,1
Comm: tcp://127.0.0.1:40421,Total threads: 1
Dashboard: http://127.0.0.1:42097/status,Memory: 31.46 GiB
Nanny: tcp://127.0.0.1:37005,
Local directory: /tmp/dask-scratch-space/worker-pe3dkile,Local directory: /tmp/dask-scratch-space/worker-pe3dkile




In [5]:
%%time

# 8192 files -> GBs



# 4096 files -> 69GBs
# CPU times: user 3min 21s, sys: 28.1 s, total: 3min 49s
# Wall time: 16min 15s

# 1024 files -> 17GB
# CPU times: user 39.3 s, sys: 5.07 s, total: 44.4 s
# Wall time: 2min 24s 

# Using 12 workers, 
# Using 16 workers, 1min 38s
# Using 32 workers, 1min 23s 
# Using 64 workers, no bueno 

def read_mfnetcdfs_dask(paths, dim, transform_func=None, load=True):
    """Read multiple NetCDF files into memory, using Dask for parallel loading."""
    # Absolutely, crucial to set threads_per_worker=1!!!!
    # https://forum.access-hive.org.au/t/netcdf-not-a-valid-id-errors/389/19
    #To summarise in this thread, it looks like a work-around in netcdf4-python to deal 
    #with netcdf-c not being thread safe was removed in 1.6.1. 
    #The solution (for now) is to make sure your cluster only uses 1 thread per worker.

    dataset = xr.open_mfdataset(paths, concat_dim=dim, combine='nested',
                                parallel=True, preprocess=transform_func,
                                chunks={"time": 4}, )  # Adjust the chunking strategy as needed
    if load:
        with ProgressBar():
            loaded_dataset= dataset.compute()
        return loaded_dataset

    return dataset 

n_files = 8192
idx = 128
data_paths = glob.glob('/work/mflora/wofs-cast-data/datasets/2019/*.nc')

files = data_paths[idx:idx + n_files]
dataset = read_mfnetcdfs_dask(files, dim='batch', transform_func=add_local_solar_time, load=False) 
    
# Initialize Blosc compressor with LZ4 for speed
compressor = Blosc(cname='lz4', clevel=5, shuffle=Blosc.BITSHUFFLE)
    
dataset.to_zarr(f'/work/mflora/wofs-cast-data/wofcast_dataset_n_samples-{n_files}.zarr', 
                    mode='w', 
                   consolidated=True, 
                   encoding={var: {'compressor': compressor} for var in dataset.variables})

This may cause some slowdown.
Consider scattering data ahead of time and using futures.
2024-05-09 10:47:45,836 - distributed.worker_memory - ERROR - 
Traceback (most recent call last):
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/site-packages/distributed/compatibility.py", line 236, in asyncio_run
    return loop.run_until_complete(main)
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/asyncio/base_events.py", line 1871, in _run_once
    event_list = self._selector.select(timeout)
  File "/work/mflora/miniconda3/envs/wofs-cast/lib/python3.10/selectors.py", line 469, in select
    fd_event_list = self._selector.poll(timeout, max_ev)
KeyboardInterrupt

During handling of the above excep

2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
Process Dask Worker process (from Nanny):
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
2024-05-09 10:47:47,815 - distributed.nanny - ERROR - Worker process died unexpectedly
Process Dask Worker process (from Nanny):
Process Dask Worker process (from Nanny):
Process Dask Worker process (from Nanny):
Process Dask Worker process (from Nanny):
Process Da

2024-05-09 10:47:55,451 - distributed.worker - ERROR - Scheduler was unaware of this worker 'tcp://127.0.0.1:35047'. Shutting down.


KeyboardInterrupt: 

In [8]:
client.close()

In [7]:
def wofscast_data_generator(path='/work/mflora/wofs-cast-data/wofcast_dataset.zarr', 
                            train_lead_times=train_lead_times,
                            batch_chunk_size=256, 
                    client=None, 
                    task_config=None):

    with xr.open_dataset(path, 
                         engine='zarr', 
                         consolidated=True, 
                         chunks={'batch' : batch_chunk_size}
                        ) as ds:
        
        total_samples = len(ds.batch)
        total_batches = total_samples // batch_chunk_size + (1 if total_samples % batch_chunk_size > 0 else 0)
    
        for batch_num in tqdm(range(total_batches), desc='Loading Data..'):
            start_idx = batch_num * batch_chunk_size
            end_idx = min((batch_num + 1) * batch_chunk_size, total_samples)
            batch_indices = slice(start_idx, end_idx)  # Use slice for more efficient indexing
        
            print(batch_num, batch_indices) 
        
            # Load this batch into memory. 
            this_batch = ds.isel(batch=batch_indices)
        
            inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(
                this_batch,
                target_lead_times=train_lead_times,
                **dataclasses.asdict(task_config)
            )
        
            inputs = to_static_vars(inputs)
        
            inputs = inputs.transpose('batch', 'time', 'lat', 'lon', 'level')
            targets = targets.transpose('batch', 'time', 'lat', 'lon', 'level')
            forcings = forcings.transpose('batch', 'time', 'lat', 'lon')

            inputs, targets, forcings = dask.compute(inputs, targets, forcings)
            
            yield inputs, targets, forcings