In [1]:
%cd /home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence

/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence


In [2]:
import os

NCAR_JHF_BASE_PATH = '/home/idies/workspace/turbulence-ceph-staging/ncar-jhf'

NCAR_JHF_HR_PATH = os.path.join(NCAR_JHF_BASE_PATH, 'hr')
NCAR_JHF_LR_PATH = os.path.join(NCAR_JHF_BASE_PATH, 'lr')

NCAR_FILES_EXTENSION = '.nc'

In [3]:
"""
import packages
"""
import os
import sys
import zarr
import itertools
import contextlib
import numpy as np
from tqdm.auto import tqdm
from zarr.storage import DirectoryStore
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor

In [4]:
"""
parameters
"""
dataset_title = 'stsabl2048low'
store_path = f"/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/{dataset_title}/{dataset_title}.zarr"
store = DirectoryStore(store_path, dimension_separator = '/')
# time offset from 0 when querying the dataset with giverny.
time_offset = 1
# offset from the first time chunk to write data into the zarr store on ceph.
# if this value = 0, that means time 0 is written to the zarr time chunk folder = 0.
# if this value = 1, that means time 0 is written to the zarr time chunk folder = 1. this is needed to keep a placeholder time chunk folder for pchip interpolation
# because the precursor time cannot be read by giverny getCutout and will have to be read and written manually.
time_ceph_offset = 1
# xyz and time dimensions for the full dataset including times for pchip interpolation that cannot be read by giverny.
xyzt_dims_full = [2048, 2048, 2048, 20]
# xyz and time dimensions for the dataset to be read with giverny.
xyzt_dims = [2048, 2048, 2048, 20] 
# xyz and time chunk sizes for the zarr store on ceph.
xyzt_chunk_sizes = [64, 64, 64, 1]
# xyz and time dimensions to query in parallel when reading from the legacy stores.
xyzt_filedb_file_dims = np.array([512, 512, 512, 1])
# use the default stride value of 1 for each axis when using giverny to retrieve a cutout of the data.
strides = [1, 1, 1]
# map the zarr group variables to the number of values stored for each grid point.
zarr_groups = {
    'velocity': 3,
    'pressure': 1,
    'temperature': 1,
    'energy': 1
}
zarr_variables = list(zarr_groups.keys())
num_variables = len(zarr_variables)
# number of workers to read in parallel.
num_workers = 8
# maximum number of retries in case of an error.
max_retries = 10
# output path for writing the report text file.
output_path = os.path.join('/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/reports/', dataset_title)

In [5]:
"""
create the zarr store
"""
def create_zarr_store(store, xyzt_dims_full, xyzt_chunk_sizes, zarr_groups):
    dt = np.dtype(np.float32)
    dt = dt.newbyteorder('<')
    
    # create the zarr store.
    root = zarr.group(store = store, overwrite = True, synchronizer = None)
    
    # create the zarr group for each variable.
    for variable_name in zarr_groups:
        zarr_dims = zarr_groups[variable_name]
        
        zarr_group = root.zeros(variable_name, shape = (xyzt_dims_full[3], xyzt_dims_full[2], xyzt_dims_full[1], xyzt_dims_full[0], zarr_dims), 
                                chunks = (xyzt_chunk_sizes[3], xyzt_chunk_sizes[2], xyzt_chunk_sizes[1], xyzt_chunk_sizes[0], zarr_dims), 
                                dtype = dt, compressor = None)
    
    print('zarr store created.')
    print('-')
    sys.stdout.flush()

create_zarr_store(store, xyzt_dims_full, xyzt_chunk_sizes, zarr_groups)

zarr store created.
-


In [12]:
import dask.array as da

def merge_velocities(transposed_ds, chunk_size_base=64):
    """
        Merge the 3 velocity components/directions - such merging
        exhibits faster 3-component reads. This is a Dask lazy
         computation
    """

    # Merge Velocities into 1
    b = da.stack([transposed_ds['u'], transposed_ds['v'], transposed_ds['w']], axis=3)
    b = b.squeeze()  # It should be (2048, 2048, 2048, 3, 1) before this. Use (2048, 2048, 2048, 3)
    # Make into correct chunk sizes
    b = b.rechunk((chunk_size_base, chunk_size_base, chunk_size_base, 3))  # Dask chooses (64,64,64,1)
    result = transposed_ds.drop_vars(['u', 'v', 'w'])  # Drop individual velocities

    # Add joined velocity to original group
    # Can't make the dim name same as scalars
    result['velocity'] = xr.DataArray(b, dims=(
        'nnz', 'nny', 'nnx', 'velocity component (xyz)'))

    return result

In [None]:
"""
read JHTDB datasets from fileDB using giverny and then write to a zarr store on ceph
"""
import xarray as xr
import traceback
import dask

# open the zarr store using DirectoryStore.
root = zarr.open(store, mode = 'a')

def process_cube(coords):
    try:
        """
        TODO : Ariel change code here for reading in the new NCAR data without giverny.
        """
        
        # Initialize dataset and open zarr store.
        data_xr = xr.open_dataset(os.path.join(NCAR_JHF_LR_PATH, 'jhf.000.nc'),
                                  chunks={'nnz': xyzt_filedb_file_dims[0], 'nny': xyzt_filedb_file_dims[1],
                                          'nnx': xyzt_filedb_file_dims[2]})
    
        assert isinstance(data_xr['e'].data, dask.array.core.Array)

        # Add an extra dimension to the data to match isotropic8192
        expanded_ds = data_xr.expand_dims({'extra_dim': [1]}).drop_vars('extra_dim')
        # Put the extra dimension in the back
        transposed_ds = expanded_ds.transpose('nnz', 'nny', 'nnx', 'extra_dim')

        # Group 3 velocity components together
        merged_velocity = merge_velocities(transposed_ds, chunk_size_base=xyzt_chunk_sizes[0])

        # Rename variables
        merged_velocity = merged_velocity.rename({'e': 'energy', 't': 'temperature', 'p': 'pressure'})

        dims = [dim for dim in data_xr.dims]
        dims.reverse()  # Use (nnz, nny, nnx) instead of (nnx, nny, nnz)
        
        dataset = data_xr

        x, y, z = [coord * xyzt_filedb_file_dims[index] + 1 for index, coord in enumerate(coords[:3])]
        time = coords[3]
        ranges = [[x, x + xyzt_filedb_file_dims[0] - 1],
                  [y, y + xyzt_filedb_file_dims[1] - 1],
                  [z, z + xyzt_filedb_file_dims[2] - 1]]
            
        variable_data = []
        for zarr_variable in zarr_variables:
            variable_data.append(getCutout(dataset, zarr_variable, time + time_offset, np.array(ranges), np.array(strides), verbose=False).to_array().to_numpy()[0])
            
        """
        TODO : End of code for reading in data.
        """

        def save_store(giverny_cube, variable_name):
            root[variable_name][time + time_ceph_offset,
                                z - 1 : z + xyzt_filedb_file_dims[2] - 1,
                                y - 1 : y + xyzt_filedb_file_dims[1] - 1,
                                x - 1 : x + xyzt_filedb_file_dims[0] - 1, :] = giverny_cube
            
        def verify_copy(giverny_cube, variable_name):
            ceph_cube = root[variable_name][time + time_ceph_offset,
                                            z - 1 : z + xyzt_filedb_file_dims[2] - 1,
                                            y - 1 : y + xyzt_filedb_file_dims[1] - 1,
                                            x - 1 : x + xyzt_filedb_file_dims[0] - 1, :]
            
            if np.all(giverny_cube == ceph_cube):
                return "valid copy"
            else:
                return "corrupt copy"

        # Save the variables in parallel.
        with ThreadPoolExecutor(num_variables) as p:
            list(p.map(save_store, variable_data, zarr_variables))
            
        # Verify that the copies are not corrupt.
        with ThreadPoolExecutor(num_variables) as p:
            verified = list(p.map(verify_copy, variable_data, zarr_variables))

        if any([message == "corrupt copy" for message in verified]) or len(verified) != num_variables:
            return f"error processing cube at {coords}: verification failed"
        else:
            return f"successfully processed cube at {coords}"
    except Exception as e:
        tb = traceback.format_exc()
        raise Exception(f"error processing cube at {coords}: {type(e).__name__}, {e}\n{tb}")
        
# Commented out the suppress_stderr context manager
# @contextlib.contextmanager
# def suppress_stderr():
#     with open(os.devnull, 'w') as devnull:
#         with contextlib.redirect_stderr(devnull):
#             yield

# Write the report file.
if not os.path.exists(output_path):
    os.mkdir(output_path)
    
# Process cubes in parallel.
cube_coords = list(itertools.product(*[range(dim // chunk) for dim, chunk in zip(xyzt_dims[:3], xyzt_filedb_file_dims[:3])]))
    
current_time_pbar = tqdm(total=len(cube_coords), desc="chunks completed", leave=False)
    
with open(os.path.join(output_path, f"{dataset_title}_report-success.txt"), "w") as f_success:
    with open(os.path.join(output_path, f"{dataset_title}_report-error.txt"), "w") as f_error:
        for time in tqdm(range(xyzt_dims[-1]), total=xyzt_dims[-1], desc="time"):
            correct_flag = False
            retries = 0
            while not correct_flag and retries <= max_retries:
                retries += 1
                
                # Update cube_coords with the time.
                cube_coords_time = [(x, y, z, time) for x, y, z in cube_coords]
                
                results = []
                try:
                    with ProcessPoolExecutor(num_workers) as executor:
                        # Removed suppress_stderr()
                        for result in executor.map(process_cube, cube_coords_time):
                            results.append(result)
                            current_time_pbar.update(1)
                            
                            if "error processing cube" in result:
                                break
                    
                    # Reset the chunk progress bar.
                    current_time_pbar.reset()
                    
                    print('\r' + ' ' * 100 + '\r', end='', flush=True)
                    if any(["error processing cube" in message for message in results]) or results == []:
                        if retries > max_retries:
                            output_str = f"error consolidating data for time = {time}\n"
                            f_error.write(output_str)
                            f_error.flush()

                            print(f"\rerror consolidating data for time = {time}", end='')
                        else:
                            print(f"\rretrying (n = {retries}) time = {time}", end='')
                            continue
                    else:
                        output_str = f"successfully consolidated data for time = {time}\n"
                        f_success.write(output_str)
                        f_success.flush()

                        print(f"\rsuccessfully consolidated data for time = {time}", end='')

                        correct_flag = True
                except Exception as e:
                    tb = traceback.format_exc()
                    print('\r' + ' ' * 100 + '\r', end='', flush=True)
                    
                    # Reset the chunk progress bar.
                    current_time_pbar.reset()
                    
                    error_message = f"code exception when consolidating data for time = {time}: {type(e).__name__}, {e}\n{tb}\n"
                    f_error.write(error_message)
                    f_error.flush()

                    print(f"\rcode exception when consolidating data for time = {time}: {e}", end='')
                    if retries > max_retries:
                        break
                    else:
                        print(f"\rretrying (n = {retries}) time = {time}", end='')
                        continue

print('\n-')
print("completed zarr consolidation, check report files")

HBox(children=(HTML(value='chunks completed'), FloatProgress(value=0.0, max=64.0), HTML(value='')))

HBox(children=(HTML(value='time'), FloatProgress(value=0.0, max=20.0), HTML(value='')))