diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index 425afb16c..3017473ac 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -482,6 +482,21 @@ def determine_target_size(var_type: str) -> int: ds.variables[index].metadata.chunk_grid = chunk_grid +def _validate_spec_in_template(segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate) -> None: + """Validate that the SegySpec has all required fields in the MDIO template.""" + header_fields = {field.name for field in segy_spec.trace.header.fields} + + required_fields = set(mdio_template._dim_names[:-1]) | set(mdio_template._coord_names) + missing_fields = required_fields - header_fields + + if missing_fields: + err = ( + f"Required fields {sorted(missing_fields)} for template {mdio_template.name} " + f"not found in the provided segy_spec" + ) + raise ValueError(err) + + def segy_to_mdio( # noqa PLR0913 segy_spec: SegySpec, mdio_template: AbstractDatasetTemplate, @@ -507,6 +522,8 @@ def segy_to_mdio( # noqa PLR0913 Raises: FileExistsError: If the output location already exists and overwrite is False. """ + _validate_spec_in_template(segy_spec, mdio_template) + input_path = _normalize_path(input_path) output_path = _normalize_path(output_path) diff --git a/tests/unit/test_segy_spec_validation.py b/tests/unit/test_segy_spec_validation.py new file mode 100644 index 000000000..f8d65fcd9 --- /dev/null +++ b/tests/unit/test_segy_spec_validation.py @@ -0,0 +1,59 @@ +"""Tests for SEG-Y spec validation against MDIO templates.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from segy.schema import HeaderField +from segy.standards import get_segy_standard + +from mdio.converters.segy import _validate_spec_in_template + + +class TestValidateSpecInTemplate: + """Test cases for _validate_spec_in_template function.""" + + def test_validation_passes_with_all_required_fields(self) -> None: + """Test that validation passes when all required fields are present.""" + template = MagicMock() + template._dim_names = ("inline", "crossline", "time") + template._coord_names = ("cdp_x", "cdp_y") + + # SegySpec with all required fields + spec = get_segy_standard(1.0) + header_fields = [ + HeaderField(name="inline", byte=189, format="int32"), + HeaderField(name="crossline", byte=193, format="int32"), + HeaderField(name="cdp_x", byte=181, format="int32"), + HeaderField(name="cdp_y", byte=185, format="int32"), + ] + segy_spec = spec.customize(trace_header_fields=header_fields) + + # Should not raise any exception + _validate_spec_in_template(segy_spec, template) + + def test_validation_fails_with_missing_fields(self) -> None: + """Test that validation fails when required fields are missing.""" + # Template requiring custom fields not in standard spec + template = MagicMock() + template.name = "CustomTemplate" + template._dim_names = ("custom_dim1", "custom_dim2", "time") + template._coord_names = ("custom_coord_x", "custom_coord_y") + + # SegySpec with only one of the required custom fields + spec = get_segy_standard(1.0) + header_fields = [ + HeaderField(name="custom_dim1", byte=189, format="int32"), + ] + segy_spec = spec.customize(trace_header_fields=header_fields) + + # Should raise ValueError listing the missing fields + with pytest.raises(ValueError, match=r"Required fields.*not found in.*segy_spec") as exc_info: + _validate_spec_in_template(segy_spec, template) + + error_message = str(exc_info.value) + assert "custom_dim2" in error_message + assert "custom_coord_x" in error_message + assert "custom_coord_y" in error_message + assert "CustomTemplate" in error_message