Skip to content

Commit

Permalink
Apply review comments. Add get_models_of_type to ISimulation interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Manangka committed May 14, 2024
1 parent cb3cd20 commit f021fec
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 27 deletions.
2 changes: 1 addition & 1 deletion imod/mf6/interfaces/imaskingsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class IMaskingSettings(IPackage, abc.ABC):
"""
Interface for packages that support regridding
Interface for packages that support masking
"""

@property
Expand Down
16 changes: 16 additions & 0 deletions imod/mf6/interfaces/imodel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import abstractmethod
from typing import Optional, Tuple

from imod.mf6.interfaces.idict import IDict
Expand All @@ -10,22 +11,37 @@ class IModel(IDict):
Interface for imod.mf6.model.Modflow6Model
"""

@abstractmethod
def mask_all_packages(self, mask: GridDataArray):
raise NotImplementedError

@abstractmethod
def purge_empty_packages(self, model_name: Optional[str] = "") -> None:
raise NotImplementedError

@abstractmethod
def validate(self, model_name: str = "") -> StatusInfoBase:
raise NotImplementedError

@property
@abstractmethod
def domain(self):
raise NotImplementedError

@property
@abstractmethod
def options(self) -> dict:
raise NotImplementedError

@property
@abstractmethod
def model_id(self) -> str:
raise NotImplementedError

@abstractmethod
def is_regridding_supported(self) -> Tuple[bool, str]:
raise NotImplementedError

@abstractmethod
def is_splitting_supported(self) -> Tuple[bool, str]:
raise NotImplementedError
9 changes: 9 additions & 0 deletions imod/mf6/interfaces/isimulation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from abc import abstractmethod

from imod.mf6.interfaces.idict import IDict
from imod.mf6.interfaces.imodel import IModel

Expand All @@ -7,11 +9,18 @@ class ISimulation(IDict):
Interface for imod.mf6.simulation.Modflow6Simulation
"""

@abstractmethod
def is_split(self) -> bool:
raise NotImplementedError

@abstractmethod
def has_one_flow_model(self) -> bool:
raise NotImplementedError

@abstractmethod
def get_models(self) -> dict[str, IModel]:
raise NotImplementedError

@abstractmethod
def get_models_of_type(self, model_id: str) -> dict[str, IModel]:
raise NotImplementedError
6 changes: 6 additions & 0 deletions imod/mf6/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def from_file(cls, toml_path):

return instance

@property
def options(self) -> dict:
if self._options is None:
raise ValueError("Model id has not been set")
return self._options

@property
def model_id(self) -> str:
if self._model_id is None:
Expand Down
6 changes: 3 additions & 3 deletions imod/mf6/multimodel/modelsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from imod.mf6.boundary_condition import BoundaryCondition
from imod.mf6.hfb import HorizontalFlowBarrierBase
from imod.mf6.model import Modflow6Model
from imod.mf6.interfaces.imodel import IModel
from imod.mf6.utilities.clip import clip_by_grid
from imod.mf6.utilities.grid import get_active_domain_slice
from imod.mf6.wel import Well
Expand Down Expand Up @@ -64,14 +64,14 @@ def _validate_submodel_label_array(submodel_labels: GridDataArray) -> None:
)


def slice_model(partition_info: PartitionInfo, model: Modflow6Model) -> Modflow6Model:
def slice_model(partition_info: PartitionInfo, model: IModel) -> IModel:
"""
This function slices a Modflow6Model. A sliced model is a model that consists of packages of the original model
that are sliced using the domain_slice. A domain_slice can be created using the
:func:`imod.mf6.modelsplitter.create_domain_slices` function.
"""
modelclass = type(model)
new_model = modelclass(**model._options)
new_model = modelclass(**model.options)
domain_slice2d = get_active_domain_slice(partition_info.active_domain)
if is_unstructured(model.domain):
new_idomain = model.domain.sel(domain_slice2d)
Expand Down
6 changes: 2 additions & 4 deletions imod/mf6/multimodel/partition_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ def get_label_array(simulation: Modflow6Simulation, npartitions: int) -> GridDat
idomain. Every array element is the partition number to which the column of
gridblocks of idomain at that location belong.
"""
gwf_models = [
model for model in simulation.get_models().values() if model.model_id == "gwf6"
]
gwf_models = simulation.get_models_of_type("gwf6")
if len(gwf_models) != 1:
raise ValueError(
"for partitioning a simulation to work, it must have exactly 1 flow model"
)
if npartitions <= 0:
raise ValueError("You should create at least 1 partition")

flowmodel = gwf_models[0]
flowmodel = list(gwf_models.values())[0]
idomain = flowmodel.domain
idomain_top = copy.deepcopy(idomain.isel(layer=0))

Expand Down
19 changes: 10 additions & 9 deletions imod/mf6/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from imod.mf6.gwfgwt import GWFGWT
from imod.mf6.gwtgwt import GWTGWT
from imod.mf6.ims import Solution
from imod.mf6.interfaces.imodel import IModel
from imod.mf6.interfaces.isimulation import ISimulation
from imod.mf6.model import Modflow6Model
from imod.mf6.model_gwf import GroundwaterFlowModel
Expand Down Expand Up @@ -566,7 +567,7 @@ def _open_output(self, output: str, **settings) -> GridDataArray | GridDataset:
output function.
"""
modeltype = OUTPUT_MODEL_MAPPING[output]
modelnames = self._get_models_of_type(modeltype._model_id).keys()
modelnames = self.get_models_of_type(modeltype._model_id).keys()
if len(modelnames) == 0:
modeltype = OUTPUT_MODEL_MAPPING[output]
raise ValueError(
Expand All @@ -575,7 +576,7 @@ def _open_output(self, output: str, **settings) -> GridDataArray | GridDataset:
)

if output in ["head", "budget-flow"]:
return self._open_single_output(modelnames, output, **settings)
return self._open_single_output(list(modelnames), output, **settings)
elif output in ["concentration", "budget-transport"]:
return self._concat_species(output, **settings)
else:
Expand Down Expand Up @@ -903,11 +904,11 @@ def get_exchange_relationships(self):
result.append(exchange.get_specification())
return result

def _get_models_of_type(self, modeltype):
def get_models_of_type(self, model_id) -> dict[str, IModel]:
return {
k: v
for k, v in self.items()
if isinstance(v, Modflow6Model) and (v.model_id == modeltype)
if isinstance(v, Modflow6Model) and (v.model_id == model_id)
}

def get_models(self):
Expand Down Expand Up @@ -1021,8 +1022,8 @@ def split(self, submodel_labels: GridDataArray) -> Modflow6Simulation:
raise ValueError(
"splitting of simulations with more (or less) than 1 flow model currently not supported."
)
transport_models = self._get_models_of_type("gwt6")
flow_models = self._get_models_of_type("gwf6")
transport_models = self.get_models_of_type("gwt6")
flow_models = self.get_models_of_type("gwf6")
if not any(flow_models) and not any(transport_models):
raise ValueError("a simulation without any models cannot be split.")

Expand Down Expand Up @@ -1207,8 +1208,8 @@ def __repr__(self) -> str:
return "\n".join(content)

def _get_transport_models_per_flow_model(self) -> dict[str, list[str]]:
flow_models = self._get_models_of_type("gwf6")
transport_models = self._get_models_of_type("gwt6")
flow_models = self.get_models_of_type("gwf6")
transport_models = self.get_models_of_type("gwt6")
# exchange for flow and transport
result = collections.defaultdict(list)

Expand Down Expand Up @@ -1258,7 +1259,7 @@ def is_split(self) -> bool:
return "split_exchanges" in self.keys()

def has_one_flow_model(self) -> bool:
flow_models = self._get_models_of_type("gwf6")
flow_models = self.get_models_of_type("gwf6")
return len(flow_models) == 1

def mask_all_models(
Expand Down
13 changes: 3 additions & 10 deletions imod/mf6/utilities/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,9 @@ def _mask_all_models(
"masking can only be applied to simulations that have not been split. Apply masking before splitting."
)

flowmodels = [
name
for name, model in simulation.get_models().items()
if model.model_id == "gwf6"
]
transportmodels = [
name
for name, model in simulation.get_models().items()
if model.model_id == "gwt6"
]
flowmodels = list(simulation.get_models_of_type("gwf6").keys())
transportmodels = list(simulation.get_models_of_type("gwt6").keys())

modelnames = flowmodels + transportmodels

for name in modelnames:
Expand Down

0 comments on commit f021fec

Please sign in to comment.