In [3]:
import xarray as xr
import fsspec
import dask 

import numpy as np
import dataclasses

import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))
from wofscast.data_generator import add_local_solar_time, to_static_vars 
from wofscast import data_utils


class WoFSCastDataGenerator:
    """
    Generates batches from a WoFS dataset for training/testing machine learning models.
    
    Attributes:
        task_config: TaskConfig object, including variables, pressure levels, etc.
                     Defined in graphcast_lam.py 
        cpu_batch_size : int : Number of samples to preload into CPU memory 
        gpu_batch_size : int : Numbe of samples sent to GPU at one time. 
        seed : int : Random seed for the shuffling the dataset. 
    """
    
    def __init__(self, task_config, cpu_batch_size=512, gpu_batch_size=32, seed=123):
        
        self.task_config = task_config
        self.cpu_batch_size = cpu_batch_size
        self.gpu_batch_size = gpu_batch_size 
        
        # Set the seed for reproducibility
        np.random.seed(seed)
    
    def __call__(self, dataset):
        """
        Args:
            dataset : xarray.Dataset : lazily loaded dataset using open_mfdataset
    
        Yields:
            inputs, targets, forcings : xarray.Datasets 
    
        Batcher for an xarray dataset using xbatcher. Useful for storing the full dataset in 
        CPU RAM and then offloading small subsets to the GPU RAM batch by batch.
        """
        dims = ('batch', 'time', 'lat', 'lon', 'level')

        total_samples = dataset.sizes['batch']
        indices = np.arange(total_samples)
             
        outer_start = 0

        while outer_start < total_samples:
            outer_end = min(outer_start + self.cpu_batch_size, total_samples)
            chunk_indices = indices[outer_start:outer_end]

            # Preload the chunk of batches
            chunk = dataset.isel(batch=chunk_indices).compute()

            inner_start = 0
            while inner_start < len(chunk_indices):
                inner_end = min(inner_start + self.gpu_batch_size, len(chunk_indices))
                batch = chunk.isel(batch=slice(inner_start, inner_end))

                inputs, targets, forcings = data_utils.extract_inputs_targets_forcings(
                    batch,
                    target_lead_times=self.task_config.train_lead_times,
                    **dataclasses.asdict(self.task_config)
                )
            
                inputs = to_static_vars(inputs)
            
                inputs = inputs.transpose(*dims, missing_dims='ignore')
                targets = targets.transpose(*dims, missing_dims='ignore')
                forcings = forcings.transpose(*dims, missing_dims='ignore')
            
                yield inputs, targets, forcings

                inner_start = inner_end

            outer_start = outer_end
    
    def is_nested_list(self, lst):
        """
        Check if a list is a nested list (i.e., contains other lists as elements).
    
        Args:
            lst (list): The list to check.
        
        Returns:
            bool: True if the list is a nested list, False otherwise.
        """
        return any(isinstance(i, list) for i in lst)
    
    def open_mfdataset(self, paths, concat_time=False):
        """
        Open multiple files as a single xarray dataset using Kerchunk JSON descriptors.
        
        Args:
            paths (list of str): List of paths to the JSON files.
            concat_time : bool : Whether to concatenate along a time dimension 
                                 before concatenating along a batch dimension. 
                                 Must pass a nested list. 
        
        Returns:
            xarray.Dataset: The combined dataset.
        """
        @dask.delayed
        def load_dataset_from_json(json_path):
            """Load a dataset from a Kerchunk JSON descriptor."""
            mapper = fsspec.get_mapper('reference://', fo=json_path, remote_protocol='file')
            ds = xr.open_dataset(mapper, engine='zarr', consolidated=False, chunks={}, decode_times=False)
            ds = add_local_solar_time(ds)
            return ds

        def load_and_concatenate(json_files, concat_dim):
            """Load multiple datasets from JSON files and concatenate them along a specified dimension."""
            datasets = [load_dataset_from_json(json_file) for json_file in json_files]
            datasets = dask.compute(*datasets)
            combined_dataset = xr.concat(datasets, dim=concat_dim)
            for ds in datasets:
                ds.close()
            return combined_dataset

        if concat_time:
            if not self.is_nested_time(paths):
                raise ValueError('paths must be a nested list if concatenating along a time dimension.')
            datasets_per_time = [load_and_concatenate(p, concat_dim='Time') for p in paths]
            dataset = xr.concat(datasets_per_time, dim='batch')
            dataset = dataset.rename({'Time': 'time'})
        else:
            dataset = load_and_concatenate(paths, concat_dim='batch')

        total_samples = dataset.sizes['batch']
        indices = np.random.permutation(total_samples)  # Shuffle indices
        
        # Reorder the dataset based on shuffled indices
        dataset = dataset.isel(batch=indices)    
            
        dataset = dataset.chunk({'batch': self.gpu_batch_size})
        
        return dataset

In [4]:
%%time 

from wofscast.wofscast_task_config import DBZ_TASK_CONFIG

n_epoches = 1

directory = '/work/mflora/wofs-cast-data/datasets_jsons/2019'
files = os.listdir(directory)[:128]
paths = [os.path.join(directory, file) for file in files]

generator = WoFSCastDataGenerator(DBZ_TASK_CONFIG, cpu_batch_size=128, gpu_batch_size=32)
dataset = generator.open_mfdataset(paths)

j=0
for inputs, targets, forcings in generator(dataset):
    print(f'Batch : {j}')
    j+=1
        

# Dataset with 128 samples, prefetch_steps=4, batch_size=8
#CPU times: user 20.9 s, sys: 8.99 s, total: 29.9 s
#Wall time: 12.4 s

Batch : 0
Batch : 1
Batch : 2
Batch : 3
CPU times: user 19.8 s, sys: 6.4 s, total: 26.2 s
Wall time: 9.7 s
