diff --git a/geos-mesh/src/geos/mesh/doctor/filters/MeshDoctorFilterBase.py b/geos-mesh/src/geos/mesh/doctor/filters/MeshDoctorFilterBase.py new file mode 100644 index 00000000..e6bf682b --- /dev/null +++ b/geos-mesh/src/geos/mesh/doctor/filters/MeshDoctorFilterBase.py @@ -0,0 +1,242 @@ +from typing_extensions import Self +from typing import Union +from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid +from geos.utils.Logger import getLogger, Logger +from geos.mesh.io.vtkIO import VtkOutput, write_mesh + +__doc__ = """ +MeshDoctorFilterBase module provides base classes for all mesh doctor filters using direct mesh manipulation. + +MeshDoctorFilterBase serves as the foundation class for filters that process existing meshes, +while MeshDoctorGenerator is for filters that generate new meshes from scratch. + +These base classes provide common functionality including: +- Logger management and setup +- Mesh access and manipulation methods +- File I/O operations for writing VTK unstructured grids +- Consistent interface across all mesh doctor filters + +Unlike the VTK pipeline-based MeshDoctorBase, these classes work with direct mesh manipulation +following the FillPartialArrays pattern for simpler, more Pythonic usage. + +Example usage patterns +---------------------- + +.. code-block:: python + + # For filters that process existing meshes + from filters.MeshDoctorFilterBase import MeshDoctorFilterBase + + class MyProcessingFilter(MeshDoctorFilterBase): + def __init__(self, mesh, parameter1=default_value): + super().__init__(mesh, "My Filter Name") + self.parameter1 = parameter1 + + def applyFilter(self): + # Process self.mesh directly + # Return True on success, False on failure + pass + + # For filters that generate meshes from scratch + from filters.MeshDoctorFilterBase import MeshDoctorGeneratorBase + + class MyGeneratorFilter(MeshDoctorGeneratorBase): + def __init__(self, parameter1=default_value): + super().__init__("My Generator Name") + self.parameter1 = parameter1 + + def applyFilter(self): + # Generate new mesh and assign to self.mesh + # Return True on success, False on failure + pass +""" + + +class MeshDoctorFilterBase: + """Base class for all mesh doctor filters using direct mesh manipulation.""" + + def __init__( + self: Self, + mesh: vtkUnstructuredGrid, + filterName: str, + useExternalLogger: bool = False, + ) -> None: + """Initialize the base mesh doctor filter.""" + # Check the 'mesh' input + if not isinstance( mesh, vtkUnstructuredGrid ): + raise TypeError( f"Input 'mesh' must be a vtkUnstructuredGrid, but got {type(mesh).__name__}." ) + if mesh.GetNumberOfCells() == 0: + raise ValueError( "Input 'mesh' cannot be empty." ) + + # Check the 'filterName' input + if not isinstance( filterName, str ): + raise TypeError( f"Input 'filterName' must be a string, but got {type(filterName).__name__}." ) + if not filterName.strip(): + raise ValueError( "Input 'filterName' cannot be an empty or whitespace-only string." ) + + # Check the 'useExternalLogger' input + if not isinstance( useExternalLogger, bool ): + raise TypeError( + f"Input 'useExternalLogger' must be a boolean, but got {type(useExternalLogger).__name__}." ) + + # Non-destructive behavior. + # The filter should contain a COPY of the mesh, not the original object. + self.mesh: vtkUnstructuredGrid = vtkUnstructuredGrid() + self.mesh.DeepCopy( mesh ) + self.filterName: str = filterName + + # Logger setup + self.logger: Logger + if not useExternalLogger: + self.logger = getLogger( filterName, True ) + else: + import logging + self.logger = logging.getLogger( filterName ) + self.logger.setLevel( logging.INFO ) + + def setLoggerHandler( self: Self, handler ) -> None: + """Set a specific handler for the filter logger. + + Args: + handler: The logging handler to add. + """ + if not self.logger.handlers: + self.logger.addHandler( handler ) + else: + self.logger.warning( "The logger already has a handler, to use yours set 'useExternalLogger' " + "to True during initialization." ) + + def getMesh( self: Self ) -> vtkUnstructuredGrid: + """Get the processed mesh. + + Returns: + vtkUnstructuredGrid: The processed mesh + """ + return self.mesh + + def writeGrid( self: Self, filepath: str, isDataModeBinary: bool = True, canOverwrite: bool = False ) -> None: + """Writes a .vtu file of the vtkUnstructuredGrid at the specified filepath. + + Args: + filepath (str): /path/to/your/file.vtu + isDataModeBinary (bool, optional): Writes the file in binary format or ascii. Defaults to True. + canOverwrite (bool, optional): Allows or not to overwrite if the filepath already leads to an existing file. + Defaults to False. + """ + if self.mesh: + vtk_output = VtkOutput( filepath, isDataModeBinary ) + write_mesh( self.mesh, vtk_output, canOverwrite ) + else: + self.logger.error( f"No mesh available. Cannot output vtkUnstructuredGrid at {filepath}." ) + + def copyMesh( self: Self, sourceMesh: vtkUnstructuredGrid ) -> vtkUnstructuredGrid: + """Helper method to create a copy of a mesh with structure and attributes. + + Args: + sourceMesh (vtkUnstructuredGrid): Source mesh to copy from. + + Returns: + vtkUnstructuredGrid: New mesh with copied structure and attributes. + """ + output_mesh: vtkUnstructuredGrid = sourceMesh.NewInstance() + output_mesh.CopyStructure( sourceMesh ) + output_mesh.CopyAttributes( sourceMesh ) + return output_mesh + + def applyFilter( self: Self ) -> bool: + """Apply the filter operation. + + This method should be overridden by subclasses to implement specific filter logic. + + Returns: + bool: True if filter applied successfully, False otherwise. + """ + raise NotImplementedError( "Subclasses must implement applyFilter method." ) + + +class MeshDoctorGeneratorBase: + """Base class for mesh doctor generator filters (no input mesh required). + + This class provides functionality for filters that generate meshes + from scratch without requiring input meshes. + """ + + def __init__( + self: Self, + filterName: str, + useExternalLogger: bool = False, + ) -> None: + """Initialize the base mesh doctor generator filter. + + Args: + filterName (str): Name of the filter for logging. + useExternalLogger (bool): Whether to use external logger. Defaults to False. + """ + # Check the 'filterName' input + if not isinstance( filterName, str ): + raise TypeError( f"Input 'filterName' must be a string, but got {type(filterName).__name__}." ) + if not filterName.strip(): + raise ValueError( "Input 'filterName' cannot be an empty or whitespace-only string." ) + + # Check the 'useExternalLogger' input + if not isinstance( useExternalLogger, bool ): + raise TypeError( + f"Input 'useExternalLogger' must be a boolean, but got {type(useExternalLogger).__name__}." ) + + self.mesh: Union[ vtkUnstructuredGrid, None ] = None + self.filterName: str = filterName + + # Logger setup + self.logger: Logger + if not useExternalLogger: + self.logger = getLogger( filterName, True ) + else: + import logging + self.logger = logging.getLogger( filterName ) + self.logger.setLevel( logging.INFO ) + + def setLoggerHandler( self: Self, handler ) -> None: + """Set a specific handler for the filter logger. + + Args: + handler: The logging handler to add. + """ + if not self.logger.handlers: + self.logger.addHandler( handler ) + else: + self.logger.warning( "The logger already has a handler, to use yours set 'useExternalLogger' " + "to True during initialization." ) + + def getMesh( self: Self ) -> Union[ vtkUnstructuredGrid, None ]: + """Get the generated mesh. + + Returns: + Union[vtkUnstructuredGrid, None]: The generated mesh, or None if not yet generated. + """ + return self.mesh + + def writeGrid( self: Self, filepath: str, isDataModeBinary: bool = True, canOverwrite: bool = False ) -> None: + """Writes a .vtu file of the vtkUnstructuredGrid at the specified filepath. + + Args: + filepath (str): /path/to/your/file.vtu + isDataModeBinary (bool, optional): Writes the file in binary format or ascii. Defaults to True. + canOverwrite (bool, optional): Allows or not to overwrite if the filepath already leads to an existing file. + Defaults to False. + """ + if self.mesh: + vtk_output = VtkOutput( filepath, isDataModeBinary ) + write_mesh( self.mesh, vtk_output, canOverwrite ) + else: + self.logger.error( f"No mesh generated. Cannot output vtkUnstructuredGrid at {filepath}." ) + + def applyFilter( self: Self ) -> bool: + """Apply the filter operation to generate a mesh. + + This method should be overridden by subclasses to implement specific generation logic. + The generated mesh should be assigned to self.mesh. + + Returns: + bool: True if mesh generated successfully, False otherwise. + """ + raise NotImplementedError( "Subclasses must implement applyFilter method." ) diff --git a/geos-mesh/src/geos/mesh/io/vtkIO.py b/geos-mesh/src/geos/mesh/io/vtkIO.py index 1b93648a..c325f110 100644 --- a/geos-mesh/src/geos/mesh/io/vtkIO.py +++ b/geos-mesh/src/geos/mesh/io/vtkIO.py @@ -1,192 +1,250 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies. # SPDX-FileContributor: Alexandre Benedicto - -import os.path -import logging from dataclasses import dataclass -from typing import Optional -from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkStructuredGrid, vtkPointSet -from vtkmodules.vtkIOLegacy import vtkUnstructuredGridWriter, vtkUnstructuredGridReader -from vtkmodules.vtkIOXML import ( vtkXMLUnstructuredGridReader, vtkXMLUnstructuredGridWriter, - vtkXMLStructuredGridReader, vtkXMLPUnstructuredGridReader, +from enum import Enum +from pathlib import Path +from typing import Optional, Type, TypeAlias +from vtkmodules.vtkCommonDataModel import vtkPointSet, vtkUnstructuredGrid +from vtkmodules.vtkIOCore import vtkWriter +from vtkmodules.vtkIOLegacy import vtkDataReader, vtkUnstructuredGridWriter, vtkUnstructuredGridReader +from vtkmodules.vtkIOXML import ( vtkXMLDataReader, vtkXMLUnstructuredGridReader, vtkXMLUnstructuredGridWriter, + vtkXMLWriter, vtkXMLStructuredGridReader, vtkXMLPUnstructuredGridReader, vtkXMLPStructuredGridReader, vtkXMLStructuredGridWriter ) +from geos.utils.Logger import getLogger __doc__ = """ -Input and Ouput methods for VTK meshes: - - VTK, VTU, VTS, PVTU, PVTS readers - - VTK, VTS, VTU writers +Input and Output methods for various VTK mesh formats. +Supports reading: .vtk, .vtu, .vts, .pvtu, .pvts +Supports writing: .vtk, .vtu, .vts """ +io_logger = getLogger( "IO for geos-mesh" ) +io_logger.propagate = False + + +class VtkFormat( Enum ): + """Enumeration for supported VTK file formats and their extensions.""" + VTK = ".vtk" + VTS = ".vts" + VTU = ".vtu" + PVTU = ".pvtu" + PVTS = ".pvts" + + +# Improved: Use TypeAlias for cleaner and more readable type hints +VtkReaderClass: TypeAlias = Type[ vtkDataReader | vtkXMLDataReader ] +VtkWriterClass: TypeAlias = Type[ vtkWriter | vtkXMLWriter ] + +# Centralized mapping of formats to their corresponding reader classes +READER_MAP: dict[ VtkFormat, VtkReaderClass ] = { + VtkFormat.VTK: vtkUnstructuredGridReader, + VtkFormat.VTS: vtkXMLStructuredGridReader, + VtkFormat.VTU: vtkXMLUnstructuredGridReader, + VtkFormat.PVTU: vtkXMLPUnstructuredGridReader, + VtkFormat.PVTS: vtkXMLPStructuredGridReader +} + +# Centralized mapping of formats to their corresponding writer classes +WRITER_MAP: dict[ VtkFormat, VtkWriterClass ] = { + VtkFormat.VTK: vtkUnstructuredGridWriter, + VtkFormat.VTS: vtkXMLStructuredGridWriter, + VtkFormat.VTU: vtkXMLUnstructuredGridWriter, +} + @dataclass( frozen=True ) class VtkOutput: + """Configuration for writing a VTK file.""" output: str - is_data_mode_binary: bool - - -def __read_vtk( vtk_input_file: str ) -> Optional[ vtkUnstructuredGrid ]: - reader = vtkUnstructuredGridReader() - logging.info( f"Testing file format \"{vtk_input_file}\" using legacy format reader..." ) - reader.SetFileName( vtk_input_file ) - if reader.IsFileUnstructuredGrid(): - logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using legacy format reader." ) - reader.Update() - return reader.GetOutput() - else: - logging.info( "Reader did not match the input file format." ) - return None + is_data_mode_binary: bool = True -def __read_vts( vtk_input_file: str ) -> Optional[ vtkStructuredGrid ]: - reader = vtkXMLStructuredGridReader() - logging.info( f"Testing file format \"{vtk_input_file}\" using XML format reader..." ) - if reader.CanReadFile( vtk_input_file ): - reader.SetFileName( vtk_input_file ) - logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." ) - reader.Update() - return reader.GetOutput() - else: - logging.info( "Reader did not match the input file format." ) - return None +def _read_data( filepath: str, reader_class: VtkReaderClass ) -> Optional[ vtkPointSet ]: + """Generic helper to read a VTK file using a specific reader class. + Args: + filepath (str): Path to the VTK file. + reader_class (VtkReaderClass): The VTK reader class to use. -def __read_vtu( vtk_input_file: str ) -> Optional[ vtkUnstructuredGrid ]: - reader = vtkXMLUnstructuredGridReader() - logging.info( f"Testing file format \"{vtk_input_file}\" using XML format reader..." ) - if reader.CanReadFile( vtk_input_file ): - reader.SetFileName( vtk_input_file ) - logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." ) - reader.Update() - return reader.GetOutput() - else: - logging.info( "Reader did not match the input file format." ) + Returns: + Optional[ vtkPointSet ]: The read VTK point set, or None if reading failed. + """ + reader = reader_class() + io_logger.info( f"Attempting to read '{filepath}' with {reader_class.__name__}..." ) + + reader.SetFileName( str( filepath ) ) + + # For XML-based readers, CanReadFile is a reliable and fast pre-check. + if hasattr( reader, 'CanReadFile' ) and not reader.CanReadFile( filepath ): + io_logger.error( f"Reader {reader_class.__name__} reports it cannot read file '{filepath}'." ) return None + reader.Update() -def __read_pvts( vtk_input_file: str ) -> Optional[ vtkStructuredGrid ]: - reader = vtkXMLPStructuredGridReader() - logging.info( f"Testing file format \"{vtk_input_file}\" using XML format reader..." ) - if reader.CanReadFile( vtk_input_file ): - reader.SetFileName( vtk_input_file ) - logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." ) - reader.Update() - return reader.GetOutput() - else: - logging.info( "Reader did not match the input file format." ) + # FIX: Check the reader's error code. This is the most reliable way to + # detect a failed read, as GetOutput() can return a default empty object on failure. + if hasattr( reader, 'GetErrorCode' ) and reader.GetErrorCode() != 0: + io_logger.warning( + f"VTK reader {reader_class.__name__} reported an error code after attempting to read '{filepath}'." ) return None + output = reader.GetOutput() -def __read_pvtu( vtk_input_file: str ) -> Optional[ vtkUnstructuredGrid ]: - reader = vtkXMLPUnstructuredGridReader() - logging.info( f"Testing file format \"{vtk_input_file}\" using XML format reader..." ) - if reader.CanReadFile( vtk_input_file ): - reader.SetFileName( vtk_input_file ) - logging.info( f"Reader matches. Reading file \"{vtk_input_file}\" using XML format reader." ) - reader.Update() - return reader.GetOutput() - else: - logging.info( "Reader did not match the input file format." ) + if output is None: return None + io_logger.info( "Read successful." ) + return output + + +def _write_data( mesh: vtkPointSet, writer_class: VtkWriterClass, output: str, is_binary: bool ) -> int: + """Generic helper to write a VTK file using a specific writer class. + + Args: + mesh (vtkPointSet): The grid data to write. + writer_class (VtkWriterClass): The VTK writer class to use. + output (str): The output file path. + is_binary (bool): Whether to write the file in binary mode. + + Returns: + int: The result of the write operation. + """ + io_logger.info( f"Writing mesh to '{output}' using {writer_class.__name__}..." ) + writer = writer_class() + writer.SetFileName( output ) + writer.SetInputData( mesh ) + + # Set data mode only for XML writers that support it + if isinstance( writer, vtkXMLWriter ): + if is_binary: + writer.SetDataModeToBinary() + io_logger.info( "Data mode set to Binary." ) + else: + writer.SetDataModeToAscii() + io_logger.info( "Data mode set to ASCII." ) -def read_mesh( vtk_input_file: str ) -> vtkPointSet: - """Read vtk file and build either an unstructured grid or a structured grid from it. + return writer.Write() + + +def read_mesh( filepath: str ) -> vtkPointSet: + """ + Reads a VTK file, automatically detecting the format. + + It first tries the reader associated with the file extension, then falls + back to trying all other available readers if the first attempt fails. Args: - vtk_input_file (str): The file name. Extension will be used to guess file format\ - If first guess fails, other available readers will be tried. + filepath (str): The path to the VTK file. Raises: - ValueError: Invalid file path error - ValueError: No appropriate reader available for the file format + FileNotFoundError: If the input file does not exist. + ValueError: If no suitable reader can be found for the file. Returns: - vtkPointSet: Mesh read + vtkPointSet: The resulting mesh data. """ - if not os.path.exists( vtk_input_file ): - err_msg: str = f"Invalid file path. Could not read \"{vtk_input_file}\"." - logging.error( err_msg ) - raise ValueError( err_msg ) - file_extension = os.path.splitext( vtk_input_file )[ -1 ] - extension_to_reader = { - ".vtk": __read_vtk, - ".vts": __read_vts, - ".vtu": __read_vtu, - ".pvtu": __read_pvtu, - ".pvts": __read_pvts - } - # Testing first the reader that should match - if file_extension in extension_to_reader: - output_mesh = extension_to_reader.pop( file_extension )( vtk_input_file ) - if output_mesh: - return output_mesh - # If it does not match, then test all the others. - for reader in extension_to_reader.values(): - output_mesh = reader( vtk_input_file ) + filepath_path: Path = Path( filepath ) + if not filepath_path.exists(): + raise FileNotFoundError( f"Invalid file path: '{filepath}' does not exist." ) + + candidate_readers: list[ VtkReaderClass ] = [] + # 1. Prioritize the reader associated with the file extension + try: + file_format = VtkFormat( filepath_path.suffix ) + if file_format in READER_MAP: + candidate_readers.append( READER_MAP[ file_format ] ) + except ValueError: + io_logger.warning( f"Unknown file extension '{filepath_path.suffix}'. Trying all available readers." ) + + # 2. Add all other unique readers as fallbacks + for reader_cls in READER_MAP.values(): + if reader_cls not in candidate_readers: + candidate_readers.append( reader_cls ) + + # 3. Attempt to read with the candidates in order + for reader_class in candidate_readers: + output_mesh = _read_data( filepath, reader_class ) if output_mesh: return output_mesh - # No reader did work. - err_msg = f"Could not find the appropriate VTK reader for file \"{vtk_input_file}\"." - logging.error( err_msg ) - raise ValueError( err_msg ) + raise ValueError( f"Could not find a suitable reader for '{filepath}'." ) -def __write_vtk( mesh: vtkUnstructuredGrid, output: str ) -> int: - logging.info( f"Writing mesh into file \"{output}\" using legacy format." ) - writer = vtkUnstructuredGridWriter() - writer.SetFileName( output ) - writer.SetInputData( mesh ) - return writer.Write() +def read_unstructured_grid( filepath: str ) -> vtkUnstructuredGrid: + """ + Reads a VTK file and ensures it is a vtkUnstructuredGrid. -def __write_vts( mesh: vtkStructuredGrid, output: str, toBinary: bool = False ) -> int: - logging.info( f"Writing mesh into file \"{output}\" using XML format." ) - writer = vtkXMLStructuredGridWriter() - writer.SetFileName( output ) - writer.SetInputData( mesh ) - writer.SetDataModeToBinary() if toBinary else writer.SetDataModeToAscii() - return writer.Write() + This function uses the general `read_mesh` to load the data and then + validates its type. + Args: + filepath (str): The path to the VTK file. -def __write_vtu( mesh: vtkUnstructuredGrid, output: str, toBinary: bool = False ) -> int: - logging.info( f"Writing mesh into file \"{output}\" using XML format." ) - writer = vtkXMLUnstructuredGridWriter() - writer.SetFileName( output ) - writer.SetInputData( mesh ) - writer.SetDataModeToBinary() if toBinary else writer.SetDataModeToAscii() - return writer.Write() + Raises: + FileNotFoundError: If the input file does not exist. + ValueError: If no suitable reader can be found for the file. + TypeError: If the file is read successfully but is not a vtkUnstructuredGrid. + Returns: + vtkUnstructuredGrid: The resulting unstructured grid data. + """ + io_logger.info( f"Reading file '{filepath}' and expecting vtkUnstructuredGrid." ) + mesh = read_mesh( filepath ) + + if not isinstance( mesh, vtkUnstructuredGrid ): + error_msg = ( f"File '{filepath}' was read successfully, but it is of type " + f"'{type(mesh).__name__}', not the expected vtkUnstructuredGrid." ) + io_logger.error( error_msg ) + raise TypeError( error_msg ) -def write_mesh( mesh: vtkPointSet, vtk_output: VtkOutput, canOverwrite: bool = False ) -> int: - """Write mesh to disk. + io_logger.info( "Validation successful. Mesh is a vtkUnstructuredGrid." ) + return mesh + + +def write_mesh( mesh: vtkPointSet, vtk_output: VtkOutput, can_overwrite: bool = False ) -> int: + """ + Writes a vtkPointSet to a file. - Nothing is done if file already exists. + The format is determined by the file extension in `VtkOutput.output`. Args: - mesh (vtkPointSet): Grid to write - vtk_output (VtkOutput): File path. File extension will be used to select VTK file format - canOverwrite (bool, optional): Authorize overwriting the file. Defaults to False. + mesh (vtkPointSet): The grid data to write. + vtk_output (VtkOutput): Configuration for the output file. + can_overwrite (bool, optional): If False, raises an error if the file + already exists. Defaults to False. Raises: - ValueError: Invalid VTK format. + FileExistsError: If the output file exists and `can_overwrite` is False. + ValueError: If the file extension is not a supported write format. + RuntimeError: If the VTK writer fails to write the file. Returns: - int: 0 if success + int: Returns 1 on success, consistent with the VTK writer's return code. """ - if os.path.exists( vtk_output.output ) and canOverwrite: - logging.error( f"File \"{vtk_output.output}\" already exists, nothing done." ) - return 1 - file_extension = os.path.splitext( vtk_output.output )[ -1 ] - if file_extension == ".vtk": - success_code = __write_vtk( mesh, vtk_output.output ) - elif file_extension == ".vts": - success_code = __write_vts( mesh, vtk_output.output, vtk_output.is_data_mode_binary ) - elif file_extension == ".vtu": - success_code = __write_vtu( mesh, vtk_output.output, vtk_output.is_data_mode_binary ) - else: - # No writer found did work. Dying. - err_msg = f"Could not find the appropriate VTK writer for extension \"{file_extension}\"." - logging.error( err_msg ) - raise ValueError( err_msg ) - return 0 if success_code else 2 # the Write member function return 1 in case of success, 0 otherwise. + output_path = Path( vtk_output.output ) + if output_path.exists() and not can_overwrite: + raise FileExistsError( f"File '{output_path}' already exists. Set can_overwrite=True to replace it." ) + + try: + # Catch the ValueError from an invalid enum to provide a consistent error message. + try: + file_format = VtkFormat( output_path.suffix ) + except ValueError: + # Re-raise with the message expected by the test. + raise ValueError( f"Writing to extension '{output_path.suffix}' is not supported." ) + + writer_class = WRITER_MAP.get( file_format ) + if not writer_class: + raise ValueError( f"Writing to extension '{output_path.suffix}' is not supported." ) + + success_code = _write_data( mesh, writer_class, str( output_path ), vtk_output.is_data_mode_binary ) + if not success_code: + raise RuntimeError( f"VTK writer failed to write file '{output_path}'." ) + + io_logger.info( f"Successfully wrote mesh to '{output_path}'." ) + return success_code + + except ( ValueError, RuntimeError ) as e: + io_logger.error( e ) + raise diff --git a/geos-mesh/src/geos/mesh/utils/genericHelpers.py b/geos-mesh/src/geos/mesh/utils/genericHelpers.py index 481b391f..6684d125 100644 --- a/geos-mesh/src/geos/mesh/utils/genericHelpers.py +++ b/geos-mesh/src/geos/mesh/utils/genericHelpers.py @@ -6,8 +6,10 @@ from typing import Iterator, List, Sequence, Any, Union from vtkmodules.util.numpy_support import numpy_to_vtk from vtkmodules.vtkCommonCore import vtkIdList, vtkPoints, reference -from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkMultiBlockDataSet, vtkPolyData, vtkDataSet, vtkDataObject, vtkPlane, vtkCellTypes, vtkIncrementalOctreePointLocator -from vtkmodules.vtkFiltersCore import vtk3DLinearGridPlaneCutter +from vtkmodules.vtkCommonDataModel import ( vtkUnstructuredGrid, vtkMultiBlockDataSet, vtkPolyData, vtkDataSet, + vtkDataObject, vtkPlane, vtkCellTypes, vtkIncrementalOctreePointLocator, + vtkStaticPointLocator ) +from vtkmodules.vtkFiltersCore import vtk3DLinearGridPlaneCutter, vtkCellCenters from geos.mesh.utils.multiblockHelpers import ( getBlockElementIndexesFlatten, getBlockFromFlatIndex ) __doc__ = """ @@ -36,7 +38,7 @@ def to_vtk_id_list( data: List[ int ] ) -> vtkIdList: return result -def vtk_iter( vtkContainer: Union[ vtkIdList, vtkCellTypes ] ) -> Iterator[ Any ]: +def vtk_iter( vtkContainer: vtkIdList | vtkCellTypes ) -> Iterator[ Any ]: """Utility function transforming a vtk "container" into an iterable. Args: @@ -85,6 +87,78 @@ def extractSurfaceFromElevation( mesh: vtkUnstructuredGrid, elevation: float ) - return cutter.GetOutputDataObject( 0 ) +def findUniqueCellCenterCellIds( grid1: vtkUnstructuredGrid, + grid2: vtkUnstructuredGrid, + tolerance: float = 1e-6 ) -> tuple[ list[ int ], list[ int ] ]: + """ + Compares two vtkUnstructuredGrids and finds the IDs of cells with unique centers. + + This function identifies cells whose centers exist in one grid but not the other, + within a specified floating-point tolerance. + + Args: + grid1 (vtk.vtkUnstructuredGrid): The first grid. + grid2 (vtk.vtkUnstructuredGrid): The second grid. + tolerance (float): The distance threshold to consider two points as the same. + + Returns: + tuple[list[int], list[int]]: A tuple containing two lists: + - The first list has the IDs of cells with centers unique to grid1. + - The second list has the IDs of cells with centers unique to grid2. + """ + if not grid1 or not grid2: + raise ValueError( "Input grids must be valid vtkUnstructuredGrid objects." ) + + # Generate cell centers for both grids using vtkCellCenters filter + centersFilter1 = vtkCellCenters() + centersFilter1.SetInputData( grid1 ) + centersFilter1.Update() + centers1 = centersFilter1.GetOutput().GetPoints() + + centersFilter2 = vtkCellCenters() + centersFilter2.SetInputData( grid2 ) + centersFilter2.Update() + centers2 = centersFilter2.GetOutput().GetPoints() + + # Find cells with centers that are unique to grid1 + uniqueIdsInGrid1: list[ int ] = [] + uniqueIdsCoordsInGrid1: list[ tuple[ float, float, float ] ] = [] + # Build a locator for the cell centers of grid2 for fast searching + locator2 = vtkStaticPointLocator() + locator2.SetDataSet( centersFilter2.GetOutput() ) + locator2.BuildLocator() + + for i in range( centers1.GetNumberOfPoints() ): + centerPt1 = centers1.GetPoint( i ) + # Find the closest point in grid2 to the current center from grid1 + result = vtkIdList() + locator2.FindPointsWithinRadius( tolerance, centerPt1, result ) + # If no point is found within the tolerance radius, the cell center is unique + if result.GetNumberOfIds() == 0: + uniqueIdsInGrid1.append( i ) + uniqueIdsCoordsInGrid1.append( centerPt1 ) + + # Find cells with centers that are unique to grid2 + uniqueIdsInGrid2: list[ int ] = [] + uniqueIdsCoordsInGrid2: list[ tuple[ float, float, float ] ] = [] + # Build a locator for the cell centers of grid1 for fast searching + locator1 = vtkStaticPointLocator() + locator1.SetDataSet( centersFilter1.GetOutput() ) + locator1.BuildLocator() + + for i in range( centers2.GetNumberOfPoints() ): + centerPt2 = centers2.GetPoint( i ) + # Find the closest point in grid1 to the current center from grid2 + result = vtkIdList() + locator1.FindPointsWithinRadius( tolerance, centerPt2, result ) + # If no point is found, it's unique to grid2 + if result.GetNumberOfIds() == 0: + uniqueIdsInGrid2.append( i ) + uniqueIdsCoordsInGrid2.append( centerPt2 ) + + return uniqueIdsInGrid1, uniqueIdsInGrid2, uniqueIdsCoordsInGrid1, uniqueIdsCoordsInGrid2 + + def getBounds( input: Union[ vtkUnstructuredGrid, vtkMultiBlockDataSet ] ) -> tuple[ float, float, float, float, float, float ]: diff --git a/geos-mesh/tests/test_MeshDoctorFilterBase.py b/geos-mesh/tests/test_MeshDoctorFilterBase.py new file mode 100644 index 00000000..9fb1f16c --- /dev/null +++ b/geos-mesh/tests/test_MeshDoctorFilterBase.py @@ -0,0 +1,486 @@ +import pytest +import logging +import numpy as np +from unittest.mock import Mock +from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, VTK_TETRA +from geos.mesh.utils.genericHelpers import createSingleCellMesh +from geos.mesh.doctor.filters.MeshDoctorFilterBase import MeshDoctorFilterBase, MeshDoctorGeneratorBase + +__doc__ = """ +Test module for MeshDoctorFilterBase classes. +Tests the functionality of base classes for mesh doctor filters and generators. +""" + + +@pytest.fixture( scope="module" ) +def single_tetrahedron_mesh() -> vtkUnstructuredGrid: + """Fixture for a single tetrahedron mesh.""" + return createSingleCellMesh( VTK_TETRA, np.array( [ [ 0, 0, 0 ], [ 1, 0, 0 ], [ 0, 1, 0 ], [ 0, 0, 1 ] ] ) ) + + +class ConcreteFilterForTesting( MeshDoctorFilterBase ): + """Concrete implementation of MeshDoctorFilterBase for testing purposes.""" + + def __init__( self, mesh, filterName="TestFilter", useExternalLogger=False, shouldSucceed=True ): + super().__init__( mesh, filterName, useExternalLogger ) + self.shouldSucceed = shouldSucceed + self.applyFilterCalled = False + + def applyFilter( self ): + """Test implementation that can be configured to succeed or fail.""" + self.applyFilterCalled = True + if self.shouldSucceed: + self.logger.info( "Test filter applied successfully" ) + return True + else: + self.logger.error( "Test filter failed" ) + return False + + +class ConcreteGeneratorForTesting( MeshDoctorGeneratorBase ): + """Concrete implementation of MeshDoctorGeneratorBase for testing purposes.""" + + def __init__( self, filterName="TestGenerator", useExternalLogger=False, shouldSucceed=True ): + super().__init__( filterName, useExternalLogger ) + self.shouldSucceed = shouldSucceed + self.applyFilterCalled = False + + def applyFilter( self ): + """Test implementation that generates a simple mesh or fails.""" + self.applyFilterCalled = True + if self.shouldSucceed: + # Generate a simple single-cell mesh + self.mesh = createSingleCellMesh( VTK_TETRA, + np.array( [ [ 0, 0, 0 ], [ 1, 0, 0 ], [ 0, 1, 0 ], [ 0, 0, 1 ] ] ) ) + self.logger.info( "Test generator applied successfully" ) + return True + else: + self.logger.error( "Test generator failed" ) + return False + + +class TestMeshDoctorFilterBase: + """Test class for MeshDoctorFilterBase functionality.""" + + def test_initialization_valid_inputs( self, single_tetrahedron_mesh ): + """Test successful initialization with valid inputs.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", False ) + + assert filter_instance.filterName == "TestFilter" + assert filter_instance.mesh is not None + assert filter_instance.mesh.GetNumberOfCells() > 0 + assert filter_instance.logger is not None + + # Verify that mesh is a copy, not the original + assert filter_instance.mesh is not single_tetrahedron_mesh + + def test_initialization_with_external_logger( self, single_tetrahedron_mesh ): + """Test initialization with external logger.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", True ) + + assert filter_instance.filterName == "TestFilter" + assert isinstance( filter_instance.logger, logging.Logger ) + + def test_initialization_invalid_mesh_type( self ): + """Test initialization with invalid mesh type.""" + for error_obj in [ "not_a_mesh", 123, None ]: + with pytest.raises( TypeError, match="Input 'mesh' must be a vtkUnstructuredGrid" ): + ConcreteFilterForTesting( error_obj, "TestFilter" ) + + def test_initialization_empty_mesh( self ): + """Test initialization with empty mesh.""" + with pytest.raises( ValueError, match="Input 'mesh' cannot be empty" ): + ConcreteFilterForTesting( vtkUnstructuredGrid(), "TestFilter" ) + + def test_initialization_invalid_filter_name( self, single_tetrahedron_mesh ): + """Test initialization with invalid filter name.""" + for error_obj in [ 123, None ]: + with pytest.raises( TypeError, match="Input 'filterName' must be a string" ): + ConcreteFilterForTesting( single_tetrahedron_mesh, error_obj ) + + for error_obj in [ "", " " ]: + with pytest.raises( ValueError, match="Input 'filterName' cannot be an empty or whitespace-only string" ): + ConcreteFilterForTesting( single_tetrahedron_mesh, error_obj ) + + def test_initialization_invalid_external_logger_flag( self, single_tetrahedron_mesh ): + """Test initialization with invalid useExternalLogger flag.""" + for error_obj in [ "not_bool", 1 ]: + with pytest.raises( TypeError, match="Input 'useExternalLogger' must be a boolean" ): + ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", error_obj ) + + def test_get_mesh( self, single_tetrahedron_mesh ): + """Test getMesh method returns the correct mesh.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + returned_mesh = filter_instance.getMesh() + + assert returned_mesh is filter_instance.mesh + assert returned_mesh.GetNumberOfCells() == single_tetrahedron_mesh.GetNumberOfCells() + assert returned_mesh.GetNumberOfPoints() == single_tetrahedron_mesh.GetNumberOfPoints() + + def test_copy_mesh( self, single_tetrahedron_mesh ): + """Test copyMesh helper method.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + copied_mesh = filter_instance.copyMesh( single_tetrahedron_mesh ) + + assert copied_mesh is not single_tetrahedron_mesh + assert copied_mesh.GetNumberOfCells() == single_tetrahedron_mesh.GetNumberOfCells() + assert copied_mesh.GetNumberOfPoints() == single_tetrahedron_mesh.GetNumberOfPoints() + + def test_apply_filter_success( self, single_tetrahedron_mesh ): + """Test successful filter application.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", shouldSucceed=True ) + result = filter_instance.applyFilter() + + assert result is True + assert filter_instance.applyFilterCalled + + def test_apply_filter_failure( self, single_tetrahedron_mesh ): + """Test filter application failure.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", shouldSucceed=False ) + result = filter_instance.applyFilter() + + assert result is False + assert filter_instance.applyFilterCalled + + def test_write_grid_with_mesh( self, single_tetrahedron_mesh, tmp_path ): + """Test writing mesh to file when mesh is available.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + output_file = tmp_path / "test_output.vtu" + + filter_instance.writeGrid( str( output_file ) ) + + # Verify file was created + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_grid_with_different_options( self, single_tetrahedron_mesh, tmp_path ): + """Test writing mesh with different file options.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + + # Test ASCII mode + output_file_ascii = tmp_path / "test_ascii.vtu" + filter_instance.writeGrid( str( output_file_ascii ), isDataModeBinary=False ) + assert output_file_ascii.exists() + + # Test with overwrite enabled + output_file_overwrite = tmp_path / "test_overwrite.vtu" + filter_instance.writeGrid( str( output_file_overwrite ), canOverwrite=True ) + assert output_file_overwrite.exists() + + # Write again with overwrite enabled (should not raise error) + filter_instance.writeGrid( str( output_file_overwrite ), canOverwrite=True ) + + def test_write_grid_without_mesh( self, single_tetrahedron_mesh, tmp_path, caplog ): + """Test writing when no mesh is available.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + filter_instance.mesh = None # Remove the mesh + + output_file = tmp_path / "should_not_exist.vtu" + + with caplog.at_level( logging.ERROR ): + filter_instance.writeGrid( str( output_file ) ) + + # Should log error and not create file + assert "No mesh available" in caplog.text + assert not output_file.exists() + + def test_set_logger_handler_without_existing_handlers( self, single_tetrahedron_mesh ): + """Test setting logger handler when no handlers exist.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter", useExternalLogger=True ) + + # Clear any existing handlers + filter_instance.logger.handlers.clear() + + # Create a mock handler + mock_handler = Mock() + filter_instance.setLoggerHandler( mock_handler ) + + # Verify handler was added + assert mock_handler in filter_instance.logger.handlers + + def test_set_logger_handler_with_existing_handlers( self, single_tetrahedron_mesh, caplog ): + """Test setting logger handler when handlers already exist.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, + "TestFilter_with_handlers", + useExternalLogger=True ) + filter_instance.logger.addHandler( logging.NullHandler() ) + + mock_handler = Mock() + mock_handler.level = logging.WARNING + + with caplog.at_level( logging.WARNING ): + filter_instance.setLoggerHandler( mock_handler ) + + # Now caplog will capture the warning correctly + assert "already has a handler" in caplog.text + + def test_logger_functionality( self, single_tetrahedron_mesh, caplog ): + """Test that logging works correctly.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter_functionality" ) + + with caplog.at_level( logging.INFO ): + filter_instance.applyFilter() + + # Should have logged the success message + assert "Test filter applied successfully" in caplog.text + + def test_mesh_deep_copy_behavior( self, single_tetrahedron_mesh ): + """Test that the filter creates a deep copy of the input mesh.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "TestFilter" ) + + # Modify the original mesh + original_cell_count = single_tetrahedron_mesh.GetNumberOfCells() + + # The filter's mesh should be independent of the original + filter_mesh = filter_instance.getMesh() + assert filter_mesh.GetNumberOfCells() == original_cell_count + assert filter_mesh is not single_tetrahedron_mesh + + +class TestMeshDoctorGeneratorBase: + """Test class for MeshDoctorGeneratorBase functionality.""" + + def test_initialization_valid_inputs( self ): + """Test successful initialization with valid inputs.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", False ) + + assert generator_instance.filterName == "TestGenerator" + assert generator_instance.mesh is None # Should start with no mesh + assert generator_instance.logger is not None + + def test_initialization_with_external_logger( self ): + """Test initialization with external logger.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", True ) + + assert generator_instance.filterName == "TestGenerator" + assert isinstance( generator_instance.logger, logging.Logger ) + + def test_initialization_invalid_filter_name( self ): + """Test initialization with invalid filter name.""" + for error_obj in [ 123, None ]: + with pytest.raises( TypeError, match="Input 'filterName' must be a string" ): + ConcreteGeneratorForTesting( error_obj ) + + for error_obj in [ "", " " ]: + with pytest.raises( ValueError, match="Input 'filterName' cannot be an empty or whitespace-only string" ): + ConcreteGeneratorForTesting( error_obj ) + + def test_initialization_invalid_external_logger_flag( self ): + """Test initialization with invalid useExternalLogger flag.""" + for error_obj in [ "not_bool", 1 ]: + with pytest.raises( TypeError, match="Input 'useExternalLogger' must be a boolean" ): + ConcreteGeneratorForTesting( "TestGenerator", error_obj ) + + def test_get_mesh_before_generation( self ): + """Test getMesh method before mesh generation.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator" ) + returned_mesh = generator_instance.getMesh() + + assert returned_mesh is None + + def test_get_mesh_after_generation( self ): + """Test getMesh method after successful mesh generation.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", shouldSucceed=True ) + result = generator_instance.applyFilter() + + assert result is True + assert generator_instance.mesh is not None + + returned_mesh = generator_instance.getMesh() + assert returned_mesh is generator_instance.mesh + assert returned_mesh.GetNumberOfCells() > 0 + + def test_apply_filter_success( self ): + """Test successful mesh generation.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", shouldSucceed=True ) + result = generator_instance.applyFilter() + + assert result is True + assert generator_instance.applyFilterCalled + assert generator_instance.mesh is not None + + def test_apply_filter_failure( self ): + """Test mesh generation failure.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", shouldSucceed=False ) + result = generator_instance.applyFilter() + + assert result is False + assert generator_instance.applyFilterCalled + assert generator_instance.mesh is None + + def test_write_grid_with_generated_mesh( self, tmp_path ): + """Test writing generated mesh to file.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", shouldSucceed=True ) + generator_instance.applyFilter() + + output_file = tmp_path / "generated_mesh.vtu" + generator_instance.writeGrid( str( output_file ) ) + + # Verify file was created + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_grid_without_generated_mesh( self, tmp_path, caplog ): + """Test writing when no mesh has been generated.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator" ) + output_file = tmp_path / "should_not_exist.vtu" + + with caplog.at_level( logging.ERROR ): + generator_instance.writeGrid( str( output_file ) ) + + # Should log error and not create file + assert "No mesh generated" in caplog.text + assert not output_file.exists() + + def test_write_grid_with_different_options( self, tmp_path ): + """Test writing generated mesh with different file options.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", shouldSucceed=True ) + generator_instance.applyFilter() + + # Test ASCII mode + output_file_ascii = tmp_path / "generated_ascii.vtu" + generator_instance.writeGrid( str( output_file_ascii ), isDataModeBinary=False ) + assert output_file_ascii.exists() + + # Test with overwrite enabled + output_file_overwrite = tmp_path / "generated_overwrite.vtu" + generator_instance.writeGrid( str( output_file_overwrite ), canOverwrite=True ) + assert output_file_overwrite.exists() + + def test_set_logger_handler_without_existing_handlers( self ): + """Test setting logger handler when no handlers exist.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator", useExternalLogger=True ) + + # Clear any existing handlers + generator_instance.logger.handlers.clear() + + # Create a mock handler + mock_handler = Mock() + generator_instance.setLoggerHandler( mock_handler ) + + # Verify handler was added + assert mock_handler in generator_instance.logger.handlers + + def test_set_logger_handler_with_existing_handlers( self, caplog ): + """Test setting logger handler when handlers already exist.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator_with_handlers", useExternalLogger=True ) + generator_instance.logger.addHandler( logging.NullHandler() ) + + mock_handler = Mock() + mock_handler.level = logging.WARNING + + with caplog.at_level( logging.WARNING ): + generator_instance.setLoggerHandler( mock_handler ) + + # Now caplog will capture the warning correctly + assert "already has a handler" in caplog.text + + def test_logger_functionality( self, caplog ): + """Test that logging works correctly.""" + generator_instance = ConcreteGeneratorForTesting( "TestGenerator_functionality", shouldSucceed=True ) + + with caplog.at_level( logging.INFO ): + generator_instance.applyFilter() + + # Should have logged the success message + assert "Test generator applied successfully" in caplog.text + + +class TestMeshDoctorBaseEdgeCases: + """Test class for edge cases and integration scenarios.""" + + def test_filter_base_not_implemented_error( self, single_tetrahedron_mesh ): + """Test that base class raises NotImplementedError.""" + filter_instance = MeshDoctorFilterBase( single_tetrahedron_mesh, "BaseFilter" ) + + with pytest.raises( NotImplementedError, match="Subclasses must implement applyFilter method" ): + filter_instance.applyFilter() + + def test_generator_base_not_implemented_error( self ): + """Test that base generator class raises NotImplementedError.""" + generator_instance = MeshDoctorGeneratorBase( "BaseGenerator" ) + + with pytest.raises( NotImplementedError, match="Subclasses must implement applyFilter method" ): + generator_instance.applyFilter() + + def test_filter_with_single_cell_mesh( self, single_tetrahedron_mesh ): + """Test filter with a single cell mesh.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, "SingleCellTest" ) + result = filter_instance.applyFilter() + + assert result is True + assert filter_instance.getMesh().GetNumberOfCells() == 1 + + def test_filter_mesh_independence( self, single_tetrahedron_mesh ): + """Test that multiple filters are independent.""" + filter1 = ConcreteFilterForTesting( single_tetrahedron_mesh, "Filter1" ) + filter2 = ConcreteFilterForTesting( single_tetrahedron_mesh, "Filter2" ) + + mesh1 = filter1.getMesh() + mesh2 = filter2.getMesh() + + # Meshes should be independent copies + assert mesh1 is not mesh2 + assert mesh1 is not single_tetrahedron_mesh + assert mesh2 is not single_tetrahedron_mesh + + def test_generator_multiple_instances( self ): + """Test that multiple generator instances are independent.""" + gen1 = ConcreteGeneratorForTesting( "Gen1", shouldSucceed=True ) + gen2 = ConcreteGeneratorForTesting( "Gen2", shouldSucceed=True ) + + gen1.applyFilter() + gen2.applyFilter() + + assert gen1.getMesh() is not gen2.getMesh() + assert gen1.getMesh() is not None + assert gen2.getMesh() is not None + + def test_filter_logger_names( self, single_tetrahedron_mesh ): + """Test that different filters get different logger names.""" + filter1 = ConcreteFilterForTesting( single_tetrahedron_mesh, "Filter1" ) + filter2 = ConcreteFilterForTesting( single_tetrahedron_mesh, "Filter2" ) + + assert filter1.logger.name != filter2.logger.name + + def test_generator_logger_names( self ): + """Test that different generators get different logger names.""" + gen1 = ConcreteGeneratorForTesting( "Gen1" ) + gen2 = ConcreteGeneratorForTesting( "Gen2" ) + + assert gen1.logger.name != gen2.logger.name + + +@pytest.mark.parametrize( "filter_name,should_succeed", [ + ( "ParametrizedFilter1", True ), + ( "ParametrizedFilter2", False ), + ( "LongFilterNameForTesting", True ), + ( "UnicodeFilter", True ), +] ) +def test_parametrized_filter_behavior( single_tetrahedron_mesh, filter_name, should_succeed ): + """Parametrized test for different filter configurations.""" + filter_instance = ConcreteFilterForTesting( single_tetrahedron_mesh, filter_name, shouldSucceed=should_succeed ) + + result = filter_instance.applyFilter() + assert result == should_succeed + assert filter_instance.filterName == filter_name + + +@pytest.mark.parametrize( "generator_name,should_succeed", [ + ( "ParametrizedGen1", True ), + ( "ParametrizedGen2", False ), + ( "LongGeneratorNameForTesting", True ), + ( "UnicodeGenerator", True ), +] ) +def test_parametrized_generator_behavior( generator_name, should_succeed ): + """Parametrized test for different generator configurations.""" + generator_instance = ConcreteGeneratorForTesting( generator_name, shouldSucceed=should_succeed ) + + result = generator_instance.applyFilter() + assert result == should_succeed + assert generator_instance.filterName == generator_name + + if should_succeed: + assert generator_instance.getMesh() is not None + else: + assert generator_instance.getMesh() is None diff --git a/geos-mesh/tests/test_genericHelpers.py b/geos-mesh/tests/test_genericHelpers.py index 85de45f5..7132ca2c 100644 --- a/geos-mesh/tests/test_genericHelpers.py +++ b/geos-mesh/tests/test_genericHelpers.py @@ -9,7 +9,8 @@ from typing import ( Iterator, ) -from geos.mesh.utils.genericHelpers import getBoundsFromPointCoords, createVertices, createMultiCellMesh +from geos.mesh.utils.genericHelpers import ( getBoundsFromPointCoords, createVertices, createMultiCellMesh, + findUniqueCellCenterCellIds ) from vtkmodules.util.numpy_support import vtk_to_numpy @@ -181,3 +182,56 @@ def test_getBoundsFromPointCoords() -> None: boundsExp: list[ float ] = [ 0., 5., 1., 8., 2., 9. ] boundsObs: list[ float ] = getBoundsFromPointCoords( cellPtsCoord ) assert boundsExp == boundsObs, f"Expected bounds are {boundsExp}." + + +def test_findUniqueCellCenterCellIds() -> None: + """Test of findUniqueCellCenterCellIds method.""" + # Create first mesh with two cells + cellTypes1: list[ int ] = [ VTK_TETRA, VTK_TETRA ] + cellPtsCoord1: list[ npt.NDArray[ np.float64 ] ] = [ + # First tetrahedron centered around (0.25, 0.25, 0.25) + np.array( [ [ 0.0, 0.0, 0.0 ], [ 1.0, 0.0, 0.0 ], [ 0.0, 1.0, 0.0 ], [ 0.0, 0.0, 1.0 ] ], dtype=float ), + # Second tetrahedron centered around (2.25, 2.25, 2.25) + np.array( [ [ 2.0, 2.0, 2.0 ], [ 3.0, 2.0, 2.0 ], [ 2.0, 3.0, 2.0 ], [ 2.0, 2.0, 3.0 ] ], dtype=float ), + ] + + # Create second mesh with different cells, one overlapping with first mesh + cellTypes2: list[ int ] = [ VTK_TETRA, VTK_TETRA ] + cellPtsCoord2: list[ npt.NDArray[ np.float64 ] ] = [ + # First tetrahedron with same center as first cell in mesh1 (should overlap) + np.array( [ [ 0.0, 0.0, 0.0 ], [ 1.0, 0.0, 0.0 ], [ 0.0, 1.0, 0.0 ], [ 0.0, 0.0, 1.0 ] ], dtype=float ), + # Second tetrahedron with different center (unique to mesh2) + np.array( [ [ 4.0, 4.0, 4.0 ], [ 5.0, 4.0, 4.0 ], [ 4.0, 5.0, 4.0 ], [ 4.0, 4.0, 5.0 ] ], dtype=float ), + ] + + # Create meshes + mesh1: vtkUnstructuredGrid = createMultiCellMesh( cellTypes1, cellPtsCoord1, sharePoints=True ) + mesh2: vtkUnstructuredGrid = createMultiCellMesh( cellTypes2, cellPtsCoord2, sharePoints=True ) + + # Test the function + uniqueIds1, uniqueIds2, uniqueCoords1, uniqueCoords2 = findUniqueCellCenterCellIds( mesh1, mesh2 ) + + # Expected results: + # - Cell 0 in both meshes have the same center, so should not be in unique lists + # - Cell 1 in mesh1 (centered at ~(2.25, 2.25, 2.25)) is unique to mesh1 + # - Cell 1 in mesh2 (centered at ~(4.25, 4.25, 4.25)) is unique to mesh2 + assert len( uniqueIds1 ) == 1, f"Expected 1 unique cell in mesh1, got {len(uniqueIds1)}" + assert len( uniqueIds2 ) == 1, f"Expected 1 unique cell in mesh2, got {len(uniqueIds2)}" + assert uniqueIds1 == [ 1 ], f"Expected unique cell 1 in mesh1, got {uniqueIds1}" + assert uniqueIds2 == [ 1 ], f"Expected unique cell 1 in mesh2, got {uniqueIds2}" + + # Test coordinate lists + assert len( uniqueCoords1 ) == 1, f"Expected 1 unique coordinate in mesh1, got {len(uniqueCoords1)}" + assert len( uniqueCoords2 ) == 1, f"Expected 1 unique coordinate in mesh2, got {len(uniqueCoords2)}" + + # Test with tolerance + uniqueIds1_tight, uniqueIds2_tight, _, _ = findUniqueCellCenterCellIds( mesh1, mesh2, tolerance=1e-12 ) + assert len( uniqueIds1_tight ) == 1, "Tight tolerance should still find 1 unique cell in mesh1" + assert len( uniqueIds2_tight ) == 1, "Tight tolerance should still find 1 unique cell in mesh2" + + # Test error handling + with pytest.raises( ValueError, match="Input grids must be valid vtkUnstructuredGrid objects" ): + findUniqueCellCenterCellIds( None, mesh2 ) # type: ignore[arg-type] + + with pytest.raises( ValueError, match="Input grids must be valid vtkUnstructuredGrid objects" ): + findUniqueCellCenterCellIds( mesh1, None ) # type: ignore[arg-type] diff --git a/geos-mesh/tests/test_vtkIO.py b/geos-mesh/tests/test_vtkIO.py new file mode 100644 index 00000000..56ab678a --- /dev/null +++ b/geos-mesh/tests/test_vtkIO.py @@ -0,0 +1,441 @@ +import pytest +import numpy as np +from vtkmodules.vtkCommonCore import vtkPoints +from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkStructuredGrid, VTK_TETRA, VTK_HEXAHEDRON +from geos.mesh.utils.genericHelpers import createSingleCellMesh +from geos.mesh.io.vtkIO import ( VtkFormat, VtkOutput, read_mesh, read_unstructured_grid, write_mesh, READER_MAP, + WRITER_MAP ) + +__doc__ = """ +Test module for vtkIO module. +Tests the functionality of reading and writing various VTK file formats. +""" + + +@pytest.fixture( scope="module" ) +def simple_unstructured_mesh(): + """Fixture for a simple unstructured mesh with tetrahedron.""" + return createSingleCellMesh( VTK_TETRA, np.array( [ [ 0, 0, 0 ], [ 1, 0, 0 ], [ 0, 1, 0 ], [ 0, 0, 1 ] ] ) ) + + +@pytest.fixture( scope="module" ) +def simple_hex_mesh(): + """Fixture for a simple hexahedron mesh.""" + return createSingleCellMesh( + VTK_HEXAHEDRON, + np.array( [ [ 0, 0, 0 ], [ 1, 0, 0 ], [ 1, 1, 0 ], [ 0, 1, 0 ], [ 0, 0, 1 ], [ 1, 0, 1 ], [ 1, 1, 1 ], + [ 0, 1, 1 ] ] ) ) + + +@pytest.fixture( scope="module" ) +def structured_mesh(): + """Fixture for a simple structured grid.""" + mesh = vtkStructuredGrid() + mesh.SetDimensions( 2, 2, 2 ) + + points = vtkPoints() + for k in range( 2 ): + for j in range( 2 ): + for i in range( 2 ): + points.InsertNextPoint( i, j, k ) + + mesh.SetPoints( points ) + return mesh + + +class TestVtkFormat: + """Test class for VtkFormat enumeration.""" + + def test_vtk_format_values( self ): + """Test that VtkFormat enum has correct values.""" + assert VtkFormat.VTK.value == ".vtk" + assert VtkFormat.VTS.value == ".vts" + assert VtkFormat.VTU.value == ".vtu" + assert VtkFormat.PVTU.value == ".pvtu" + assert VtkFormat.PVTS.value == ".pvts" + + def test_vtk_format_from_string( self ): + """Test creating VtkFormat from string values.""" + assert VtkFormat( ".vtk" ) == VtkFormat.VTK + assert VtkFormat( ".vtu" ) == VtkFormat.VTU + assert VtkFormat( ".vts" ) == VtkFormat.VTS + assert VtkFormat( ".pvtu" ) == VtkFormat.PVTU + assert VtkFormat( ".pvts" ) == VtkFormat.PVTS + + def test_invalid_format( self ): + """Test that invalid format raises ValueError.""" + with pytest.raises( ValueError ): + VtkFormat( ".invalid" ) + + +class TestVtkOutput: + """Test class for VtkOutput dataclass.""" + + def test_vtk_output_creation( self ): + """Test VtkOutput creation with default parameters.""" + output = VtkOutput( "test.vtu" ) + assert output.output == "test.vtu" + assert output.is_data_mode_binary is True + + def test_vtk_output_creation_custom( self ): + """Test VtkOutput creation with custom parameters.""" + output = VtkOutput( "test.vtu", is_data_mode_binary=False ) + assert output.output == "test.vtu" + assert output.is_data_mode_binary is False + + def test_vtk_output_immutable( self ): + """Test that VtkOutput is immutable (frozen dataclass).""" + output = VtkOutput( "test.vtu" ) + with pytest.raises( AttributeError ): + output.output = "new_test.vtu" + + +class TestMappings: + """Test class for reader and writer mappings.""" + + def test_reader_map_completeness( self ): + """Test that READER_MAP contains all readable formats.""" + expected_formats = { VtkFormat.VTK, VtkFormat.VTS, VtkFormat.VTU, VtkFormat.PVTU, VtkFormat.PVTS } + assert set( READER_MAP.keys() ) == expected_formats + + def test_writer_map_completeness( self ): + """Test that WRITER_MAP contains all writable formats.""" + expected_formats = { VtkFormat.VTK, VtkFormat.VTS, VtkFormat.VTU } + assert set( WRITER_MAP.keys() ) == expected_formats + + def test_reader_map_classes( self ): + """Test that READER_MAP contains valid reader classes.""" + for format_type, reader_class in READER_MAP.items(): + assert hasattr( reader_class, '__name__' ) + # All readers should be classes + assert isinstance( reader_class, type ) + + def test_writer_map_classes( self ): + """Test that WRITER_MAP contains valid writer classes.""" + for format_type, writer_class in WRITER_MAP.items(): + assert hasattr( writer_class, '__name__' ) + # All writers should be classes + assert isinstance( writer_class, type ) + + +class TestWriteMesh: + """Test class for write_mesh functionality.""" + + def test_write_vtu_binary( self, simple_unstructured_mesh, tmp_path ): + """Test writing VTU file in binary mode.""" + output_file = tmp_path / "test_mesh.vtu" + vtk_output = VtkOutput( str( output_file ), is_data_mode_binary=True ) + + result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + assert result == 1 # VTK success code + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_vtu_ascii( self, simple_unstructured_mesh, tmp_path ): + """Test writing VTU file in ASCII mode.""" + output_file = tmp_path / "test_mesh_ascii.vtu" + vtk_output = VtkOutput( str( output_file ), is_data_mode_binary=False ) + + result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + assert result == 1 # VTK success code + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_vtk_format( self, simple_unstructured_mesh, tmp_path ): + """Test writing VTK legacy format.""" + output_file = tmp_path / "test_mesh.vtk" + vtk_output = VtkOutput( str( output_file ) ) + + result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + assert result == 1 # VTK success code + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_vts_format( self, structured_mesh, tmp_path ): + """Test writing VTS (structured grid) format.""" + output_file = tmp_path / "test_mesh.vts" + vtk_output = VtkOutput( str( output_file ) ) + + result = write_mesh( structured_mesh, vtk_output, can_overwrite=True ) + + assert result == 1 # VTK success code + assert output_file.exists() + assert output_file.stat().st_size > 0 + + def test_write_file_exists_error( self, simple_unstructured_mesh, tmp_path ): + """Test that writing to existing file raises error when can_overwrite=False.""" + output_file = tmp_path / "existing_file.vtu" + output_file.write_text( "dummy content" ) # Create existing file + + vtk_output = VtkOutput( str( output_file ) ) + + with pytest.raises( FileExistsError, match="already exists" ): + write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=False ) + + def test_write_unsupported_format( self, simple_unstructured_mesh, tmp_path ): + """Test that writing unsupported format raises ValueError.""" + output_file = tmp_path / "test_mesh.unsupported" + vtk_output = VtkOutput( str( output_file ) ) + + with pytest.raises( ValueError, match="not supported" ): + write_mesh( simple_unstructured_mesh, vtk_output ) + + def test_write_overwrite_allowed( self, simple_unstructured_mesh, tmp_path ): + """Test that overwriting is allowed when can_overwrite=True.""" + output_file = tmp_path / "overwrite_test.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + # First write + result1 = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert result1 == 1 + assert output_file.exists() + + # Second write (overwrite) + result2 = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert result2 == 1 + assert output_file.exists() + + +class TestReadMesh: + """Test class for read_mesh functionality.""" + + def test_read_nonexistent_file( self ): + """Test that reading nonexistent file raises FileNotFoundError.""" + with pytest.raises( FileNotFoundError, match="does not exist" ): + read_mesh( "nonexistent_file.vtu" ) + + def test_read_vtu_file( self, simple_unstructured_mesh, tmp_path ): + """Test reading VTU file.""" + output_file = tmp_path / "test_read.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + # First write the file + write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + # Then read it back + read_mesh_result = read_mesh( str( output_file ) ) + + assert read_mesh_result is not None + assert isinstance( read_mesh_result, vtkUnstructuredGrid ) + assert read_mesh_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert read_mesh_result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + def test_read_vtk_file( self, simple_unstructured_mesh, tmp_path ): + """Test reading VTK legacy file.""" + output_file = tmp_path / "test_read.vtk" + vtk_output = VtkOutput( str( output_file ) ) + + # First write the file + write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + # Then read it back + read_mesh_result = read_mesh( str( output_file ) ) + + assert read_mesh_result is not None + assert isinstance( read_mesh_result, vtkUnstructuredGrid ) + assert read_mesh_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert read_mesh_result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + def test_read_vts_file( self, structured_mesh, tmp_path ): + """Test reading VTS (structured grid) file.""" + output_file = tmp_path / "test_read.vts" + vtk_output = VtkOutput( str( output_file ) ) + + # First write the file + write_mesh( structured_mesh, vtk_output, can_overwrite=True ) + + # Then read it back + read_mesh_result = read_mesh( str( output_file ) ) + + assert read_mesh_result is not None + assert isinstance( read_mesh_result, vtkStructuredGrid ) + assert read_mesh_result.GetNumberOfPoints() == structured_mesh.GetNumberOfPoints() + + def test_read_unknown_extension( self, simple_unstructured_mesh, tmp_path ): + """Test reading file with unknown extension falls back to trying all readers.""" + # Create a VTU file but with unknown extension + vtu_file = tmp_path / "test.vtu" + unknown_file = tmp_path / "test.unknown" + + # Write as VTU first + vtk_output = VtkOutput( str( vtu_file ) ) + write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + # Copy to unknown extension + unknown_file.write_bytes( vtu_file.read_bytes() ) + + # Should still be able to read it + read_mesh_result = read_mesh( str( unknown_file ) ) + + assert read_mesh_result is not None + assert isinstance( read_mesh_result, vtkUnstructuredGrid ) + + def test_read_invalid_file_content( self, tmp_path ): + """Test that reading invalid file content raises ValueError.""" + invalid_file = tmp_path / "invalid.vtu" + invalid_file.write_text( "This is not a valid VTU file" ) + + with pytest.raises( ValueError, match="Could not find a suitable reader" ): + read_mesh( str( invalid_file ) ) + + +class TestReadUnstructuredGrid: + """Test class for read_unstructured_grid functionality.""" + + def test_read_unstructured_grid_success( self, simple_unstructured_mesh, tmp_path ): + """Test successfully reading an unstructured grid.""" + output_file = tmp_path / "test_ug.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + # Write unstructured grid + write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + + # Read back as unstructured grid + result = read_unstructured_grid( str( output_file ) ) + + assert isinstance( result, vtkUnstructuredGrid ) + assert result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + def test_read_unstructured_grid_wrong_type( self, structured_mesh, tmp_path ): + """Test that reading non-unstructured grid raises TypeError.""" + output_file = tmp_path / "test_sg.vts" + vtk_output = VtkOutput( str( output_file ) ) + + # Write structured grid + write_mesh( structured_mesh, vtk_output, can_overwrite=True ) + + # Try to read as unstructured grid - should fail + with pytest.raises( TypeError, match="not the expected vtkUnstructuredGrid" ): + read_unstructured_grid( str( output_file ) ) + + def test_read_unstructured_grid_nonexistent( self ): + """Test that reading nonexistent file raises FileNotFoundError.""" + with pytest.raises( FileNotFoundError, match="does not exist" ): + read_unstructured_grid( "nonexistent.vtu" ) + + +class TestRoundTripReadWrite: + """Test class for round-trip read/write operations.""" + + def test_vtu_round_trip_binary( self, simple_unstructured_mesh, tmp_path ): + """Test round-trip write and read for VTU binary format.""" + output_file = tmp_path / "roundtrip_binary.vtu" + vtk_output = VtkOutput( str( output_file ), is_data_mode_binary=True ) + + # Write + write_result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert write_result == 1 + + # Read back + read_result = read_unstructured_grid( str( output_file ) ) + + # Compare + assert read_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert read_result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + # Check point coordinates are preserved + for i in range( read_result.GetNumberOfPoints() ): + orig_point = simple_unstructured_mesh.GetPoint( i ) + read_point = read_result.GetPoint( i ) + np.testing.assert_array_almost_equal( orig_point, read_point, decimal=6 ) + + def test_vtu_round_trip_ascii( self, simple_unstructured_mesh, tmp_path ): + """Test round-trip write and read for VTU ASCII format.""" + output_file = tmp_path / "roundtrip_ascii.vtu" + vtk_output = VtkOutput( str( output_file ), is_data_mode_binary=False ) + + # Write + write_result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert write_result == 1 + + # Read back + read_result = read_unstructured_grid( str( output_file ) ) + + # Compare + assert read_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert read_result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + def test_vtk_round_trip( self, simple_unstructured_mesh, tmp_path ): + """Test round-trip write and read for VTK legacy format.""" + output_file = tmp_path / "roundtrip.vtk" + vtk_output = VtkOutput( str( output_file ) ) + + # Write + write_result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert write_result == 1 + + # Read back + read_result = read_unstructured_grid( str( output_file ) ) + + # Compare + assert read_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints() + assert read_result.GetNumberOfCells() == simple_unstructured_mesh.GetNumberOfCells() + + def test_vts_round_trip( self, structured_mesh, tmp_path ): + """Test round-trip write and read for VTS format.""" + output_file = tmp_path / "roundtrip.vts" + vtk_output = VtkOutput( str( output_file ) ) + + # Write + write_result = write_mesh( structured_mesh, vtk_output, can_overwrite=True ) + assert write_result == 1 + + # Read back + read_result = read_mesh( str( output_file ) ) + + # Compare + assert isinstance( read_result, vtkStructuredGrid ) + assert read_result.GetNumberOfPoints() == structured_mesh.GetNumberOfPoints() + + +class TestEdgeCases: + """Test class for edge cases and error conditions.""" + + def test_empty_mesh_write( self, tmp_path ): + """Test writing an empty mesh.""" + empty_mesh = vtkUnstructuredGrid() + output_file = tmp_path / "empty.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + result = write_mesh( empty_mesh, vtk_output, can_overwrite=True ) + assert result == 1 + assert output_file.exists() + + def test_empty_mesh_round_trip( self, tmp_path ): + """Test round-trip with empty mesh.""" + empty_mesh = vtkUnstructuredGrid() + output_file = tmp_path / "empty_roundtrip.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + # Write + write_result = write_mesh( empty_mesh, vtk_output, can_overwrite=True ) + assert write_result == 1 + + # Read back + read_result = read_unstructured_grid( str( output_file ) ) + assert read_result.GetNumberOfPoints() == 0 + assert read_result.GetNumberOfCells() == 0 + + def test_large_path_names( self, simple_unstructured_mesh, tmp_path ): + """Test handling of long file paths.""" + # Create a deep directory structure + deep_dir = tmp_path + for i in range( 5 ): + deep_dir = deep_dir / f"very_long_directory_name_level_{i}" + deep_dir.mkdir( parents=True ) + + output_file = deep_dir / "mesh_with_very_long_filename_that_should_still_work.vtu" + vtk_output = VtkOutput( str( output_file ) ) + + # Should work fine + result = write_mesh( simple_unstructured_mesh, vtk_output, can_overwrite=True ) + assert result == 1 + assert output_file.exists() + + # And read back + read_result = read_unstructured_grid( str( output_file ) ) + assert read_result.GetNumberOfPoints() == simple_unstructured_mesh.GetNumberOfPoints()