## Step 1. Generate a Kerchunk JSON for the NetCDF files

In [1]:
%%time
import xarray as xr
import fsspec
import os
from dask import delayed
import dask

import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))
from wofscast.data_generator import add_local_solar_time
import wofscast.my_graphcast as graphcast
from wofscast import data_utils
import dataclasses


import glob
import ujson
from kerchunk.hdf import SingleHdf5ToZarr


def ensure_json_ext(filename: str) -> str:
    """
    Ensure the given filename ends with '.json' and remove any other extensions.
    If the filename does not have '.json', append '.json' to the root name.

    Args:
        filename (str): The original filename.

    Returns:
        str: The filename with only a '.json' extension.
    """
    # Split the filename to remove its existing extension (if any)
    root_name, _ = os.path.splitext(filename)
    
    # Add `.json` as the extension
    return f"{root_name}.json"

def gen_json(u, output_dir="/work/mflora/wofs-cast-data/jsons/", 
             original_dir = '/work2/wof/realtime/FCST/'):

    # File system options
    so = dict(
        mode="rb", anon=True, default_fill_cache=False,
        default_cache_type="none"
    )

    # Open the NetCDF file and generate JSON
    try:
        with fsspec.open(u, **so) as inf:
            h5chunks = SingleHdf5ToZarr(inf, u, inline_threshold=300)
            output_path = ensure_json_ext(u.replace(original_dir, output_dir))
            # Ensure output directory exists
            if not os.path.exists(os.path.dirname(output_path)):
                os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            with open(output_path, 'wb') as outf:
                outf.write(ujson.dumps(h5chunks.translate()).encode())
            return f"Generated JSON for {output_path}"
        
    except Exception as e:
        print(f"Failed to generate JSON for {u}: {e}")


def open_mfdataset_batch(paths, batch_chunk_size, concat_dim='batch'):
    """Using kerchunking, individual zarr or netcdf files are represented by individual json files. 
    We can then trick xarray into believing the jsons are individual zarr files and treat them as
    one file. This function uses dask.delayed to lazily load the individual jsons in parallel and 
    then re-chunks based on the batch chunk size for efficiently batch loading. 
    
    paths : list of paths : paths to a set of forecasts for a given ensemble member 
    
    """
    @delayed
    def load_dataset_from_json(json_path):
        """Load a dataset from a Kerchunk JSON descriptor."""
        # Using fsspec to create a mapper from the JSON reference
        mapper = fsspec.get_mapper('reference://', fo=json_path, remote_protocol='file')
        # Load the dataset using xarray with the Zarr engine
        ds = xr.open_dataset(mapper, engine='zarr', consolidated=False, chunks={}, decode_times=False)
        
        ##ds = add_local_solar_time(ds)
        
        return ds

    def load_and_concatenate(json_files, concat_dim='batch'):
        """Load multiple datasets from JSON files and concatenate them along a specified dimension."""
        # Load each dataset using Dask delayed and collect them in a list
        datasets = [load_dataset_from_json(json_file) for json_file in json_files]
    
        # Use Dask to compute the list of datasets
        datasets = dask.compute(*datasets)
    
        # Concatenate all datasets along the specified dimension
        combined_dataset = xr.concat(datasets, dim=concat_dim)
    
        for ds in datasets:
            ds.close() 
    
        return combined_dataset


    datasets_per_time = [load_and_concatenate(p, 
                                     concat_dim='Time') for p in paths] 
    
    
    dataset = xr.concat(datasets_per_time, dim='batch')
    
    dataset = dataset.rename({'Time' : 'time'})
    
    dataset = dataset.chunk({'batch': batch_chunk_size})
    
    return dataset 


CPU times: user 2.27 s, sys: 12.6 s, total: 14.8 s
Wall time: 1.16 s


In [2]:
paths1 = glob.glob('/work/mflora/wofs-cast-data/jsons/2021/20210409/0200/ENS_MEM_01/wrfwof*')
paths2 = glob.glob('/work/mflora/wofs-cast-data/jsons/2021/20210409/0200/ENS_MEM_01/wrfwof*')

paths1.sort()
paths2.sort()

paths = [paths1, paths2]


In [3]:
import numpy as np
from tqdm import tqdm

# the number of gridpoints in one direction; square domain.
DOMAIN_SIZE = 300

VARS_3D = ['U', 'V', 'W', 'T', 'GEOPOT', 'QVAPOR']
VARS_2D = ['T2', 'COMPOSITE_REFL_10CM', 'UP_HELI_MAX']
STATIC_VARS = ['XLAND', 'HGT']

INPUT_VARS = VARS_3D + VARS_2D + STATIC_VARS
TARGET_VARS = VARS_3D + VARS_2D

# I compute this myself rather than using the GraphCast code. 
FORCING_VARS = (
            'SWDOWN'
        )

# Not pressure levels, but just vertical array indices at the moment. 
# When I created the wrfwof files, I pre-sampled every 3 levels. 
PRESSURE_LEVELS = np.arange(50)

# Loads data from the past 20 minutes (2 steps) and 
# creates a target over the next 10-60 min. 
INPUT_DURATION = '10min'
train_lead_times = '5min'

task_config = graphcast.TaskConfig(
      input_variables=INPUT_VARS,
      target_variables=TARGET_VARS,
      forcing_variables=FORCING_VARS,
      pressure_levels=PRESSURE_LEVELS,
      input_duration=INPUT_DURATION,
      n_vars_2D = len(VARS_2D),
      domain_size = DOMAIN_SIZE,
      tiling=None
 )

def wofscast_data_generator(paths, 
                            train_lead_times, 
                            task_config,
                            batch_chunk_size=256, 
                            client=None,):
    
    with open_mfdataset_batch(paths, batch_chunk_size) as ds:
   
            total_samples = len(ds.batch)
            total_batches = total_samples // batch_chunk_size + (1 if total_samples % batch_chunk_size > 0 else 0)
    
            for batch_num in tqdm(range(total_batches), desc='Loading Zarr Batch..'):
                start_idx = batch_num * batch_chunk_size
                end_idx = min((batch_num + 1) * batch_chunk_size, total_samples)
                batch_indices = slice(start_idx, end_idx)  # Use slice for more efficient indexing
        
                # Load this batch into memory. 
                this_batch = ds.isel(batch=batch_indices)
        
                inputs, targets, forcings = data_utils.batch_extract_inputs_targets_forcings(
                    this_batch,
                    n_input_steps=2,
                    n_target_steps=1,
                    target_lead_times=train_lead_times,
                    **dataclasses.asdict(task_config)
                )
        
                inputs = to_static_vars(inputs)
        
                inputs = inputs.transpose('batch', 'time', 'lat', 'lon', 'level')
                targets = targets.transpose('batch', 'time', 'lat', 'lon', 'level')
                forcings = forcings.transpose('batch', 'time', 'lat', 'lon')

                inputs, targets, forcings = dask.compute(inputs, targets, forcings)
            
                yield inputs, targets, forcings 


In [4]:
for batch_input, batch_target, batch_forcings in wofscast_data_generator(paths, 
                            '5min', 
                            task_config,
                            batch_chunk_size=8, 
                            client=None,):
    pass

Loading Zarr Batch..:   0%|                                                                                            | 0/1 [00:00<?, ?it/s]


TypeError: Concatenation operation is not implemented for NumPy arrays, use np.concatenate() instead. Please do not rely on this error; it may not be given on all Python implementations.

In [None]:
batch_input