Code used to sanity-check data: plotting raw and written files, calculating various statistics

In [1]:
stored_jhf_hr_path = "/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/stsabl2048high/stsabl2048high.zarr"
raw_jhf_hr_0_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.000.nc"
raw_jhf_hr_1_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.001.nc"
raw_jhf_hr_104_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.104.nc"

In [2]:
import zarr
import xarray as xr
import numpy as np

jhf_hr_zarr = zarr.open(stored_jhf_hr_path)
jhf_hr_netcdf_t0 = xr.open_dataset(raw_jhf_hr_0_path)
jhf_hr_netcdf_t1 = xr.open_dataset(raw_jhf_hr_1_path)
jhf_hr_netcdf_t104 = xr.open_dataset(raw_jhf_hr_104_path)

In [3]:
stored_jhf_lr_path = "/home/idies/workspace/turbulence-ceph-staging/sciserver-turbulence/stsabl2048low/stsabl2048low.zarr"
raw_jhf_lr_0_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.000.nc"
raw_jhf_lr_1_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.001.nc"
raw_jhf_lr_19_path = "/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/lr/jhf.019.nc"


jhf_lr_zarr = zarr.open(stored_jhf_lr_path)
jhf_lr_netcdf_t0 = xr.open_dataset(raw_jhf_lr_0_path)
jhf_lr_netcdf_t1 = xr.open_dataset(raw_jhf_lr_1_path)
jhf_lr_netcdf_t19 = xr.open_dataset(raw_jhf_lr_19_path)

<font color='cyan'>

## Remember, index is nnz-nny-nnx. first dim is Z
    
</font>

<font color="cyan">

# Remember, data is saved in `nnz-nny-nnx`
    
    
</font>

# Quick-Verify Correctness of data

1. Check if `data == 0`

2. Pick one $64^3$ chunk and compare it to raw NetCDF

### High-Rate

#### `zarr.info`

In [None]:
jhf_hr_zarr

In [None]:
jhf_hr_zarr['energy'].info

In [None]:
jhf_hr_zarr['velocity'].info

In [None]:
jhf_hr_zarr.info

#### Indexing, Compare to 0

In [None]:
print("Checking whether field all zeros - True is bad!")

for t in range(105):
    print("t=", t, " - ", np.all(jhf_hr_zarr['temperature'][t,:64,:64,:64,0] == 0))

#### Comparing Real Values

In [None]:
jhf_hr_zarr['temperature'][0,:10,0,0,0]

In [None]:
jhf_lr_zarr['temperature'][0,:10,0,0,0]

In [None]:
jhf_hr_zarr['temperature'][1,:10,0,0,0]

In [10]:
for t in range(93, 105):
    zarr_comparison_data = jhf_hr_zarr['temperature'][t,:64,:64,:64,0]

    raw_t_path = f"/home/idies/workspace/turbulence-ceph-staging/ncar-jhf/hr/jhf.{t:03d}.nc"
    raw_xr = xr.open_dataset(raw_t_path)
    raw_t = raw_xr['t'].isel(nnx=slice(0, 64), nny=slice(0, 64), nnz=slice(0, 64)).values
    
    print("t: ", t, np.all(zarr_comparison_data == raw_t))

t:  97 True
t:  98 True
t:  99 True
t:  100 True
t:  101 True
t:  102 True
t:  103 True
t:  104 True


In [None]:
jhf_hr_zarr['temperature'][0,0,0,:10,0]

In [None]:
jhf_hr_zarr['energy'][0,:10,0,0,0]

In [None]:
np.array(jhf_hr_netcdf_t0['e'][:10,0,0])

Xarray complains about missing metadata. NOT Fixed by GPT o1

`written_ds_xr = xr.open_dataset(stored_jhf_path, engine='zarr', consolidated=False)`

In [None]:
import numpy as np

np.all(ds['velocity'][0, :10,0,0, 0] == 0)

In [None]:
import numpy as np

np.all(ds['velocity'][100, :10,0,0, 0] == 0)

In [None]:
import numpy as np

np.all(ds['energy'][100, :10,0,0, 0] == 0)

### Low-Rate

#### Comparing Real Values

In [None]:
jhf_lr_zarr['temperature'][0,:10,0,0,0]

In [None]:
np.array(jhf_lr_netcdf_t0['t'][:10,0,0])

# Efficient Zarr full data Slices

This can take a few minutes

 Sciserver doesn't allow localhost connections, so can't use Dask Cluster console Sciserver doesn't allow localhost connections, so can't use Dask Cluster console
 
<font color="green"> Lazy loading speeds up reading times 10x+</font>

In [11]:
import os
import zarr
import dask.array as da
import matplotlib.pyplot as plt
from dask import compute
from dask.diagnostics import ProgressBar

plt.rcParams['image.cmap'] = 'inferno'

def process_zarr(stored_jhf_path, dataset):
    if dataset not in ['hr', 'lr']:
        raise ValueError("Dataset must be 'hr' or 'lr'.")

    # Open the Zarr group
    store = zarr.open_group(stored_jhf_path, mode='r')

    # Create Dask arrays
    data_arrays = {}
    variables = list(store.array_keys())
    for var_name in variables:
        zarr_array = store[var_name]
        dask_array = da.from_zarr(zarr_array)
        # Squeeze scalar variables
        if dask_array.shape[-1] == 1:
            dask_array = dask_array.squeeze(axis=-1)
        data_arrays[var_name] = dask_array

    # Function to collect data slices
    def collect_data_slices(data_arrays, variables, timesteps):
        data_slices = []
        titles = []
        
        for variable in variables:
            for timestep in timesteps:
                dask_array = data_arrays[variable]
                
                # Get slices along each dimension (assuming shape = [time, Z, Y, X] or similar)
                x_slice = dask_array[timestep, :, :, 0]
                y_slice = dask_array[timestep, :, 0, :]
                z_slice = dask_array[timestep, 0, :, :]
                
                # For vector variables, handle components
                if variable == 'velocity':
                    for component in range(3):
                        x_comp_slice = x_slice[..., component]
                        y_comp_slice = y_slice[..., component]
                        z_comp_slice = z_slice[..., component]
                        
                        data_slices.extend([x_comp_slice, y_comp_slice, z_comp_slice])
                        titles.extend([
                            f"{variable} (component {component}) nnx=0 Timestep={timestep}",
                            f"{variable} (component {component}) nny=0 Timestep={timestep}",
                            f"{variable} (component {component}) nnz=0 Timestep={timestep}"
                        ])
                else:
                    data_slices.extend([x_slice, y_slice, z_slice])
                    titles.extend([
                        f"{variable} nnx=0 Timestep={timestep}",
                        f"{variable} nny=0 Timestep={timestep}",
                        f"{variable} nnz=0 Timestep={timestep}"
                    ])
        return data_slices, titles

    # Choose timesteps: here every 5th until the end
    timesteps = range(0, data_arrays['energy'].shape[0], 5)  

    # Collect data slices
    data_slices, titles = collect_data_slices(data_arrays, ['energy', 'temperature', 'pressure', 'velocity'], timesteps)

    # Compute all data slices at once (lazy -> single compute call)
    with ProgressBar():
        computed_slices = compute(*data_slices)

    # Plot (and now save) all images
    # Create a master folder for all slices
    master_folder = os.path.join("zarr_slices", dataset)
    os.makedirs(master_folder, exist_ok=True)

    idx = 0
    for variable in ['energy', 'temperature', 'pressure', 'velocity']:
        # Make a folder per variable
        var_folder = os.path.join(master_folder, variable)
        os.makedirs(var_folder, exist_ok=True)
        
        # We'll have 3 slices per timestep if scalar,
        # or 9 slices per timestep if velocity (3 components * 3 slices).
        # So we figure out how many slices belong to each variable:
        if variable == 'velocity':
            slices_per_timestep = 9
        else:
            slices_per_timestep = 3
        
        for timestep in timesteps:
            # For velocity, we handle 9 images; for others, 3.
            subset = computed_slices[idx : idx + slices_per_timestep]
            subset_titles = titles[idx : idx + slices_per_timestep]
            
            for data_slice, title in zip(subset, subset_titles):
                plt.figure()
                plt.imshow(data_slice)
                plt.title(title)
                plt.gca().invert_yaxis()
                plt.colorbar()
                
                # Create a filename for saving
                safe_title = title.replace(" ", "_").replace("=", "_")
                filename = os.path.join(var_folder, f"Timestep_{timestep}_{safe_title}.png")
                plt.savefig(filename, dpi=150, bbox_inches='tight')
                plt.close()
            
            idx += slices_per_timestep

In [12]:
process_zarr(stored_jhf_lr_path, "lr")

[########################################] | 100% Completed | 52.64 s


# Data Statistics - Mean Temp. across Axis

### Function Definitions

In [12]:
import os
import matplotlib.pyplot as plt
import dask.array as da
from dask import compute
from dask.diagnostics import ProgressBar
import numpy as np
from typing import Literal

def plot_slices_and_mean(
    ds, 
    variables, 
    timesteps, 
    data_type: Literal["zarr", "original"] = "zarr",
    spacing: int = 128,
    master_slice_folder: str = "zarr_slices",
    mean_output_folder: str = "mean_temperature_plots"
):
    """
    Combines:
      A) Slice-based plotting (X=0, Y=0, Z=0) for multiple variables
      B) Mean-temperature plotting across Z
    by building a single large Dask graph and doing one .compute()
    to keep everything lazy.

    Parameters
    ----------
    ds : dict-like or xarray-like 
        Should have ds[var] as a Dask (or NumPy) array with shape:
        - For scalars (e.g. "energy"): [time, z, y, x]
        - For vector variables (e.g. "velocity"): [time, z, y, x, 3]
    variables : list of str
        E.g. ["energy", "temperature", "pressure", "velocity"]
    timesteps : iterable
        E.g. range(0, ds["energy"].shape[0], 5)
    data_type : Literal["zarr", "original"]
        If "zarr", we assume ds[var] is Dask-backed. If "original", we assume
        real NumPy arrays (less lazy).
    spacing : int
        Step size for Z in the mean-temperature calculation
    master_slice_folder : str
        Where to save slice plots (will have subfolders for each variable).
    mean_output_folder : str
        Where to save the mean-temperature plots.
    """

    # ------------------------------------------------------------------
    # 1) PREPARE FOLDERS
    # ------------------------------------------------------------------
    os.makedirs(master_slice_folder, exist_ok=True)
    os.makedirs(mean_output_folder, exist_ok=True)

    # ------------------------------------------------------------------
    # 2) COLLECT SLICE ARRAYS (X=0, Y=0, Z=0)
    # ------------------------------------------------------------------
    # We'll store them in lists so we can do a single compute
    slice_arrays = []
    slice_titles = []
    # We'll need to remember how many slice-arrays belong to each variable+timestep
    slice_counts_per_variable_timestep = {}

    for var in variables:
        # Decide how many slices per timestep
        # - For scalar variables: 3 (X=0, Y=0, Z=0)
        # - For velocity variable: 3*3=9 (3 components × 3 slices)
        if var == "velocity":
            slices_per_ts = 9
        else:
            slices_per_ts = 3

        slice_counts_per_variable_timestep[var] = slices_per_ts

        for t in timesteps:
            arr = ds[var]

            # For scalars: arr.shape ~ [time, z, y, x]
            # For velocity: arr.shape ~ [time, z, y, x, 3]

            # Grab slices
            #   x_slice = arr[t, :, :, 0]
            #   y_slice = arr[t, :, 0, :]
            #   z_slice = arr[t, 0, :, :]
            # If velocity, there is an extra last dim for components
            if var == "velocity":
                x_slice = arr[t, :, :, 0, :]  # shape [z, y, 3]
                y_slice = arr[t, :, 0, :, :]  # shape [z, x, 3]
                z_slice = arr[t, 0, :, :, :]  # shape [y, x, 3]

                for comp in range(3):
                    slice_arrays.append(x_slice[..., comp])  
                    slice_titles.append(f"{var} (component {comp}) nnx=0 Timestep={t}")

                    slice_arrays.append(y_slice[..., comp])
                    slice_titles.append(f"{var} (component {comp}) nny=0 Timestep={t}")

                    slice_arrays.append(z_slice[..., comp])
                    slice_titles.append(f"{var} (component {comp}) nnz=0 Timestep={t}")

            else:
                # Scalar
                x_slice = arr[t, :, :, 0]  # shape [z, y]
                y_slice = arr[t, :, 0, :]  # shape [z, x]
                z_slice = arr[t, 0, :, :]  # shape [y, x]

                slice_arrays.extend([x_slice, y_slice, z_slice])
                slice_titles.extend([
                    f"{var} nnx=0 Timestep={t}",
                    f"{var} nny=0 Timestep={t}",
                    f"{var} nnz=0 Timestep={t}"
                ])

    # ------------------------------------------------------------------
    # 3) COLLECT MEAN TEMP ARRAYS ACROSS Z
    # ------------------------------------------------------------------
    # We'll do it for "temperature" only, so let's confirm it's in variables
    # (If you have a different variable name for temperature, adapt accordingly.)
    if "temperature" in ds:
        lazy_mean_arrays = []
        # We'll store the timesteps in order so we can line up results
        mean_t_indices = []

        z_indices = range(0, ds["temperature"].shape[1], spacing)  # e.g. 0..2048..128

        if data_type == "zarr":
            # Build a list of lazy means
            for t in timesteps:
                # shape ~ [z, y, x] (if last dim is 1, we can .squeeze or do `[..., 0]`)
                temp_t = ds["temperature"][t, :, :, :, 0]

                # pick out z_indices: shape [len(z_indices), y, x]
                temp_sliced = temp_t[list(z_indices), :, :]

                # lazy mean over y, x => shape [len(z_indices)]
                z_means = temp_sliced.mean(axis=(1, 2))

                lazy_mean_arrays.append(z_means)
                mean_t_indices.append(t)

            # We'll stack them so we have shape [nTimesteps, len(z_indices)]
            # but we won't compute until later
            if lazy_mean_arrays:
                # da.stack requires a list of dask arrays
                mean_temp_stack = da.stack(lazy_mean_arrays, axis=0)
            else:
                mean_temp_stack = None

        else:
            # "original" => not truly lazy, but let's just store them for consistency
            mean_temp_stack = []
            for t in timesteps:
                temp_t = ds["temperature"][t, :64, :64, :64, 0]  # or the full domain
                # Actually this is not lazy, but we'll keep the approach
                # You could do a loop or a partial approach if they're truly NumPy.
                pass
    else:
        mean_temp_stack = None
        mean_t_indices = []

    # ------------------------------------------------------------------
    # 4) SINGLE COMPUTE FOR EVERYTHING
    # ------------------------------------------------------------------
    # - We have `slice_arrays` (some might be Dask, some might be NumPy).
    # - We have `mean_temp_stack` (Dask or None).
    # We want to do a single `.compute()` on them all if they're Dask.

    # We'll gather them into a single list to compute in one pass
    dask_objects = []
    if slice_arrays:
        dask_objects.extend(slice_arrays)
    if mean_temp_stack is not None:
        # It's a single dask.array, not a list
        dask_objects.append(mean_temp_stack)

    # If there's truly nothing to compute (edge case), just skip
    if not dask_objects:
        print("No slices or mean arrays found to compute.")
        return

    print("Building Dask graph for slices + mean temperature. Now computing...")

    with ProgressBar():
        computed_results = compute(*dask_objects)

    # ------------------------------------------------------------------
    # 5) UNPACK RESULTS
    # ------------------------------------------------------------------
    # The first len(slice_arrays) items correspond to the slices, in order.
    # Then, if mean_temp_stack is not None, the last item is that array.

    if mean_temp_stack is not None:
        slice_result_count = len(slice_arrays)  
        slice_data_results = computed_results[:slice_result_count]
        mean_temp_results = computed_results[-1]  # shape [nTimesteps, len(z_indices)]
    else:
        slice_data_results = computed_results
        mean_temp_results = None

    # ------------------------------------------------------------------
    # 6) PLOT AND SAVE SLICE IMAGES
    # ------------------------------------------------------------------
    # We'll replicate the logic for saving each variable's slices to disk
    # By looping over variables + timesteps again in the same order.

    idx = 0
    for var in variables:
        var_folder = os.path.join(master_slice_folder, var)
        os.makedirs(var_folder, exist_ok=True)

        slices_per_ts = slice_counts_per_variable_timestep[var]

        for t in timesteps:
            # Grab the relevant slice results from slice_data_results
            subset = slice_data_results[idx : idx + slices_per_ts]
            subset_titles = slice_titles[idx : idx + slices_per_ts]

            for data_slice, title_str in zip(subset, subset_titles):
                plt.figure()
                plt.imshow(data_slice)
                plt.title(title_str)
                plt.gca().invert_yaxis()
                plt.colorbar()

                safe_title = title_str.replace(" ", "_").replace("=", "_")
                fname = os.path.join(var_folder, f"{safe_title}.png")
                plt.savefig(fname, dpi=150, bbox_inches='tight')
                plt.close()

            idx += slices_per_ts

    # ------------------------------------------------------------------
    # 7) PLOT AND SAVE MEAN TEMPERATURE
    # ------------------------------------------------------------------
    if mean_temp_results is not None:
        # mean_temp_results shape: [num_timesteps, len(z_indices)]
        # mean_t_indices: the actual timesteps
        z_indices_list = list(range(0, ds["temperature"].shape[1], spacing))

        for row_i, t in enumerate(mean_t_indices):
            xy_means = mean_temp_results[row_i]

            plt.figure()
            plt.plot(z_indices_list, xy_means, marker='o')
            plt.title(f"Mean Temperature across Z (t={t}, data_type={data_type})")
            plt.xlabel("Slice (Z)")
            plt.ylabel("Average Temperature")

            file_name = os.path.join(mean_output_folder, f"mean_z_temperature_{data_type}_t_{t}.png")
            plt.savefig(file_name, dpi=150, bbox_inches='tight')
            plt.close()
            print(f"Saved plot as {file_name}")

    print("All slice plots and mean-temperature plots saved.")

### Zarr - High Rate

In [11]:
process_dataset = "high_rate"

ds = jhf_hr_zarr = zarr.open(stored_jhf_hr_path)

## Across Z

In [None]:
# Suppose you've already got a dictionary-like `ds` with:
# ds["energy"], ds["temperature"], ds["pressure"], ds["velocity"] 
# as Dask arrays (from .from_zarr), each shape ~ [time, z, y, x, (maybe 3 for velocity)]

variables = ["energy", "temperature", "pressure", "velocity"]
timesteps = range(0, ds["energy"].shape[0], 5)  # or [0, 5, 10, ..., 105]

plot_slices_and_mean(
    ds,
    variables,
    timesteps,
    data_type="zarr",    # so we do the lazy means
    spacing=128,         # how often to sample Z dimension for the mean
    master_slice_folder="zarr_slices",
    mean_output_folder="mean_temperature_plots"
)

### Original NetCDF

In [None]:
plot_mean_temp_across_z(data_arrays, timesteps=range(0, 106, 5), data_type="original", spacing=128)

## Across X

- [ ] TODO if necessary