From 0d5ebc402239232c7edc4615085a7e5d1055c822 Mon Sep 17 00:00:00 2001 From: Huite Date: Wed, 24 Sep 2025 21:05:00 +0200 Subject: [PATCH 1/7] Exclude auxiliary data correctly from options. Relates to #1681 and #1682 --- imod/mf6/boundary_condition.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/imod/mf6/boundary_condition.py b/imod/mf6/boundary_condition.py index afe63ffa7..0cd5168d4 100644 --- a/imod/mf6/boundary_condition.py +++ b/imod/mf6/boundary_condition.py @@ -202,12 +202,17 @@ def _get_unfiltered_pkg_options( options = copy(predefined_options) if not_options is None: - not_options = self._get_period_varnames() + not_options = [] + if hasattr(self, "_period_data"): + not_options.extend(self._period_data) + if hasattr(self, "_auxiliary_data"): + not_options.extend(get_variable_names(self)) + not_options.extend(self._auxiliary_data.keys()) for varname in self.dataset.data_vars.keys(): # pylint:disable=no-member if varname in not_options: continue - v = self.dataset[varname].values[()] + v = self.dataset[varname].item() options[varname] = v return options From 4a55a214b1e654a89cddb025600c9728ce048099 Mon Sep 17 00:00:00 2001 From: Huite Date: Wed, 24 Sep 2025 21:14:36 +0200 Subject: [PATCH 2/7] Fix the awful run times of the model splitting In a nutshell: make sure DIS is in RAM, avoid re-loading data over and over, avoid unnecessary work. Refactoring model splitter: Create a class to re-use data where possible. * DO NOT PURGE AFTER PARTITION MODEL CREATION. If data is spatially unchunked, this forces loading the entire dataset into memory. This means performance degrades linearly with the number of partitions (since each partition requires a separate load) unless all data is loaded into memory at start (not feasible for large models). * This implementation runs a groupby instead on the unpartitioned data (i.e. a single load) and counts the number of active elements. If a package has zero active elements, it is omitted. Some dispatching for point and line data (matches clipping logic). * Avoid looking for paired transport models; mapping is fully known a priori due to same partioning labels; keep them together in a NamedTuple (with some helpers). * Added some trailing returns (matter of taste) * Seems to increase performance from 2 hours (Joost) to 1 minute (me locally). The performance unfortunately still gets worse with each partition when writing the model. In principle, chunking fixes this (but requires data locality of unstructured meshes, can be done, see xugrid issues); alternatively, we might get dask to optimize writing. Needs research. --- imod/mf6/multimodel/exchange_creator.py | 15 +- .../multimodel/exchange_creator_structured.py | 10 +- .../exchange_creator_unstructured.py | 5 +- imod/mf6/multimodel/modelsplitter.py | 315 +++++++++++++++--- imod/mf6/simulation.py | 124 ++++--- 5 files changed, 343 insertions(+), 126 deletions(-) diff --git a/imod/mf6/multimodel/exchange_creator.py b/imod/mf6/multimodel/exchange_creator.py index 31a55e6a8..8a9a16b00 100644 --- a/imod/mf6/multimodel/exchange_creator.py +++ b/imod/mf6/multimodel/exchange_creator.py @@ -1,5 +1,5 @@ import abc -from typing import Dict +from typing import Dict, NamedTuple import numpy as np import pandas as pd @@ -8,10 +8,14 @@ from imod.common.utilities.grid import get_active_domain_slice, to_cell_idx from imod.mf6.gwfgwf import GWFGWF from imod.mf6.gwtgwt import GWTGWT -from imod.mf6.multimodel.modelsplitter import PartitionInfo from imod.typing import GridDataArray +class PartitionInfo(NamedTuple): + active_domain: GridDataArray + partition_id: int + + def _adjust_gridblock_indexing(connected_cells: xr.Dataset) -> xr.Dataset: """ adjusts the gridblock numbering from 0-based to 1-based. @@ -25,8 +29,7 @@ class ExchangeCreator(abc.ABC): """ Creates the GroundWaterFlow to GroundWaterFlow exchange package (gwfgwf) as a function of a submodel label array and a PartitionInfo object. This file contains the cell indices of coupled cells. With coupled cells we mean - cells that are adjacent but that are located in different subdomains. At the moment only structured grids are - supported, for unstructured grids the geometric information is still set to default values. + cells that are adjacent but that are located in different subdomains. The submodel_labels array should have the same topology as the domain being partitioned. The array will be used to determine the connectivity of the submodels after the split operation has been performed. @@ -248,7 +251,7 @@ def _create_global_cellidx_to_local_cellid_mapping( mapping = {} for submodel_partition_info in partition_info: - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id mapping[model_id] = pd.merge( global_to_local_idx[model_id], local_cell_idx_to_id[model_id] ) @@ -268,7 +271,7 @@ def _get_local_cell_indices( def _local_cell_idx_to_id(cls, partition_info) -> Dict[int, pd.DataFrame]: local_cell_idx_to_id = {} for submodel_partition_info in partition_info: - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id local_cell_indices = cls._get_local_cell_indices(submodel_partition_info) local_cell_id = list(np.ndindex(local_cell_indices.shape)) diff --git a/imod/mf6/multimodel/exchange_creator_structured.py b/imod/mf6/multimodel/exchange_creator_structured.py index 79b3fee91..7d1572efd 100644 --- a/imod/mf6/multimodel/exchange_creator_structured.py +++ b/imod/mf6/multimodel/exchange_creator_structured.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, NamedTuple import numpy as np import pandas as pd @@ -6,12 +6,16 @@ from imod.common.utilities.grid import create_geometric_grid_info from imod.mf6.multimodel.exchange_creator import ExchangeCreator -from imod.mf6.multimodel.modelsplitter import PartitionInfo from imod.typing import GridDataArray NOT_CONNECTED_VALUE = -999 +class PartitionInfo(NamedTuple): + active_domain: GridDataArray + partition_id: int + + class ExchangeCreator_Structured(ExchangeCreator): """ Creates the GroundWaterFlow to GroundWaterFlow exchange package (gwfgwf) as @@ -130,7 +134,7 @@ def _create_global_to_local_idx( compat="override", )["label"] - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id global_to_local_idx[model_id] = pd.DataFrame( { "global_idx": overlap.values.flatten(), diff --git a/imod/mf6/multimodel/exchange_creator_unstructured.py b/imod/mf6/multimodel/exchange_creator_unstructured.py index 2869bc1ea..43d40aaca 100644 --- a/imod/mf6/multimodel/exchange_creator_unstructured.py +++ b/imod/mf6/multimodel/exchange_creator_unstructured.py @@ -4,8 +4,7 @@ import pandas as pd import xarray as xr -from imod.mf6.multimodel.exchange_creator import ExchangeCreator -from imod.mf6.multimodel.modelsplitter import PartitionInfo +from imod.mf6.multimodel.exchange_creator import ExchangeCreator, PartitionInfo from imod.typing import GridDataArray @@ -122,7 +121,7 @@ def _create_global_to_local_idx( compat="override", )["label"] - model_id = submodel_partition_info.id + model_id = submodel_partition_info.partition_id global_to_local_idx[model_id] = pd.DataFrame( { "global_idx": overlap.values.flatten(), diff --git a/imod/mf6/multimodel/modelsplitter.py b/imod/mf6/multimodel/modelsplitter.py index b77a9f725..5210e2652 100644 --- a/imod/mf6/multimodel/modelsplitter.py +++ b/imod/mf6/multimodel/modelsplitter.py @@ -1,8 +1,14 @@ -from typing import List, NamedTuple +import collections +from typing import Any, NamedTuple import numpy as np +from plum import Dispatcher +import imod +from imod.common.interfaces.ilinedatapackage import ILineDataPackage from imod.common.interfaces.imodel import IModel +from imod.common.interfaces.ipackagebase import IPackageBase +from imod.common.interfaces.ipointdatapackage import IPointDataPackage from imod.common.utilities.clip import clip_by_grid from imod.mf6.auxiliary_variables import ( expand_transient_auxiliary_variables, @@ -10,76 +16,283 @@ ) from imod.mf6.boundary_condition import BoundaryCondition from imod.mf6.hfb import HorizontalFlowBarrierBase +from imod.mf6.multimodel.exchange_creator import PartitionInfo +from imod.mf6.multimodel.exchange_creator_structured import ExchangeCreator_Structured +from imod.mf6.multimodel.exchange_creator_unstructured import ( + ExchangeCreator_Unstructured, +) from imod.mf6.wel import Well from imod.typing import GridDataArray -from imod.typing.grid import ones_like +from imod.typing.grid import bounding_polygon, is_unstructured HIGH_LEVEL_PKGS = (HorizontalFlowBarrierBase, Well) -class PartitionInfo(NamedTuple): - active_domain: GridDataArray - id: int +dispatch = Dispatcher() + + +@dispatch +def activity_count( + package: object, labels: object, polygons: list[Any], ignore_time_purge_empty: bool +) -> dict: + raise TypeError( + f"`labels` should be of type xr.DataArray, xu.Ugrid2d or xu.UgridDataArray, got {type(labels)}" + ) + + +@dispatch +def activity_count( # noqa: F811 + package: IPackageBase, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + label_dims = set(labels.dims) + dataset = package.dataset + + # Determine sample variable: it should be spatial. + # Otherwise return a count of 1 for each partition. + if not label_dims.intersection(dataset.dims): + return dict.fromkeys(range(len(polygons)), 1) + + # Find variable with spatial dimensions + # Accessing variables is cheaper than creating a DataArray. + ndim_per_variable = { + var_name: len(dims) + for var_name in dataset.data_vars + if label_dims.intersection(dims := dataset.variables[var_name].dims) + } + max_variable = max(ndim_per_variable, key=ndim_per_variable.get) + # TODO: there might be a more robust way to do this. + # Alternatively, we just define a predicate variable (e.g. conductance) + # on each package. + sample = dataset[max_variable] + if "time" in sample.coords: + if ignore_time_purge_empty: + sample = sample.isel(time=0) + else: + sample = sample.max("time") -def create_partition_info(submodel_labels: GridDataArray) -> List[PartitionInfo]: + # Use ellipsis to reduce over ALL dimensions except label dims + dims_to_aggregate = [dim for dim in sample.dims if dim not in label_dims] + counts = sample.notnull().sum(dim=dims_to_aggregate).groupby(labels).sum() + return {label: int(n) for label, n in enumerate(counts.data)} + + +@dispatch +def activity_count( # noqa: F811 + package: IPointDataPackage, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + point_labels = imod.select.points_values( + labels, out_of_bounds="ignore", x=package.x, y=package.y + ) + return {label: int(n) for label, n in enumerate(np.bincount(point_labels))} + + +@dispatch +def activity_count( # noqa: F811 + package: ILineDataPackage, + labels: object, + polygons: list[Any], + ignore_time_purge_empty: bool, +) -> dict: + counts = {} + gdf_linestrings = package.line_data + for partition_id, polygon in enumerate(polygons): + partition_linestrings = gdf_linestrings.clip(polygon) + # Catch edge case: when line crosses only vertex of polygon, a point + # or multipoint is returned. These will be dropped, and can be + # identified by zero length. + counts[partition_id] = sum(partition_linestrings.length > 0) + return counts + + +class PartitionModels(NamedTuple): """ - A PartitionInfo is used to partition a model or package. The partition info's of a domain are created using a - submodel_labels array. The submodel_labels provided as input should have the same shape as a single layer of the - model grid (all layers are split the same way), and contains an integer value in each cell. Each cell in the - model grid will end up in the submodel with the index specified by the corresponding label of that cell. The - labels should be numbers between 0 and the number of partitions. + Mapping of: + flow_model_name (str) => model (object) + partition_id (int) => transport_model_name (str) => model (object) """ - _validate_submodel_label_array(submodel_labels) - unique_labels = np.unique(submodel_labels.values) + flow_models: dict[str, object] + transport_models: dict[int, dict[str, object]] - partition_infos = [] - for label_id in unique_labels: - active_domain = submodel_labels.where(submodel_labels.values == label_id) - active_domain = ones_like(active_domain).where(active_domain.notnull(), 0) - active_domain = active_domain.astype(submodel_labels.dtype) + def paired_keys(self): + for partition_id, key in enumerate(self.flow_models.keys()): + yield key, list(self.transport_models[partition_id].keys()) - submodel_partition_info = PartitionInfo( - id=label_id, active_domain=active_domain - ) - partition_infos.append(submodel_partition_info) + def paired_models(self): + for partition_id, model in enumerate(self.flow_models.values()): + yield model, list(self.transport_models[partition_id].values()) - return partition_infos + def paired_items(self): + for partition_id, (key, model) in enumerate(self.flow_models.items()): + partition_models = self.transport_models[partition_id] + yield ( + (key, model), + (list(partition_models.keys()), list(partition_models.values())), + ) + @property + def flat_transport_models(self): + return { + name: model + for partition_models in self.transport_models.values() + for name, model in partition_models.items() + } -def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None: - unique_labels = np.unique(submodel_labels.values) - if not ( - len(unique_labels) == unique_labels.max() + 1 - and unique_labels.min() == 0 - and np.issubdtype(submodel_labels.dtype, np.integer) +class ModelSplitter: + def __init__( + self, + flow_models: dict[str, object], + transport_models: dict[str, object], + submodel_labels: GridDataArray, + ignore_time_purge_empty: bool = False, ): - raise ValueError( - "The submodel_label array should be integer and contain all the numbers between 0 and the number of " - "partitions minus 1." - ) + self.flow_models = flow_models + self.transport_models = transport_models + self.models = {**flow_models, **transport_models} + self.submodel_labels = submodel_labels + self.unique_labels = self._validate_submodel_label_array(submodel_labels) + self.ignore_time_purge_empty = ignore_time_purge_empty + self._create_partition_info() + self.bounding_polygons = [ + bounding_polygon(partition.active_domain) + for partition in self.partition_info + ] + self.exchange_creator: ExchangeCreator_Unstructured | ExchangeCreator_Structured + if is_unstructured(self.submodel_labels): + self.exchange_creator = ExchangeCreator_Unstructured( + self.submodel_labels, self.partition_info + ) + else: + self.exchange_creator = ExchangeCreator_Structured( + self.submodel_labels, self.partition_info + ) -def slice_model(partition_info: PartitionInfo, model: IModel) -> IModel: - """ - This function slices a Modflow6Model. A sliced model is a model that - consists of packages of the original model that are sliced using the - domain_slice. A domain_slice can be created using the - :func:`imod.mf6.modelsplitter.create_domain_slices` function. - """ - modelclass = type(model) - new_model = modelclass(**model.options) + self._count_boundary_activity_per_partition() + + @property + def modelnames(self): + return list(self.models.keys()) + + @staticmethod + def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None: + unique_labels = np.unique(submodel_labels) + + if not ( + len(unique_labels) == unique_labels.max() + 1 + and unique_labels.min() == 0 + and np.issubdtype(submodel_labels.dtype, np.integer) + ): + raise ValueError( + "The submodel_label array should be integer and contain all the numbers between 0 and the number of " + "partitions minus 1." + ) + return unique_labels + + def _create_partition_info(self): + self.partition_info = [] + labels = self.submodel_labels + for label_id in self.unique_labels: + active_domain = (labels == label_id).astype(labels.dtype) + self.partition_info.append( + PartitionInfo( + active_domain=active_domain, + partition_id=int(label_id), + ) + ) + + def _create_partition_polygons(self): + self.partition_polygons = { + info.partition_id: bounding_polygon(info.active_domain) + for info in self.partition_info + } + + def _count_boundary_activity_per_partition(self): + counts = {} + for model_name, model in self.models.items(): + model_counts = {} + for pkg_name, package in model.items(): + # Packages like NPF, DIS are always required. + # We only need to check packages with a MAXBOUND entry. + if not isinstance(package, BoundaryCondition): + continue + model_counts[pkg_name] = activity_count( + package, + self.submodel_labels, + self.bounding_polygons, + self.ignore_time_purge_empty, + ) + counts[model_name] = model_counts + self.boundary_activity_counts = counts + + def slice_model( + self, model: IModel, info: PartitionInfo, boundary_activity_counts: dict + ) -> IModel: + modelclass = type(model) + new_model = modelclass(**model.options) + + for pkg_name, package in model.items(): + if isinstance(package, BoundaryCondition): + # Skip empty boundary conditions + if boundary_activity_counts[pkg_name][info.partition_id] == 0: + continue + else: + remove_expanded_auxiliary_variables_from_dataset(package) + + sliced_package = clip_by_grid(package, info.active_domain) + if sliced_package is not None: + new_model[pkg_name] = sliced_package + + if isinstance(package, BoundaryCondition): + expand_transient_auxiliary_variables(sliced_package) + + return new_model + + def _split(self, models, nest: bool): + partition_models = collections.defaultdict(dict) + for model_name, model in models.items(): + for info in self.partition_info: + new_model = self.slice_model( + model, info, self.boundary_activity_counts[model_name] + ) + new_model_name = f"{model_name}_{info.partition_id}" + if nest: + partition_models[info.partition_id][new_model_name] = new_model + else: + partition_models[new_model_name] = new_model + return partition_models - for pkg_name, package in model.items(): - if isinstance(package, BoundaryCondition): - remove_expanded_auxiliary_variables_from_dataset(package) + def split(self): + # FUTURE: we may currently assume there is a single flow model. See check above. + # And each separate transport model represents a different species. + flow_models = self._split(self.flow_models, nest=False) + transport_models = self._split(self.transport_models, nest=True) + return PartitionModels(flow_models, transport_models) - sliced_package = clip_by_grid(package, partition_info.active_domain) - if sliced_package is not None: - new_model[pkg_name] = sliced_package + def create_gwfgwf_exchanges(self): + exchanges: list[Any] = [] + for model_name, model in self.flow_models.items(): + exchanges += self.exchange_creator.create_gwfgwf_exchanges( + model_name, model.domain.layer + ) + return exchanges - if isinstance(package, BoundaryCondition): - expand_transient_auxiliary_variables(sliced_package) - return new_model + def create_gwtgwt_exchanges(self): + exchanges: list[Any] = [] + # TODO: weird/arbitrary dependence on the single flow model? + flow_model_name = list(self.flow_models.keys())[0] + model = self.flow_models[flow_model_name] + if any(self.transport_models): + for transport_model_name in self.transport_models: + exchanges += self.exchange_creator.create_gwtgwt_exchanges( + transport_model_name, flow_model_name, model.domain.layer + ) + return exchanges diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index 759612969..6fc5e60e3 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -39,11 +39,7 @@ from imod.mf6.model import Modflow6Model from imod.mf6.model_gwf import GroundwaterFlowModel from imod.mf6.model_gwt import GroundwaterTransportModel -from imod.mf6.multimodel.exchange_creator_structured import ExchangeCreator_Structured -from imod.mf6.multimodel.exchange_creator_unstructured import ( - ExchangeCreator_Unstructured, -) -from imod.mf6.multimodel.modelsplitter import create_partition_info, slice_model +from imod.mf6.multimodel.modelsplitter import ModelSplitter, PartitionModels from imod.mf6.out import open_cbc, open_conc, open_hds from imod.mf6.package import Package from imod.mf6.ssm import SourceSinkMixing @@ -92,6 +88,12 @@ def get_packages(simulation: Modflow6Simulation) -> dict[str, Package]: } +def force_load_dis(model): + key = model.get_diskey() + model[key].dataset.load() + return + + class Modflow6Simulation(collections.UserDict, ISimulation): """ Modflow6Simulation is a class that represents a Modflow 6 simulation. It @@ -1412,57 +1414,53 @@ def split( f"simulation cannot be split due to presence of package '{error_with_object}' in model '{model_name}'" ) - original_packages = get_packages(self) - - partition_info = create_partition_info(submodel_labels) - - exchange_creator: ExchangeCreator_Unstructured | ExchangeCreator_Structured - if is_unstructured(submodel_labels): - exchange_creator = ExchangeCreator_Unstructured( - submodel_labels, partition_info - ) - else: - exchange_creator = ExchangeCreator_Structured( - submodel_labels, partition_info - ) + # Make sure the DIS package is available in memory and not lazily evaluated, + # since we need its values repeatedly. + for model in original_models.values(): + force_load_dis(model) + original_packages = get_packages(self) new_simulation = imod.mf6.Modflow6Simulation( f"{self.name}_partioned", validation_settings=self._validation_context ) for package_name, package in {**original_packages}.items(): new_simulation[package_name] = deepcopy(package) - for model_name, model in original_models.items(): + model_splitter = ModelSplitter( + flow_models, + transport_models, + submodel_labels, + ignore_time_purge_empty, + ) + + # TODO: isn't it cleaner to just construct a new Solution object instead? + # We know none of the original models will remain? + # And the new models is just the list of newly generated ones. + for model_name in model_splitter.modelnames: solution_name = self.get_solution_name(model_name) solution = cast(Solution, new_simulation[solution_name]) solution._remove_model_from_solution(model_name) - for submodel_partition_info in partition_info: - new_model_name = f"{model_name}_{submodel_partition_info.id}" - new_simulation[new_model_name] = slice_model( - submodel_partition_info, model - ) - new_simulation[new_model_name].purge_empty_packages( - ignore_time=ignore_time_purge_empty - ) - solution._add_model_to_solution(new_model_name) - - exchanges: list[Any] = [] - - for flow_model_name, flow_model in flow_models.items(): - exchanges += exchange_creator.create_gwfgwf_exchanges( - flow_model_name, flow_model.domain.layer - ) - if any(transport_models): - for tpt_model_name in transport_models: - exchanges += exchange_creator.create_gwtgwt_exchanges( - tpt_model_name, flow_model_name, model.domain.layer - ) + partition_models = model_splitter.split() + chained = { + **partition_models.flow_models, + **partition_models.flat_transport_models, + } + for partition_model_name, partition_model in chained.items(): + new_simulation[partition_model_name] = partition_model + # TODO: see to do above. + solution._add_model_to_solution(partition_model_name) + + # Add exchanges + exchanges: list[Any] = ( + model_splitter.create_gwfgwf_exchanges() + + model_splitter.create_gwtgwt_exchanges() + ) new_simulation._add_modelsplit_exchanges(exchanges) - new_simulation._update_buoyancy_packages() + new_simulation._update_buoyancy_packages(partition_models) new_simulation._set_flow_exchange_options() new_simulation._set_transport_exchange_options() - new_simulation._update_ssm_packages() + new_simulation._update_ssm_packages(partition_models) new_simulation._filter_inactive_cells_from_exchanges() return new_simulation @@ -1506,6 +1504,7 @@ def _add_modelsplit_exchanges(self, exchanges_list: list[GWFGWF]) -> None: if not self.is_split(): self["split_exchanges"] = [] self["split_exchanges"].extend(exchanges_list) + return def _set_flow_exchange_options(self) -> None: # collect some options that we will auto-set @@ -1520,6 +1519,7 @@ def _set_flow_exchange_options(self) -> None: xt3d=model_1["npf"].get_xt3d_option(), newton=model_1.is_use_newton(), ) + return def _set_transport_exchange_options(self) -> None: for exchange in self["split_exchanges"]: @@ -1552,6 +1552,7 @@ def _filter_inactive_cells_from_exchanges(self) -> None: # Remove exchange if no cells are left if ex.dataset.sizes["index"] == 0: self["split_exchanges"].remove(ex) + return def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None: """Filters inactive cells from one exchange domain inplace""" @@ -1575,6 +1576,7 @@ def _filter_inactive_cells_exchange_domain(self, ex: GWFGWF, i: int) -> None: active_exchange_domain = exchange_domain.where(exchange_domain > 0) active_exchange_domain = active_exchange_domain.dropna("index") ex.dataset = ex.dataset.sel(index=active_exchange_domain["index"]) + return def get_solution_name(self, model_name: str) -> Optional[str]: for k, v in self.items(): @@ -1616,25 +1618,21 @@ def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]: result[flow_model_name].append(tpt_model_name) return result - def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]: + def _generate_gwfgwt_exchanges( + self, partition_models: PartitionModels + ) -> list[GWFGWT]: exchanges = [] - flow_transport_mapping = self._get_transport_models_per_flow_model() - for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): - if len(tpt_models_of_flow_model) > 0: - for transport_model_name in tpt_models_of_flow_model: - exchanges.append(GWFGWT(flow_name, transport_model_name)) - + for flow_name, transport_model_names in partition_models.paired_keys(): + for transport_name in transport_model_names: + exchanges.append(GWFGWT(flow_name, transport_name)) return exchanges - def _update_ssm_packages(self) -> None: - flow_transport_mapping = self._get_transport_models_per_flow_model() - for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): - flow_model = self[flow_name] - for tpt_model_name in tpt_models_of_flow_model: - tpt_model = self[tpt_model_name] - ssm_key = tpt_model._get_pkgkey("ssm") + def _update_ssm_packages(self, partition_models: PartitionModels) -> None: + for flow_model, paired_transport_models in partition_models.paired_models(): + for transport_model in paired_transport_models: + ssm_key = transport_model._get_pkgkey("ssm") if ssm_key is not None: - old_ssm_package = tpt_model.pop(ssm_key) + old_ssm_package = transport_model.pop(ssm_key) state_variable_name = old_ssm_package.dataset[ "auxiliary_variable_name" ].values[0] @@ -1642,13 +1640,13 @@ def _update_ssm_packages(self) -> None: flow_model, state_variable_name, is_split=self.is_split() ) if ssm_package is not None: - tpt_model[ssm_key] = ssm_package + transport_model[ssm_key] = ssm_package + return - def _update_buoyancy_packages(self) -> None: - flow_transport_mapping = self._get_transport_models_per_flow_model() - for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): - flow_model = cast(GroundwaterFlowModel, self[flow_name]) - flow_model._update_buoyancy_package(tpt_models_of_flow_model) + def _update_buoyancy_packages(self, partition_models: PartitionModels) -> None: + for (_, flow_model), (names, _) in partition_models.paired_items(): + flow_model._update_buoyancy_package(names) + return def is_split(self) -> bool: """ From 6a204f08917885397f716cd66950baa2c95aa3bd Mon Sep 17 00:00:00 2001 From: Huite Date: Wed, 24 Sep 2025 21:27:46 +0200 Subject: [PATCH 3/7] Restore _generate_gwfgwt_exchanges (does not use partition_models) --- imod/mf6/simulation.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index 6fc5e60e3..68b72d98a 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -1618,13 +1618,14 @@ def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]: result[flow_model_name].append(tpt_model_name) return result - def _generate_gwfgwt_exchanges( - self, partition_models: PartitionModels - ) -> list[GWFGWT]: + def _generate_gwfgwt_exchanges(self) -> list[GWFGWT]: exchanges = [] - for flow_name, transport_model_names in partition_models.paired_keys(): - for transport_name in transport_model_names: - exchanges.append(GWFGWT(flow_name, transport_name)) + flow_transport_mapping = self._get_transport_models_per_flow_model() + for flow_name, tpt_models_of_flow_model in flow_transport_mapping.items(): + if len(tpt_models_of_flow_model) > 0: + for transport_model_name in tpt_models_of_flow_model: + exchanges.append(GWFGWT(flow_name, transport_model_name)) + return exchanges def _update_ssm_packages(self, partition_models: PartitionModels) -> None: From c7037e3aafbb2aac032a2db2e13d63766b86c491 Mon Sep 17 00:00:00 2001 From: Huite Date: Fri, 26 Sep 2025 13:07:10 +0200 Subject: [PATCH 4/7] Simplify (and fix) solution modelnames handling --- imod/mf6/multimodel/modelsplitter.py | 18 +++++++------- imod/mf6/simulation.py | 35 +++++++++++++--------------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/imod/mf6/multimodel/modelsplitter.py b/imod/mf6/multimodel/modelsplitter.py index 5210e2652..07e5bde34 100644 --- a/imod/mf6/multimodel/modelsplitter.py +++ b/imod/mf6/multimodel/modelsplitter.py @@ -178,10 +178,6 @@ def __init__( self._count_boundary_activity_per_partition() - @property - def modelnames(self): - return list(self.models.keys()) - @staticmethod def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None: unique_labels = np.unique(submodel_labels) @@ -258,6 +254,7 @@ def slice_model( def _split(self, models, nest: bool): partition_models = collections.defaultdict(dict) + model_names = collections.defaultdict(list) for model_name, model in models.items(): for info in self.partition_info: new_model = self.slice_model( @@ -268,14 +265,19 @@ def _split(self, models, nest: bool): partition_models[info.partition_id][new_model_name] = new_model else: partition_models[new_model_name] = new_model - return partition_models + + model_names[model_name].append(new_model_name) + return partition_models, model_names def split(self): # FUTURE: we may currently assume there is a single flow model. See check above. # And each separate transport model represents a different species. - flow_models = self._split(self.flow_models, nest=False) - transport_models = self._split(self.transport_models, nest=True) - return PartitionModels(flow_models, transport_models) + flow_models, flow_names = self._split(self.flow_models, nest=False) + transport_models, transport_names = self._split( + self.transport_models, nest=True + ) + names = {**flow_names, **transport_names} + return PartitionModels(flow_models, transport_models), names def create_gwfgwf_exchanges(self): exchanges: list[Any] = [] diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index 68b72d98a..bda8259ad 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -1419,37 +1419,34 @@ def split( for model in original_models.values(): force_load_dis(model) - original_packages = get_packages(self) - new_simulation = imod.mf6.Modflow6Simulation( - f"{self.name}_partioned", validation_settings=self._validation_context - ) - for package_name, package in {**original_packages}.items(): - new_simulation[package_name] = deepcopy(package) - model_splitter = ModelSplitter( flow_models, transport_models, submodel_labels, ignore_time_purge_empty, ) + partition_models, model_names = model_splitter.split() - # TODO: isn't it cleaner to just construct a new Solution object instead? - # We know none of the original models will remain? - # And the new models is just the list of newly generated ones. - for model_name in model_splitter.modelnames: - solution_name = self.get_solution_name(model_name) - solution = cast(Solution, new_simulation[solution_name]) - solution._remove_model_from_solution(model_name) - - partition_models = model_splitter.split() - chained = { + # Create new simulation object and add the partitioned models. + new_simulation = imod.mf6.Modflow6Simulation( + f"{self.name}_partioned", validation_settings=self._validation_context + ) + chained = { # ChainMap reverses order, annoyingly... **partition_models.flow_models, **partition_models.flat_transport_models, } for partition_model_name, partition_model in chained.items(): new_simulation[partition_model_name] = partition_model - # TODO: see to do above. - solution._add_model_to_solution(partition_model_name) + + # Add solution, time_discretization, etc. + # Replace the single model name by the partition model names. + original_packages = get_packages(self) + for package_name, package in original_packages.items(): + new_package = deepcopy(package) + if isinstance(package, Solution): + old_name = package.dataset["modelnames"].item() + new_package["modelnames"] = xr.DataArray(model_names[old_name]) + new_simulation[package_name] = new_package # Add exchanges exchanges: list[Any] = ( From 15ef01f2dd9457763e160e44bd046a7b9f54c7fd Mon Sep 17 00:00:00 2001 From: Huite Date: Fri, 26 Sep 2025 15:40:05 +0200 Subject: [PATCH 5/7] Try-except scalar .item() in get_unfiltered_pkg_options --- imod/mf6/boundary_condition.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/imod/mf6/boundary_condition.py b/imod/mf6/boundary_condition.py index 0cd5168d4..29dffc42d 100644 --- a/imod/mf6/boundary_condition.py +++ b/imod/mf6/boundary_condition.py @@ -212,7 +212,13 @@ def _get_unfiltered_pkg_options( for varname in self.dataset.data_vars.keys(): # pylint:disable=no-member if varname in not_options: continue - v = self.dataset[varname].item() + # TODO: can we easily avoid this try-except? + # On which keys does it fail? + try: + v = self.dataset[varname].item() + except ValueError: + # Apparently not a scalar, therefore not an option entry. + pass options[varname] = v return options From b791736aa542dcabb34a4ca98b44a01041467522 Mon Sep 17 00:00:00 2001 From: Huite Date: Fri, 26 Sep 2025 19:35:30 +0200 Subject: [PATCH 6/7] Add to_zarr on packages, add engine keyword in model and simulation dump methods --- imod/mf6/hfb.py | 6 ++++++ imod/mf6/model.py | 24 +++++++++++++++++---- imod/mf6/pkgbase.py | 14 +++++++++++-- imod/mf6/simulation.py | 35 ++++++++++++++++++++++++++----- imod/mf6/utilities/zarr_helper.py | 24 +++++++++++++++++++++ 5 files changed, 92 insertions(+), 11 deletions(-) create mode 100644 imod/mf6/utilities/zarr_helper.py diff --git a/imod/mf6/hfb.py b/imod/mf6/hfb.py index 32d594d00..12e8e327e 100644 --- a/imod/mf6/hfb.py +++ b/imod/mf6/hfb.py @@ -33,6 +33,7 @@ from imod.mf6.disv import VerticesDiscretization from imod.mf6.mf6_hfb_adapter import Mf6HorizontalFlowBarrier from imod.mf6.package import Package +from imod.mf6.utilities.zarr_helper import to_zarr from imod.mf6.validation_settings import ValidationSettings from imod.prepare.cleanup import cleanup_hfb from imod.schemata import ( @@ -562,6 +563,11 @@ def to_netcdf( new.dataset["geometry"] = new.line_data.to_json() new.dataset.to_netcdf(*args, **kwargs) + def to_zarr(self, path, engine: str, **kwargs): + new = deepcopy(self) + new.dataset["geometry"] = new.line_data.to_json() + to_zarr(new.dataset, path, engine, **kwargs) + def _netcdf_encoding(self): return {"geometry": {"dtype": "str"}} diff --git a/imod/mf6/model.py b/imod/mf6/model.py index 8f3888abb..2cd0d38e2 100644 --- a/imod/mf6/model.py +++ b/imod/mf6/model.py @@ -592,6 +592,7 @@ def dump( validate: bool = True, mdal_compliant: bool = False, crs: Optional[Any] = None, + engine="netCDF4", ): """ Dump simulation to files. Writes a model definition as .TOML file, which @@ -615,6 +616,8 @@ def dump( crs: Any, optional Anything accepted by rasterio.crs.CRS.from_user_input Requires ``rioxarray`` installed. + engine: str, optional + "netCDF4" or "zarr" or "zarr.zip". Defaults to "netCDF4". """ modeldirectory = pathlib.Path(directory) / modelname modeldirectory.mkdir(exist_ok=True, parents=True) @@ -624,13 +627,26 @@ def dump( if statusinfo.has_errors(): raise ValidationError(statusinfo.to_string()) + match engine: + case "netCDF4": + ext = "nc" + case "zarr": + ext = "zarr" + case "zarr.zip": + ext = "zarr.zip" + case _: + raise ValueError(f"Unknown engine: {engine}") + toml_content: dict = collections.defaultdict(dict) for pkgname, pkg in self.items(): - pkg_path = f"{pkgname}.nc" + pkg_path = f"{pkgname}.{ext}" toml_content[type(pkg).__name__][pkgname] = pkg_path - pkg.to_netcdf( - modeldirectory / pkg_path, crs=crs, mdal_compliant=mdal_compliant - ) + if engine == "netCDF4": + pkg.to_netcdf( + modeldirectory / pkg_path, crs=crs, mdal_compliant=mdal_compliant + ) + else: + pkg.to_zarr(modeldirectory / pkg_path, engine=engine) toml_path = modeldirectory / f"{modelname}.toml" with open(toml_path, "wb") as f: diff --git a/imod/mf6/pkgbase.py b/imod/mf6/pkgbase.py index d2fbf99bf..8e24232c6 100644 --- a/imod/mf6/pkgbase.py +++ b/imod/mf6/pkgbase.py @@ -10,6 +10,7 @@ import imod from imod.common.interfaces.ipackagebase import IPackageBase +from imod.mf6.utilities.zarr_helper import to_zarr from imod.typing.grid import ( GridDataArray, GridDataset, @@ -108,6 +109,9 @@ def to_netcdf( dataset = imod.util.spatial.gdal_compliant_grid(dataset, crs=crs) dataset.to_netcdf(*args, **kwargs) + def to_zarr(self, path, engine, **kwargs): + to_zarr(self.dataset, path, engine, **kwargs) + def _netcdf_encoding(self) -> dict: """ @@ -163,10 +167,16 @@ def from_file(cls, path: str | Path, **kwargs) -> Self: Refer to the xarray documentation for the possible keyword arguments. """ + path = Path(path) if path.suffix in (".zip", ".zarr"): - # TODO: seems like a bug? Remove str() call if fixed in xarray/zarr - dataset = xr.open_zarr(str(path), **kwargs) + import zarr + + if path.suffix == ".zip": + with zarr.storage.ZipStore(path, mode="r") as store: + dataset = xr.open_zarr(store, **kwargs) + else: + dataset = xr.open_zarr(str(path), **kwargs) else: dataset = xr.open_dataset(path, **kwargs) diff --git a/imod/mf6/simulation.py b/imod/mf6/simulation.py index bda8259ad..5c1d682c6 100644 --- a/imod/mf6/simulation.py +++ b/imod/mf6/simulation.py @@ -43,6 +43,7 @@ from imod.mf6.out import open_cbc, open_conc, open_hds from imod.mf6.package import Package from imod.mf6.ssm import SourceSinkMixing +from imod.mf6.utilities.zarr_helper import to_zarr from imod.mf6.validation_settings import ValidationSettings from imod.mf6.write_context import WriteContext from imod.prepare.partition import create_partition_labels @@ -971,6 +972,7 @@ def dump( validate: bool = True, mdal_compliant: bool = False, crs=None, + engine="netCDF4", ) -> None: """ Dump simulation to files. Writes a model definition as .TOML file, which @@ -992,6 +994,8 @@ def dump( crs: Any, optional Anything accepted by rasterio.crs.CRS.from_user_input Requires ``rioxarray`` installed. + engine: str, optional + "netCDF4" or "zarr" or "zarr.zip". Defaults to "netCDF4". Examples -------- @@ -1017,16 +1021,27 @@ def dump( directory = pathlib.Path(directory) directory.mkdir(parents=True, exist_ok=True) + match engine: + case "netCDF4": + ext = "nc" + case "zarr": + ext = "zarr" + case "zarr.zip": + ext = "zarr.zip" + case _: + raise ValueError(f"Unknown engine: {engine}") + toml_content: DefaultDict[str, dict] = collections.defaultdict(dict) # Dump version number version = get_version() toml_content["version"] = {"imod-python": version} + # Dump models and exchanges for key, value in self.items(): cls_name = type(value).__name__ if isinstance(value, Modflow6Model): model_toml_path = value.dump( - directory, key, validate, mdal_compliant, crs + directory, key, validate, mdal_compliant, crs, engine=engine ) toml_content[cls_name][key] = model_toml_path.relative_to( directory @@ -1036,13 +1051,23 @@ def dump( for exchange_package in self[key]: _, filename, _, _ = exchange_package.get_specification() exchange_class_short = type(exchange_package).__name__ - path = f"{filename}.nc" - exchange_package.dataset.to_netcdf(directory / path) + path = f"{filename}.{ext}" + + if engine == "netCDF4": + exchange_package.dataset.to_netcdf(directory / path) + else: + to_zarr( + exchange_package.dataset, directory / path, engine=engine + ) + toml_content[key][exchange_class_short].append(path) else: - path = f"{key}.nc" - value.dataset.to_netcdf(directory / path) + path = f"{key}.{ext}" + if engine == "netCDF4": + value.dataset.to_netcdf(directory / path) + else: + to_zarr(value.dataset, directory / path, engine=engine) toml_content[cls_name][key] = path with open(directory / f"{self.name}.toml", "wb") as f: diff --git a/imod/mf6/utilities/zarr_helper.py b/imod/mf6/utilities/zarr_helper.py new file mode 100644 index 000000000..2dedf92c6 --- /dev/null +++ b/imod/mf6/utilities/zarr_helper.py @@ -0,0 +1,24 @@ +import xugrid as xu + + +def to_zarr(dataset, path, engine, **kwargs): + import zarr + + match engine: + case "zarr": + if isinstance(dataset, xu.UgridDataset): + dataset.ugrid.to_zarr(path, **kwargs) + else: + dataset.to_zarr(path, **kwargs) + case "zarr.zip": + with zarr.storage.ZipStore(path, mode="w") as store: + if isinstance(dataset, xu.UgridDataset): + dataset.ugrid.to_zarr(store, **kwargs) + else: + dataset.to_zarr(store, **kwargs) + case _: + raise ValueError( + f'Expected engine to be "zarr" or "zarr.zip", got: {engine}' + ) + + return From 1a05316c091d76261409ee5d96f4409436dd1c38 Mon Sep 17 00:00:00 2001 From: Huite Date: Sat, 27 Sep 2025 13:12:28 +0200 Subject: [PATCH 7/7] Check if zarr already exists; and delete it if it does --- imod/mf6/utilities/zarr_helper.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/imod/mf6/utilities/zarr_helper.py b/imod/mf6/utilities/zarr_helper.py index 2dedf92c6..3ae21552a 100644 --- a/imod/mf6/utilities/zarr_helper.py +++ b/imod/mf6/utilities/zarr_helper.py @@ -1,9 +1,20 @@ +import shutil +from pathlib import Path + import xugrid as xu -def to_zarr(dataset, path, engine, **kwargs): +def to_zarr(dataset, path: str | Path, engine: str, **kwargs): import zarr + path = Path(path) + if path.exists(): + # Check if directory (ordinary .zarr, directory) or ZipStore (zip file). + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + match engine: case "zarr": if isinstance(dataset, xu.UgridDataset):