Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/mdio/transpose_writers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Views and transformations for MDIO datasets.

This module provides convenience APIs for creating different views and
transformations of MDIO Variables, including repartitioning operations and
sharding capabilities.
"""

from mdio.transpose_writers.chunking import from_variable as chunk_variable
from mdio.transpose_writers.lod import from_variable as lod_variable
from mdio.transpose_writers.shard import from_variable as shard_variable

__all__ = [
"chunk_variable",
"lod_variable",
"shard_variable",
]
246 changes: 246 additions & 0 deletions src/mdio/transpose_writers/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
"""Repartitioning operations for MDIO Variables."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import Any

import dask
from tqdm.auto import tqdm
from tqdm.dask import TqdmCallback
from xarray import DataArray

from mdio.api.io import _normalize_path
from mdio.api.io import open_mdio
from mdio.api.io import to_mdio
from mdio.builder.xarray_builder import _compressor_to_encoding
from mdio.core.config import MDIOSettings

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from pathlib import Path

from upath import UPath
from xarray import Dataset

from mdio.builder.schemas.chunk_grid import RectilinearChunkGrid
from mdio.builder.schemas.chunk_grid import RegularChunkGrid
from mdio.builder.schemas.compressors import ZFP
from mdio.builder.schemas.compressors import Blosc


def _remove_fillvalue_attrs(dataset: Dataset) -> None:
"""Remove _FillValue from all variable attrs to avoid conflicts with consolidated metadata.

This is only relevant for Zarr v2 format.
"""
for var_name in list(dataset.data_vars) + list(dataset.coords):
if "_FillValue" in dataset[var_name].attrs:
del dataset[var_name].attrs["_FillValue"]


def _validate_inputs(
new_variable: str | list[str],
chunk_grid: RegularChunkGrid | RectilinearChunkGrid | list[RegularChunkGrid | RectilinearChunkGrid],
compressor: ZFP | Blosc | list[ZFP | Blosc] | None,
) -> None:
"""Validate basic shapes/types (no broadcasting here)."""
# new_variable must be str or non-empty list[str]
if isinstance(new_variable, str):
pass
elif isinstance(new_variable, list):
if not new_variable:
msg = "new_variable list must not be empty"
raise ValueError(msg)
if not all(isinstance(v, str) for v in new_variable):
msg = "All entries in new_variable must be strings"
raise TypeError(msg)
else:
msg = "new_variable must be a string or a list of strings"
raise TypeError(msg)

# chunk_grid can be a single grid or non-empty list of grids
if isinstance(chunk_grid, list) and not chunk_grid:
msg = "chunk_grid list must not be empty"
raise ValueError(msg)

# compressor can be None, a single compressor, or non-empty list
if isinstance(compressor, list) and not compressor:
msg = "compressor list must not be empty"
raise ValueError(msg)


def _normalize_new_variable(
new_variable: str | list[str],
) -> list[str]:
"""Normalize new_variable to a list of names."""
if isinstance(new_variable, str):
return [new_variable]
# At this point _validate_inputs already ensured this is non-empty list[str]
return list(new_variable)


def _normalize_chunk_grid(
chunk_grid: RegularChunkGrid | RectilinearChunkGrid | list[RegularChunkGrid | RectilinearChunkGrid],
num_variables: int,
) -> list[RegularChunkGrid | RectilinearChunkGrid]:
"""Broadcast chunk_grid to match num_variables."""
if isinstance(chunk_grid, list):
if len(chunk_grid) == 1 and num_variables > 1:
return chunk_grid * num_variables
if len(chunk_grid) == num_variables:
return list(chunk_grid)
msg = "chunk_grid list length must be 1 or equal to the number of new variables"
raise ValueError(msg)
# single grid reused for all variables
return [chunk_grid] * num_variables


def _normalize_compressor(
compressor: ZFP | Blosc | list[ZFP | Blosc] | None,
num_variables: int,
) -> list[ZFP | Blosc | None]:
"""Broadcast compressor to match num_variables."""
if compressor is None:
return [None] * num_variables

if isinstance(compressor, list):
if len(compressor) == 1 and num_variables > 1:
return compressor * num_variables
if len(compressor) == num_variables:
return list(compressor)
msg = "compressor list length must be 1 or equal to the number of new variables"
raise ValueError(msg)

# single compressor reused for all variables
return [compressor] * num_variables


def from_variable( # noqa: PLR0913
dataset_path: UPath | Path | str,
source_variable: str,
new_variable: str | list[str],
chunk_grid: RegularChunkGrid | RectilinearChunkGrid | list[RegularChunkGrid | RectilinearChunkGrid],
compressor: ZFP | Blosc | list[ZFP | Blosc] | None = None,
copy_metadata: bool = True,
) -> None:
"""Add new Variable(s) to the Dataset with different chunking and compression.

Copies data from the source Variable to the new Variable(s) to create different
access patterns.

Args:
dataset_path: The path to a pre-existing MDIO Dataset.
source_variable: The name of the existing Variable to copy data from.
new_variable: The name(s) of the new Variable(s) to create.
chunk_grid:
Chunk grid(s) to use for the new Variable(s).
- Single grid: applied to all new variables.
- List of grids: length must be 1 (broadcast) or match len(new_variable).
compressor:
Compressor(s) for the new Variable(s).
- None: use source encoding compressor if present.
- Single compressor: applied to all new variables.
- List of compressors: length must be 1 (broadcast) or match len(new_variable).
copy_metadata: Whether to copy attrs/encoding from the source Variable.
"""
# 1) Basic validation (types, emptiness)
_validate_inputs(new_variable, chunk_grid, compressor)

# 2) Normalize/broadcast each argument
new_variables = _normalize_new_variable(new_variable)
num_vars = len(new_variables)
chunk_grids = _normalize_chunk_grid(chunk_grid, num_vars)
compressors = _normalize_compressor(compressor, num_vars)

normed_path = _normalize_path(dataset_path)
ds = open_mdio(normed_path)

source_var = ds[source_variable]
dims = source_var.dims
shape = source_var.shape
store_chunks = source_var.encoding.get("chunks", None)

logger.debug("Source variable %r: dims=%r, shape=%r, store_chunks=%r", source_variable, dims, shape, store_chunks)

settings = MDIOSettings()
num_workers = settings.export_cpus

dask_config: dict[str, Any] = {"scheduler": "processes", "num_workers": num_workers}

# 3) One Dask config context, write each new variable sequentially
with dask.config.set(**dask_config):
for name, grid, comp in tqdm(
zip(new_variables, chunk_grids, compressors, strict=True),
total=len(new_variables),
desc="Generating newly chunked Variables",
unit="variable",
):
new_chunks = tuple(grid.configuration.chunk_shape)

if len(dims) != len(new_chunks):
logger.warning(
"Original variable %r has dimensions %r, but new chunk shape %r "
"was provided for new variable %r. Behavior is currently undefined.",
source_variable,
dims,
new_chunks,
name,
)

# Build Dask chunk mapping for target chunks
dest_mapping = dict(zip(dims, new_chunks, strict=True))

# Rechunk directly to target chunks - skip intermediate work chunks to avoid task explosion
if store_chunks is not None and tuple(store_chunks) == new_chunks:
rechunked = source_var
else:
rechunked = source_var.chunk(dest_mapping)

logger.debug(
"Variable %r: nominal_chunks=%r, task graph has %d tasks",
name,
tuple(dim_chunks[0] for dim_chunks in rechunked.chunks) if rechunked.chunks is not None else new_chunks,
len(rechunked.__dask_graph__()) if rechunked.__dask_graph__() is not None else 0,
)

# Build DataArray for the new variable
attrs = source_var.attrs.copy() if copy_metadata else {}
new_da = DataArray(
data=rechunked.data,
dims=dims,
coords=source_var.coords,
attrs=attrs,
name=name,
)
new_ds = new_da.to_dataset(name=name)

# Per-variable encoding
encoding: dict[str, Any] = source_var.encoding.copy() if copy_metadata else {}
encoding["chunks"] = new_chunks
if comp is not None:
compressor_encoding = _compressor_to_encoding(comp)
encoding.update(compressor_encoding)
new_ds[name].encoding = encoding

# Clean up attrs that can conflict with consolidated metadata
_remove_fillvalue_attrs(new_ds)

# Drop non-dimensional coordinates to avoid chunk conflicts
coords_to_drop = [coord for coord in new_ds.coords if coord not in new_ds.dims]
if coords_to_drop:
new_ds = new_ds.drop_vars(coords_to_drop)

logger.debug("Starting write operation for variable %r with compute=True", name)
with TqdmCallback(desc=f"Writing variable '{name}'", unit="chunk"):
to_mdio(new_ds, normed_path, mode="a", compute=True)
logger.debug("Completed write operation for variable %r", name)

logger.info(
"Variable copy complete: %r -> %s",
source_variable,
", ".join(new_variables),
)
34 changes: 34 additions & 0 deletions src/mdio/transpose_writers/lod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Level of Detail (LoD) views for MDIO datasets."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from xarray import DataArray
from xarray import Dataset
from xarray import Variable


def from_variable(
data: DataArray | Dataset | Variable,
reduction_factor: int | dict[str, int],
method: str = "mean",
) -> DataArray | Dataset | Variable:
"""Create a Level of Detail view by downsampling the data.

Args:
data: The input data to downsample.
reduction_factor: Reduction factor for each dimension. Can be a single
integer (applied to all spatial dimensions) or a dict mapping
dimension names to reduction factors.
method: Downsampling method. Options: 'mean', 'max', 'min', 'median'.

Returns:
A downsampled copy of the input data.

Raises:
NotImplementedError: If method is not supported.
"""
msg = "Level of Detail operations are not yet implemented"
raise NotImplementedError(msg)
32 changes: 32 additions & 0 deletions src/mdio/transpose_writers/shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Sharding operations for MDIO datasets."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from xarray import Variable


def from_variable(
variable: Variable,
num_shards: int,
*,
shard_dimension: str | None = None,
) -> list[Variable]:
"""Shard a Variable across multiple pieces for distributed processing.

Args:
variable: The input Variable to shard.
num_shards: Number of shards to create.
shard_dimension: Dimension along which to shard. If None,
automatically selects the largest dimension.

Returns:
List of Variable shards.

Raises:
NotImplementedError: If sharding operations are not yet implemented.
"""
msg = "Sharding operations are not yet implemented"
raise NotImplementedError(msg)
Loading
Loading