Code used to sanity-check data: plotting raw and written files, calculating various statistics

In [1]:
import os
import zarr
import dask.array as da
import matplotlib.pyplot as plt
from dask import compute
from dask.diagnostics import ProgressBar
import numpy as np
import xarray as xr

In [2]:
stored_jhf_hr_path = "/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/stsabl2048high/stsabl2048high.zarr"
raw_jhf_hr_0_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.000.nc"
raw_jhf_hr_1_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.001.nc"
raw_jhf_hr_104_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.104.nc"

In [3]:
jhf_hr_zarr = zarr.open(stored_jhf_hr_path)
jhf_hr_netcdf_t0 = xr.open_dataset(raw_jhf_hr_0_path)
jhf_hr_netcdf_t1 = xr.open_dataset(raw_jhf_hr_1_path)
jhf_hr_netcdf_t104 = xr.open_dataset(raw_jhf_hr_104_path)

In [4]:
stored_jhf_lr_path = "/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/stsabl2048low/stsabl2048low.zarr"
raw_jhf_lr_0_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.000.nc"
raw_jhf_lr_1_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.001.nc"
raw_jhf_lr_19_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.019.nc"


jhf_lr_zarr = zarr.open(stored_jhf_lr_path)
jhf_lr_netcdf_t0 = xr.open_dataset(raw_jhf_lr_0_path)
jhf_lr_netcdf_t1 = xr.open_dataset(raw_jhf_lr_1_path)
jhf_lr_netcdf_t19 = xr.open_dataset(raw_jhf_lr_19_path)

<font color="cyan">

# Remember, data is saved in `nnz-nny-nnx`
    
    
</font>

# Quick-Verify Correctness of data

1. Check if `data == 0`

2. Pick one $64^3$ chunk and compare it to raw NetCDF

### High-Rate

#### `zarr.info`

In [None]:
jhf_hr_zarr

In [None]:
jhf_hr_zarr['energy'].info

In [None]:
jhf_hr_zarr['velocity'].info

In [None]:
jhf_hr_zarr.info

#### Indexing, Compare to 0

In [None]:
print("Checking whether field all zeros - True is bad!")

for t in range(105):
    print("t=", t, " - ", np.all(jhf_hr_zarr['temperature'][t,:64,:64,:64,0] == 0))

#### Comparing Real Values

In [None]:
jhf_hr_zarr['temperature'][0,:10,0,0,0]

In [None]:
jhf_lr_zarr['temperature'][0,:10,0,0,0]

In [None]:
jhf_hr_zarr['temperature'][1,:10,0,0,0]

In [10]:
for t in range(93, 105):
    zarr_comparison_data = jhf_hr_zarr['temperature'][t,:64,:64,:64,0]

    raw_t_path = f"/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.{t:03d}.nc"
    raw_xr = xr.open_dataset(raw_t_path)
    raw_t = raw_xr['t'].isel(nnx=slice(0, 64), nny=slice(0, 64), nnz=slice(0, 64)).values
    
    print("t: ", t, np.all(zarr_comparison_data == raw_t))

t:  97 True
t:  98 True
t:  99 True
t:  100 True
t:  101 True
t:  102 True
t:  103 True
t:  104 True


In [None]:
jhf_hr_zarr['temperature'][0,0,0,:10,0]

In [None]:
jhf_hr_zarr['energy'][0,:10,0,0,0]

In [None]:
np.array(jhf_hr_netcdf_t0['e'][:10,0,0])

Xarray complains about missing metadata. NOT Fixed by GPT o1

`written_ds_xr = xr.open_dataset(stored_jhf_path, engine='zarr', consolidated=False)`

In [None]:
import numpy as np

np.all(ds['velocity'][0, :10,0,0, 0] == 0)

In [None]:
import numpy as np

np.all(ds['velocity'][100, :10,0,0, 0] == 0)

In [None]:
import numpy as np

np.all(ds['energy'][100, :10,0,0, 0] == 0)

### Low-Rate

#### Comparing Real Values

In [None]:
jhf_lr_zarr['temperature'][0,:10,0,0,0]

In [None]:
np.array(jhf_lr_netcdf_t0['t'][:10,0,0])

# Efficient Zarr full data Slices

This can take a few minutes

 Sciserver doesn't allow localhost connections, so can't use Dask Cluster console Sciserver doesn't allow localhost connections, so can't use Dask Cluster console
 
<font color="green"> Lazy loading speeds up reading times 10x+</font>

In [11]:
import os
import zarr
import dask.array as da
import matplotlib.pyplot as plt
from dask import compute
from dask.diagnostics import ProgressBar

plt.rcParams['image.cmap'] = 'inferno'

def process_zarr(stored_jhf_path, dataset):
    if dataset not in ['hr', 'lr']:
        raise ValueError("Dataset must be 'hr' or 'lr'.")

    # Open the Zarr group
    store = zarr.open_group(stored_jhf_path, mode='r')

    # Create Dask arrays
    data_arrays = {}
    variables = list(store.array_keys())
    for var_name in variables:
        zarr_array = store[var_name]
        dask_array = da.from_zarr(zarr_array)
        # Squeeze scalar variables
        if dask_array.shape[-1] == 1:
            dask_array = dask_array.squeeze(axis=-1)
        data_arrays[var_name] = dask_array

    # Function to collect data slices
    def collect_data_slices(data_arrays, variables, timesteps):
        data_slices = []
        titles = []
        
        for variable in variables:
            for timestep in timesteps:
                dask_array = data_arrays[variable]
                
                # Get slices along each dimension (assuming shape = [time, Z, Y, X] or similar)
                x_slice = dask_array[timestep, :, :, 0]
                y_slice = dask_array[timestep, :, 0, :]
                z_slice = dask_array[timestep, 0, :, :]
                
                # For vector variables, handle components
                if variable == 'velocity':
                    for component in range(3):
                        x_comp_slice = x_slice[..., component]
                        y_comp_slice = y_slice[..., component]
                        z_comp_slice = z_slice[..., component]
                        
                        data_slices.extend([x_comp_slice, y_comp_slice, z_comp_slice])
                        titles.extend([
                            f"{variable} (component {component}) nnx=0 Timestep={timestep}",
                            f"{variable} (component {component}) nny=0 Timestep={timestep}",
                            f"{variable} (component {component}) nnz=0 Timestep={timestep}"
                        ])
                else:
                    data_slices.extend([x_slice, y_slice, z_slice])
                    titles.extend([
                        f"{variable} nnx=0 Timestep={timestep}",
                        f"{variable} nny=0 Timestep={timestep}",
                        f"{variable} nnz=0 Timestep={timestep}"
                    ])
        return data_slices, titles

    # Choose timesteps: here every 5th until the end
    timesteps = range(0, data_arrays['energy'].shape[0], 5)  

    # Collect data slices
    data_slices, titles = collect_data_slices(data_arrays, ['energy', 'temperature', 'pressure', 'velocity'], timesteps)

    # Compute all data slices at once (lazy -> single compute call)
    with ProgressBar():
        computed_slices = compute(*data_slices)

    # Plot (and now save) all images
    # Create a master folder for all slices
    master_folder = os.path.join("zarr_slices", dataset)
    os.makedirs(master_folder, exist_ok=True)

    idx = 0
    for variable in ['energy', 'temperature', 'pressure', 'velocity']:
        # Make a folder per variable
        var_folder = os.path.join(master_folder, variable)
        os.makedirs(var_folder, exist_ok=True)
        
        # We'll have 3 slices per timestep if scalar,
        # or 9 slices per timestep if velocity (3 components * 3 slices).
        # So we figure out how many slices belong to each variable:
        if variable == 'velocity':
            slices_per_timestep = 9
        else:
            slices_per_timestep = 3
        
        for timestep in timesteps:
            # For velocity, we handle 9 images; for others, 3.
            subset = computed_slices[idx : idx + slices_per_timestep]
            subset_titles = titles[idx : idx + slices_per_timestep]
            
            for data_slice, title in zip(subset, subset_titles):
                plt.figure()
                plt.imshow(data_slice)
                plt.title(title)
                plt.gca().invert_yaxis()
                plt.colorbar()
                
                # Create a filename for saving
                safe_title = title.replace(" ", "_").replace("=", "_")
                filename = os.path.join(var_folder, f"Timestep_{timestep}_{safe_title}.png")
                plt.savefig(filename, dpi=150, bbox_inches='tight')
                plt.close()
            
            idx += slices_per_timestep

In [12]:
process_zarr(stored_jhf_lr_path, "lr")

[########################################] | 100% Completed | 52.64 s


# Data Statistics - Mean Temp. across Axis

## Zarr

In [5]:
import os
import dask.array as da
from dask import compute
from dask.diagnostics import ProgressBar
import matplotlib.pyplot as plt
import numpy as np

def plot_slices_and_mean_z(
    ds_zarr_group, 
    variables, 
    time_start=0, 
    time_stop=105, 
    time_step=5, 
    z_step=128,
    master_slice_folder="zarr_slices",
    mean_output_folder="mean_temperature_plots"
):
    """
    1) Creates a Dask array for each variable from the Zarr group ds_zarr_group.
    2) Gathers slice arrays (X=0, Y=0, Z=0) for each variable/time.
    3) Gathers the mean-temp-across-Z (every z_step along Z).
    4) Performs exactly one .compute() to load everything.
    5) Plots and saves images.

    Parameters
    ----------
    ds_zarr_group : zarr.hierarchy.Group
        A Zarr group, e.g. from zarr.open_group(...), with arrays named by `variables`.
        Each array shape typically [time, z, y, x, 1 or 3].
    variables : list of str
        E.g. ["energy", "temperature", "pressure", "velocity"]
    time_start, time_stop, time_step : int
        We will process times in range(time_start, time_stop, time_step).
        Make sure time_stop <= ds['temperature'].shape[0].
    z_step : int
        Step size along Z dimension for the mean calculation (and plot).
    """
    # -------------------------------------------------------------------------
    # 1) Convert each Zarr array to a Dask array (lazy)
    #    ds_zarr_group[var] is a zarr Array; we want da.Array for each var
    # -------------------------------------------------------------------------
    data_arrays = {}
    for var in variables:
        zarr_array = ds_zarr_group[var]  # shape e.g. (105, 2048, 2048, 2048, 1)
        dask_array = da.from_zarr(zarr_array)  # now a Dask array, fully lazy
        # If last dim == 1, squeeze it out
        if dask_array.shape[-1] == 1:
            dask_array = dask_array.squeeze(axis=-1)  # shape => (time, z, y, x)
        data_arrays[var] = dask_array

    # -------------------------------------------------------------------------
    # 2) Build a range of timesteps as slice objects (NOT a for-loop yet)
    # -------------------------------------------------------------------------
    # We'll do time slicing using slice(time_start, time_stop, time_step).
    # For example, if time_start=0, time_stop=105, time_step=5 => times = 0..100
    # This is a single slice if you want standard 0,5,10,... indexing.
    # But if you truly need [0, 5, 10, ...], that's still a step slice.
    # NOTE: This requires that time_stop < shape[0]. E.g. shape[0]=105 => last valid index=104
    time_slice = slice(time_start, time_stop, time_step)
    # We'll confirm that time_stop <= data_arrays[any_var].shape[0]
    # but let's assume user does that.

    # We'll also define z-slice for the mean. E.g. z_slice = slice(None, None, z_step)
    # i.e. every z_step along that axis
    z_slice = slice(None, None, z_step)

    # -------------------------------------------------------------------------
    # 3a) Gather the Dask slice arrays for X=0, Y=0, Z=0
    # -------------------------------------------------------------------------
    # We'll store them in lists along with their eventual plot titles
    slice_arrays = []
    slice_titles = []

    for var in variables:
        arr = data_arrays[var]  # shape e.g. [time, z, y, x] (for scalars) or [time, z, y, x, 3] (if not squeezed)
        # We'll do a standard shape: (time, z, y, x) for scalars
        # For velocity, we do shape => (time, z, y, x, 3). If you have that, you'd handle differently.

        # We'll produce these slices for the times in time_slice (i.e. arr[time_slice, ...]):
        #   x=0 => arr[:, :, :, 0]
        #   y=0 => arr[:, :, 0, :]
        #   z=0 => arr[:, 0, :, :]
        # But we still want to only pick the times in [time_slice].
        # So for each slice, we do e.g. arr[time_slice, :, :, 0].
        # We'll keep them as separate arrays.

        # If it's "velocity" and has shape [time, z, y, x, 3], do 3 components
        if var == "velocity" and arr.ndim == 5:
            # arr.shape => (time, z, y, x, 3)
            # We want x=0 => arr[time_slice, :, :, 0, comp]
            # Similarly y=0 => arr[time_slice, :, 0, :, comp]
            # etc. We'll build 3 slices for each comp
            for comp in range(3):
                # x=0
                vx0 = arr[time_slice, :, :, 0, comp]
                slice_arrays.append(vx0)
                slice_titles.append(f"{var} (component {comp}) x=0")

                # y=0
                vy0 = arr[time_slice, :, 0, :, comp]
                slice_arrays.append(vy0)
                slice_titles.append(f"{var} (component {comp}) y=0")

                # z=0
                vz0 = arr[time_slice, 0, :, :, comp]
                slice_arrays.append(vz0)
                slice_titles.append(f"{var} (component {comp}) z=0")

        else:
            # Scalar variable shape => (time, z, y, x)
            # x=0
            x0 = arr[time_slice, :, :, 0]
            slice_arrays.append(x0)
            slice_titles.append(f"{var} x=0")

            # y=0
            y0 = arr[time_slice, :, 0, :]
            slice_arrays.append(y0)
            slice_titles.append(f"{var} y=0")

            # z=0
            z0 = arr[time_slice, 0, :, :]
            slice_arrays.append(z0)
            slice_titles.append(f"{var} z=0")

    # -------------------------------------------------------------------------
    # 3b) Gather the Dask array for "mean temperature across Z"
    # -------------------------------------------------------------------------
    # We'll only do this for the "temperature" variable if it exists
    mean_temp_dask = None
    if "temperature" in data_arrays:
        temp_arr = data_arrays["temperature"]  # shape [time, z, y, x]
        # We'll slice time => [time_slice], and z => [::z_step], then mean over (y,x).
        # But if we do a step for z, we can do e.g. temp_arr[time_slice, ::z_step, :, :]
        # That is standard slicing (no list-of-indices).
        # Then .mean(axis=[2,3]) => shape: [num_times, number_of_z_samples]
        stepped = temp_arr[time_slice, z_slice, :, :]
        mean_temp_dask = stepped.mean(axis=(2, 3))  # shape => [num_times, # of z_slices]

    # -------------------------------------------------------------------------
    # 4) Single .compute() for all
    # -------------------------------------------------------------------------
    # We'll combine all slice arrays + mean_temp_dask into one big list
    # and do exactly one .compute().
    to_compute = []
    to_compute.extend(slice_arrays)  # each is a Dask array of shape [num_times, ...]
    if mean_temp_dask is not None:
        to_compute.append(mean_temp_dask)

    # If there's nothing to compute, just return
    if not to_compute:
        print("No data to compute. Exiting.")
        return

    print("Building Dask graph. No data is read yet. Now calling .compute() with a ProgressBar...")
    os.makedirs(master_slice_folder, exist_ok=True)
    os.makedirs(mean_output_folder, exist_ok=True)

    with ProgressBar():
        results = compute(*to_compute)

    # results is a tuple of length = len(to_compute).
    # The first len(slice_arrays) items are the slice results,
    # the last item (if present) is mean_temp_dask.

    slice_results = results[:len(slice_arrays)]
    mean_results = results[-1] if mean_temp_dask is not None else None

    # -------------------------------------------------------------------------
    # 5) Plot the slice results
    #    - Each slice result is shape [num_times, ...], so we must loop
    #      over the timesteps to produce separate images.
    # -------------------------------------------------------------------------
    times = list(range(time_start, time_stop, time_step))  # e.g. [0,5,10,...]

    idx = 0
    for slice_result, title in zip(slice_results, slice_titles):
        # slice_result shape => [num_times, z, y?], depends if x=0 vs y=0 vs z=0
        # We want one image per time. So if slice_result.shape[0] is # of timesteps,
        # we do a for-loop:
        for i, t in enumerate(times):
            single_time_slice = slice_result[i]  # shape => e.g. [z, y], etc.

            plt.figure()
            plt.imshow(single_time_slice)
            plt.gca().invert_yaxis()
            plt.colorbar()
            plt.title(f"{title}, t={t}")

            safe_title = title.replace(" ", "_").replace("=", "_")
            fname = os.path.join(master_slice_folder, f"{safe_title}_t_{t}.png")
            plt.savefig(fname, dpi=150, bbox_inches='tight')
            plt.close()

        idx += 1

    # -------------------------------------------------------------------------
    # 6) Plot the mean temperature results
    # -------------------------------------------------------------------------
    if mean_results is not None:
        # shape => [num_times, # of z_slices]
        # We'll plot each row
        # The z_slices are from z=0..2048..z_step, so let's build that list
        z_indices_list = list(range(0, data_arrays["temperature"].shape[1], z_step))
        for i, t in enumerate(times):
            row = mean_results[i, :]  # shape [# of z_slices]

            plt.figure()
            plt.plot(z_indices_list, row, marker='o')
            plt.title(f"Mean Temperature across Z (t={t})")
            plt.xlabel("Z index")
            plt.ylabel("Avg Temp")

            fname = os.path.join(mean_output_folder, f"mean_z_temperature_t_{t}.png")
            plt.savefig(fname, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"Saved plot as {fname}")

    print("All done.")

## Across Z

### High Rate

In [None]:
# Suppose 'jhf_hr_zarr' is a zarr Group with
#  jhf_hr_zarr['energy'] shape [105, 2048, 2048, 2048, 1]
#  jhf_hr_zarr['temperature'] ...
#  jhf_hr_zarr['pressure'] ...
#  jhf_hr_zarr['velocity'] ...

variables = ["energy", "temperature", "pressure", "velocity"]

plot_slices_and_mean_z(
    ds_zarr_group=jhf_hr_zarr,
    variables=variables,
    time_start=0,
    time_stop=105,  # up to but not including 105 => indices 0..104
    time_step=5,
    z_step=128,
    master_slice_folder="zarr_slices",
    mean_output_folder="mean_temperature_plots"
)

Building Dask graph. No data is read yet. Now calling .compute() with a ProgressBar...


### Original NetCDF

In [None]:
plot_mean_temp_across_z(data_arrays, timesteps=range(0, 106, 5), data_type="original", spacing=128)

## Across X

- [ ] TODO if necessary

# Compare Zarr to NetCDF correctness

In [None]:
import os
import xarray as xr
import dask.array as da
import numpy as np

def compare_zarr_and_netcdf(
    zarr_ds, 
    netcdf_path_pattern, 
    times=range(93, 105),
    z_slice=64,  # how much of Z dimension to compare
    y_slice=64,
    x_slice=64
):
    """
    Compare temperature data between a Zarr dataset and multiple NetCDF files
    for all requested timesteps at once. Prints True/False per timestep.

    Parameters
    ----------
    zarr_ds : xarray.Dataset or dict-like of Dask arrays
        Should have zarr_ds['temperature'] as a Dask-backed array of shape
        [time, z, y, x, ...].
    netcdf_path_pattern : str
        File path pattern for NetCDF files, e.g. "path/to/jhf.*.nc". 
        We'll open them with xarray.open_mfdataset(...).
    times : iterable
        Timesteps to compare. E.g. range(93, 105).
    z_slice, y_slice, x_slice : int
        Number of grid cells in each dimension to compare
        from [0:z_slice], etc.
    """

    #---------------------
    # 1) OPEN NETCDFs
    #---------------------
    # The idea is to let xarray + dask do all the heavy lifting in parallel.
    # 'concat_dim="time"' or 'combine="nested"' depends on your file structure.
    # Adjust as needed if the times are encoded differently.
    ds_nc = xr.open_mfdataset(
        netcdf_path_pattern, 
        # If each file is a single time, and you want them stacked on "time":
        concat_dim="time",
        combine="nested",
        parallel=True,
        # It's a good idea to chunk so we don't read the entire file at once
        chunks={
            "nnz": 64, 
            "nny": 64, 
            "nnx": 64  # or whatever chunk sizes make sense
        }
    )

    #---------------------
    # 2) BUILD DASK ARRAYS
    #---------------------
    # We'll slice all requested times in one shot.
    
    # Zarr data: shape might be [time, Z, Y, X, 1] if scalar
    zarr_temp = zarr_ds['temperature'].isel(
        time=times,  # e.g. [93..105)
        z=slice(0, z_slice),
        y=slice(0, y_slice),
        x=slice(0, x_slice),
        component=0  # or .squeeze(axis=-1), if it’s always 1
    )
    
    # NetCDF data: shape might be [time, nnz, nny, nnx]
    nc_temp = ds_nc['t'].isel(
        time=times,
        nnz=slice(0, z_slice),
        nny=slice(0, y_slice),
        nnx=slice(0, x_slice)
    )

    # Both 'zarr_temp' and 'nc_temp' are still lazy Dask arrays.
    # They should align to shape: [nTimes, Z, Y, X].

    #---------------------
    # 3) LAZY COMPARISON
    #---------------------
    # Instead of np.all(...) in a loop, do a single big comparison.
    # eq => shape [nTimes, Z, Y, X], a bool Dask array
    eq = (zarr_temp == nc_temp)

    # If you want a per-timestep True/False, reduce over spatial dims only:
    # eq_per_timestep => shape [nTimes]
    eq_per_timestep = eq.all(axis=(1, 2, 3))

    #---------------------
    # 4) TRIGGER COMPUTE
    #---------------------
    # This is ONE pass that will pull all the data needed from both 
    # Zarr and NetCDF, using Dask’s parallel IO.
    result = eq_per_timestep.compute()

    #---------------------
    # 5) REPORT RESULTS
    #---------------------
    # 'result' is a boolean numpy array, one entry per requested timestep
    for idx, t in enumerate(times):
        print(f"t: {t}, match: {result[idx]}")

    # Optionally close the NetCDF dataset
    ds_nc.close()

#-------------------------------------------------------------
# USAGE EXAMPLE
#-------------------------------------------------------------
if __name__ == "__main__":
    import zarr
    import xarray as xr
    
    # Suppose you already have your zarr dataset open as `jhf_hr_zarr`:
    # jhf_hr_zarr = xarray.open_zarr("path/to/zarr_dir") 
    # or zarr.open_group(...) and wrapped in an xarray Dataset
    #
    # netcdf_path_pattern might be:
    # "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.*.nc"

    times_to_compare = range(93, 105)  # 12 timesteps
    compare_zarr_and_netcdf(
        zarr_ds=jhf_hr_zarr, 
        netcdf_path_pattern="/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.*.nc",
        times=times_to_compare,
        z_slice=64,
        y_slice=64,
        x_slice=64
    )