In [24]:
# Build Data Assimilation Dataset 
from wrfout_file_formatter import FileFormatter

import numpy as np
import itertools
from os.path import join

def get_file_paths(base_path, years, dates, init_times, ens_mem_rng=np.arange(1,19)):
    
    base_path_dict = {year: join(base_path, year) for year in years}
    
    _years = ['2019']
    
    dir_paths = [join(base_path_dict[date[:4]], date, init_time, f'ENS_MEM_{mem}') if date[:4] in _years else
                 join(base_path_dict[date[:4]], date, init_time, f'ENS_MEM_{mem:02d}')
                 for date, init_time, mem in itertools.product(dates, init_times, ens_mem_rng)]
    
    return dir_paths 
    
    


def DataAssimDataBuilder(FileFormatter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def create_filename_from_path(self, path):
        pass
    
    def load_dataset(self, path, preprocess_fn, drop_variables):
        kwargs = {'decode_times' : False, 
                  'chunks' : {}, 
                  'drop_variables' : drop_variables
                 }
        
        if '.zarr' in data_paths[0]:
            engine = 'zarr'
            kwargs['consolidated'] = True
        else:
            engine = 'netcdf4'
        
        kwargs['engine'] = engine
        kwargs['preprocess'] = preprocess
        
        dataset = xr.open_dataset(data_paths, **kwargs)
        
        if preprocess_fn:
            dataset = preprocess_fn(dataset)
        
        return dataset
    
    
    def process(self, data_paths, drop_vars, lat_1d, lon_1d):
        drop_vars += ['XLAT', 'XLONG', 'XTIME']
        
        # Perform initial error checking and abort
        # early if needed. 
        if len(data_paths) == 0:
            return "Did not process, no files!"
        
        if len(data_paths) != self.n_expected_files-1:
            print(data_paths[0], 'Not enough time files, passing...')
            return 'Not enough time files, passing...'
        
        # The WRF zarr files are already processed, so 
        # minimally additional processing is needed. 
        is_zarr = '.zarr' in data_paths[0]
        
        fname = self.create_filename_from_path(data_paths)
        year = self.get_year_from_path(data_paths[0])
        
        out_path = os.path.join(self.out_path, year, fname)
    
        if not self.overwrite:
            if os.path.exists(out_path) and not self.debug:
                return "File already processed!"
        
        # Lazily load the data. 
        # Add the forcings variables, time of year, time of day, TOA radiation 
        # Add a batch, datetime coordinate for the time of day and year
        # Must be applied for each sample separately. 
        def preprocess(ds):
            # Using the GraphCast code, compute the day and year progress 
            # and their cos/sin variants. Also, using self-developed code
            # compute instanteous top-of-the-atmosphere radiaton flux.
            # Must be applied to a single example, so adding it to the 
            # preprocessing of the xr.open_mfdataset.
            # Compute this before altering the latitude and longitude!
            ds = ds.rename({'Time' : 'time'})
            ds = self.add_batch_and_datetime_coords(ds)
            ds = add_derived_vars(ds)
            ds = TOARadiationFlux().add_toa_radiation(ds.isel(batch=0))
            ds = self.add_batch_dim(ds) 
            
            return ds
        
        if self.legacy:
            preprocess = None
        
        ds = self.load_dataset(data_paths, preprocess_fn, drop_vars)
        
        if is_zarr:
            # Add the level coordinate 
            level_values = np.arange(ds.dims['level'])
            ds = ds.assign_coords(level=("level", level_values))
            if self.legacy:
                ds = ds.rename({'Time': 'time'})
        
        else:
            ds = self.reset_negative_water_vapor(ds)
        
            # Combine geopotential perturbation + base state
            # The WRF zarr files already have geopotential 
            # height computed. 
            ds = self.compute_full_geopot(ds)
        
            # Destagger the wind and geopotential fields 
            ds = self.destagger(ds)
            
            # Add 300. to make it properly Kelvins, so we can convert to deg C/F. 
            ds['T']+=300. 
        
            # Renaming coordinate variables to align with the ERA5 naming convention.
            ds = self.rename_coords(ds)
             
        if 'subset_vertical_levels' in self.processes: 
            # Subset the vertical levels (every N layers) and reset the coordinate. 
            ds = ds.isel(level=ds.level[::3])
            ds.coords['level'] = np.arange(ds.dims['level'])
        
        
        # Assign the 2D versions of 'xlat' and 'xlon' back to the dataset as coordinates
        # Latitude and longitude are expected to be 1d vectors. 
        ds = ds.assign_coords(lat=lat_1d, lon=lon_1d)
        
        if self.legacy:
            # Shift negative longitude by 180 
            ds['lon'] = xr.where(ds['lon'] < 0, ds['lon'] + 180, ds['lon'])
        else:
            ds = self.convert_to_fully_positive_longitude(ds)
        
        # Deprecated, but keeping for legacy at the moment
        # Add the 'time' coordinate and dimension
        if self.legacy:
            ds = self.add_time_dim(ds, data_paths)
        
        # Unaccumulate rainfall
        ds = self.unaccum_rainfall(ds)
        
        if 'resize' in self.processes: 
            ds = self.resize(ds)

        #if self.debug:
        #    print(f"Processed result for {out_path}")
        #    return ds 
        
        compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=zarr.Blosc.SHUFFLE)

        # Set encoding for each variable to use the specified compressor
        encoding = {var: {'compressor': compressor} for var in ds.data_vars}

        ###ds.to_zarr(out_path, mode='w', encoding=encoding, consolidated=True)
        
        if self.verbose > 0:
            print(f'Saving {out_path}')
            
        return f"Processed result for {out_path}"    
            

In [25]:
base_path = '/work2/wofs_zarr/'
years = ['2019']
dates = ['20190502']
init_times = ['0200', '0100']
ens_mem_rng = [1,2] #np.arange(1,19)

paths = get_file_paths(base_path, years, dates, init_times, ens_mem_rng)
    


In [27]:
from glob import glob
glob('/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/*.zarr')

['/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_02:20:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_07:40:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_06:15:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_05:35:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_03:10:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_05:50:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_02:45:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_07:25:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_04:05:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_02:50:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof_d01_2019-05-03_07:30:00.zarr',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1/wrfwof

In [26]:
paths

['/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_1',
 '/work2/wofs_zarr/2019/20190502/0200/ENS_MEM_2',
 '/work2/wofs_zarr/2019/20190502/0100/ENS_MEM_1',
 '/work2/wofs_zarr/2019/20190502/0100/ENS_MEM_2']

In [None]:
data_builder = DataAssimDataBuilder(
    n_jobs = 1
)