diff --git a/src/mdio/builder/template_registry.py b/src/mdio/builder/template_registry.py index 60b7f062..d90d5cfd 100644 --- a/src/mdio/builder/template_registry.py +++ b/src/mdio/builder/template_registry.py @@ -26,6 +26,7 @@ from mdio.builder.templates.seismic_3d_cdp import Seismic3DCdpGathersTemplate from mdio.builder.templates.seismic_3d_coca import Seismic3DCocaGathersTemplate from mdio.builder.templates.seismic_3d_poststack import Seismic3DPostStackTemplate +from mdio.builder.templates.seismic_3d_streamer_field import Seismic3DStreamerFieldRecordsTemplate from mdio.builder.templates.seismic_3d_streamer_shot import Seismic3DStreamerShotGathersTemplate if TYPE_CHECKING: @@ -135,6 +136,7 @@ def _register_default_templates(self) -> None: # Field (shot) data self.register(Seismic2DStreamerShotGathersTemplate()) self.register(Seismic3DStreamerShotGathersTemplate()) + self.register(Seismic3DStreamerFieldRecordsTemplate()) def get(self, template_name: str) -> AbstractDatasetTemplate: """Get an instance of a template from the registry by its name. diff --git a/src/mdio/builder/templates/base.py b/src/mdio/builder/templates/base.py index 879f1fd6..50f775dc 100644 --- a/src/mdio/builder/templates/base.py +++ b/src/mdio/builder/templates/base.py @@ -39,6 +39,7 @@ def __init__(self, data_domain: SeismicDataDomain) -> None: raise ValueError(msg) self._dim_names: tuple[str, ...] = () + self._calculated_dims: tuple[str, ...] = () self._physical_coord_names: tuple[str, ...] = () self._logical_coord_names: tuple[str, ...] = () self._var_chunk_shape: tuple[int, ...] = () @@ -130,6 +131,11 @@ def dimension_names(self) -> tuple[str, ...]: """Returns the names of the dimensions.""" return copy.deepcopy(self._dim_names) + @property + def calculated_dimension_names(self) -> tuple[str, ...]: + """Returns the names of the dimensions.""" + return copy.deepcopy(self._calculated_dims) + @property def physical_coordinate_names(self) -> tuple[str, ...]: """Returns the names of the physical (world) coordinates.""" diff --git a/src/mdio/builder/templates/seismic_3d_streamer_field.py b/src/mdio/builder/templates/seismic_3d_streamer_field.py new file mode 100644 index 00000000..c550ca7d --- /dev/null +++ b/src/mdio/builder/templates/seismic_3d_streamer_field.py @@ -0,0 +1,104 @@ +"""Seismic3DStreamerFieldRecordsTemplate MDIO v1 dataset templates.""" + +from typing import Any + +from mdio.builder.schemas.dtype import ScalarType +from mdio.builder.schemas.v1.variable import CoordinateMetadata +from mdio.builder.templates.base import AbstractDatasetTemplate +from mdio.builder.templates.types import SeismicDataDomain + + +class Seismic3DStreamerFieldRecordsTemplate(AbstractDatasetTemplate): + """Seismic 3D streamer shot field records template. + + A generalized template for streamer field records that are optimized for: + - Common-shot access + - Common-channel access + + It can also store all the shot-lines of a survey in one MDIO if needed. + + Args: + data_domain: The domain of the dataset. + """ + + def __init__(self, data_domain: SeismicDataDomain = "time"): + super().__init__(data_domain=data_domain) + + self._spatial_dim_names = ("sail_line", "gun", "shot_index", "cable", "channel") + self._calculated_dims = ("shot_index",) + self._dim_names = (*self._spatial_dim_names, self._data_domain) + self._physical_coord_names = ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") + self._logical_coord_names = ("shot_point", "orig_field_record_num") # ffid + self._var_chunk_shape = (1, 1, 16, 1, 32, 1024) + + @property + def _name(self) -> str: + return "StreamerFieldRecords3D" + + def _load_dataset_attributes(self) -> dict[str, Any]: + return {"surveyDimensionality": "3D", "ensembleType": "common_source_by_gun"} + + def _add_coordinates(self) -> None: + # Add dimension coordinates + # EXCLUDE: `shot_index` since its 0-N + self._builder.add_coordinate( + "sail_line", + dimensions=("sail_line",), + data_type=ScalarType.UINT32, + ) + self._builder.add_coordinate( + "gun", + dimensions=("gun",), + data_type=ScalarType.UINT8, + ) + self._builder.add_coordinate( + "cable", + dimensions=("cable",), + data_type=ScalarType.UINT8, + ) + self._builder.add_coordinate( + "channel", + dimensions=("channel",), + data_type=ScalarType.UINT16, + ) + self._builder.add_coordinate( + self._data_domain, + dimensions=(self._data_domain,), + data_type=ScalarType.INT32, + ) + + # Add non-dimension coordinates + self._builder.add_coordinate( + "orig_field_record_num", + dimensions=("sail_line", "gun", "shot_index"), + data_type=ScalarType.UINT32, + ) + self._builder.add_coordinate( + "shot_point", + dimensions=("sail_line", "gun", "shot_index"), + data_type=ScalarType.UINT32, + ) + self._builder.add_coordinate( + "source_coord_x", + dimensions=("sail_line", "gun", "shot_index"), + data_type=ScalarType.FLOAT64, + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("source_coord_x")), + ) + self._builder.add_coordinate( + "source_coord_y", + dimensions=("sail_line", "gun", "shot_index"), + data_type=ScalarType.FLOAT64, + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("source_coord_y")), + ) + self._builder.add_coordinate( + "group_coord_x", + dimensions=("sail_line", "gun", "shot_index", "cable", "channel"), + data_type=ScalarType.FLOAT64, + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("group_coord_x")), + ) + self._builder.add_coordinate( + "group_coord_y", + dimensions=("sail_line", "gun", "shot_index", "cable", "channel"), + data_type=ScalarType.FLOAT64, + metadata=CoordinateMetadata(units_v1=self.get_unit_by_key("group_coord_y")), + ) diff --git a/src/mdio/converters/segy.py b/src/mdio/converters/segy.py index 708d4aed..e2fd6b35 100644 --- a/src/mdio/converters/segy.py +++ b/src/mdio/converters/segy.py @@ -345,10 +345,10 @@ def _populate_coordinates( """ drop_vars_delayed = [] # Populate the dimension coordinate variables (1-D arrays) - dataset, vars_to_drop_later = populate_dim_coordinates(dataset, grid, drop_vars_delayed=drop_vars_delayed) + dataset, drop_vars_delayed = populate_dim_coordinates(dataset, grid, drop_vars_delayed=drop_vars_delayed) # Populate the non-dimension coordinate variables (N-dim arrays) - dataset, vars_to_drop_later = populate_non_dim_coordinates( + dataset, drop_vars_delayed = populate_non_dim_coordinates( dataset, grid, coordinates=coords, @@ -488,6 +488,7 @@ def _validate_spec_in_template(segy_spec: SegySpec, mdio_template: AbstractDatas header_fields = {field.name for field in segy_spec.trace.header.fields} required_fields = set(mdio_template.spatial_dimension_names) | set(mdio_template.coordinate_names) + required_fields = required_fields - set(mdio_template.calculated_dimension_names) # remove to be calculated required_fields = required_fields | {"coordinate_scalar"} # ensure coordinate scalar is always present missing_fields = required_fields - header_fields @@ -592,6 +593,9 @@ def segy_to_mdio( # noqa PLR0913 to_mdio(xr_dataset, output_path=output_path, mode="w", compute=False) # This will write the non-dimension coordinates and trace mask + # We also remove dimensions that don't have associated coordinates + unindexed_dims = [d for d in xr_dataset.dims if d not in xr_dataset.coords] + [drop_vars_delayed.remove(d) for d in unindexed_dims] meta_ds = xr_dataset[drop_vars_delayed + ["trace_mask"]] to_mdio(meta_ds, output_path=output_path, mode="r+", compute=True) diff --git a/src/mdio/segy/geometry.py b/src/mdio/segy/geometry.py index ed41e42e..bdb0b81b 100644 --- a/src/mdio/segy/geometry.py +++ b/src/mdio/segy/geometry.py @@ -149,7 +149,7 @@ def analyze_streamer_headers( return unique_cables, cable_chan_min, cable_chan_max, geom_type -def analyze_shotlines_for_guns( +def analyze_saillines_for_guns( index_headers: HeaderArray, ) -> tuple[NDArray, dict[str, list], ShotGunGeometryType]: """Check input headers for SEG-Y input to help determine geometry of shots and guns. @@ -161,27 +161,27 @@ def analyze_shotlines_for_guns( index_headers: numpy array with index headers Returns: - tuple of unique_shot_lines, unique_guns_in_shot_line, geom_type + tuple of unique_sail_lines, unique_guns_in_sail_line, geom_type """ # Find unique cable ids - unique_shot_lines = np.sort(np.unique(index_headers["shot_line"])) + unique_sail_lines = np.sort(np.unique(index_headers["sail_line"])) unique_guns = np.sort(np.unique(index_headers["gun"])) - logger.info("unique_shot_lines: %s", unique_shot_lines) + logger.info("unique_sail_lines: %s", unique_sail_lines) logger.info("unique_guns: %s", unique_guns) # Find channel min and max values for each cable - unique_guns_in_shot_line = {} + unique_guns_in_sail_line = {} geom_type = ShotGunGeometryType.B # Check shot numbers are still unique if div/num_guns - for shot_line in unique_shot_lines: - shot_line_mask = index_headers["shot_line"] == shot_line - shot_current_sl = index_headers["shot_point"][shot_line_mask] - gun_current_sl = index_headers["gun"][shot_line_mask] + for sail_line in unique_sail_lines: + sail_line_mask = index_headers["sail_line"] == sail_line + shot_current_sl = index_headers["shot_point"][sail_line_mask] + gun_current_sl = index_headers["gun"][sail_line_mask] unique_guns_sl = np.sort(np.unique(gun_current_sl)) num_guns_sl = unique_guns_sl.shape[0] - unique_guns_in_shot_line[str(shot_line)] = list(unique_guns_sl) + unique_guns_in_sail_line[str(sail_line)] = list(unique_guns_sl) for gun in unique_guns_sl: gun_mask = gun_current_sl == gun @@ -190,10 +190,10 @@ def analyze_shotlines_for_guns( mod_shots = np.floor(shots_current_sl_gun / num_guns_sl) if len(np.unique(mod_shots)) != num_shots_sl: msg = "Shot line %s has %s when using div by %s %s has %s unique mod shots." - logger.info(msg, shot_line, num_shots_sl, num_guns_sl, np.unique(mod_shots)) + logger.info(msg, sail_line, num_shots_sl, num_guns_sl, np.unique(mod_shots)) geom_type = ShotGunGeometryType.A - return unique_shot_lines, unique_guns_in_shot_line, geom_type - return unique_shot_lines, unique_guns_in_shot_line, geom_type + return unique_sail_lines, unique_guns_in_sail_line, geom_type + return unique_sail_lines, unique_guns_in_sail_line, geom_type def create_counter( @@ -459,7 +459,7 @@ def transform( class AutoShotWrap(GridOverrideCommand): """Automatically determine ShotGun acquisition type.""" - required_keys = {"shot_line", "gun", "shot_point", "cable", "channel"} + required_keys = {"sail_line", "gun", "shot_point", "cable", "channel"} required_parameters = None def validate(self, index_headers: HeaderArray, grid_overrides: dict[str, bool | int]) -> None: @@ -475,24 +475,28 @@ def transform( """Perform the grid transform.""" self.validate(index_headers, grid_overrides) - result = analyze_shotlines_for_guns(index_headers) - unique_shot_lines, unique_guns_in_shot_line, geom_type = result + result = analyze_saillines_for_guns(index_headers) + unique_sail_lines, unique_guns_in_sail_line, geom_type = result logger.info("Ingesting dataset as shot type: %s", geom_type.name) max_num_guns = 1 - for shot_line in unique_shot_lines: - logger.info("shot_line: %s has guns: %s", shot_line, unique_guns_in_shot_line[str(shot_line)]) - num_guns = len(unique_guns_in_shot_line[str(shot_line)]) + for sail_line in unique_sail_lines: + logger.info("sail_line: %s has guns: %s", sail_line, unique_guns_in_sail_line[str(sail_line)]) + num_guns = len(unique_guns_in_sail_line[str(sail_line)]) max_num_guns = max(num_guns, max_num_guns) # This might be slow and potentially could be improved with a rewrite # to prevent so many lookups if geom_type == ShotGunGeometryType.B: - for shot_line in unique_shot_lines: - shot_line_idxs = np.where(index_headers["shot_line"][:] == shot_line) - index_headers["shot_point"][shot_line_idxs] = np.floor( - index_headers["shot_point"][shot_line_idxs] / max_num_guns + shot_index = np.empty(len(index_headers), dtype="uint32") + index_headers = rfn.append_fields(index_headers.base, "shot_index", shot_index) + for sail_line in unique_sail_lines: + sail_line_idxs = np.where(index_headers["sail_line"][:] == sail_line) + index_headers["shot_index"][sail_line_idxs] = np.floor( + index_headers["shot_point"][sail_line_idxs] / max_num_guns ) + # Make shot index zero-based PER sail line + index_headers["shot_index"][sail_line_idxs] -= index_headers["shot_index"][sail_line_idxs].min() return index_headers diff --git a/src/mdio/segy/utilities.py b/src/mdio/segy/utilities.py index f6d76bc6..195a02c8 100644 --- a/src/mdio/segy/utilities.py +++ b/src/mdio/segy/utilities.py @@ -51,6 +51,9 @@ def get_grid_plan( # noqa: C901, PLR0913 Returns: All index dimensions and chunksize or dimensions and chunksize together with header values. + + Raises: + ValueError: If computed fields are not found after grid overrides. """ if grid_overrides is None: grid_overrides = {} @@ -58,6 +61,11 @@ def get_grid_plan( # noqa: C901, PLR0913 # Keep only dimension and non-dimension coordinates excluding the vertical axis horizontal_dimensions = template.spatial_dimension_names horizontal_coordinates = horizontal_dimensions + template.coordinate_names + + # Remove any to be computed fields + computed_fields = set(template.calculated_dimension_names) + horizontal_coordinates = tuple(set(horizontal_coordinates) - computed_fields) + headers_subset = parse_headers( segy_file_kwargs=segy_file_kwargs, num_traces=segy_file_info.num_traces, @@ -73,6 +81,13 @@ def get_grid_plan( # noqa: C901, PLR0913 grid_overrides=grid_overrides, ) + if len(computed_fields) > 0 and not computed_fields.issubset(headers_subset.dtype.names): + err = ( + f"Required computed fields {sorted(computed_fields)} for template {template.name} " + f"not found after grid overrides. Please ensure correct overrides are applied." + ) + raise ValueError(err) + dimensions = [] for dim_name in horizontal_dimensions: dim_unique = np.unique(headers_subset[dim_name]) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d91a8383..5fe7aed1 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -22,13 +22,13 @@ def get_segy_mock_4d_spec() -> SegySpec: """Create a mock 4D SEG-Y specification.""" trace_header_fields = [ - HeaderField(name="field_rec_no", byte=9, format="int32"), + HeaderField(name="orig_field_record_num", byte=9, format="int32"), HeaderField(name="channel", byte=13, format="int32"), HeaderField(name="shot_point", byte=17, format="int32"), HeaderField(name="offset", byte=37, format="int32"), HeaderField(name="samples_per_trace", byte=115, format="int16"), HeaderField(name="sample_interval", byte=117, format="int16"), - HeaderField(name="shot_line", byte=133, format="int16"), + HeaderField(name="sail_line", byte=133, format="int16"), HeaderField(name="cable", byte=137, format="int16"), HeaderField(name="gun", byte=171, format="int16"), HeaderField(name="coordinate_scalar", byte=71, format="int16"), @@ -111,15 +111,15 @@ def create_segy_mock_4d( # noqa: PLR0913 gun = gun_headers[trc_idx] cable = cable_headers[trc_idx] channel = channel_headers[trc_idx] - shot_line = 1 + sail_line = 1 offset = 0 if index_receivers is False: - channel, gun, shot_line = 0, 0, 0 + channel, gun, sail_line = 0, 0, 0 # Assign dimension coordinate fields with calculated mock data - header_fields = ["field_rec_no", "channel", "shot_point", "offset", "shot_line", "cable", "gun"] - headers[header_fields][trc_idx] = (shot, channel, shot, offset, shot_line, cable, gun) + header_fields = ["orig_field_record_num", "channel", "shot_point", "offset", "sail_line", "cable", "gun"] + headers[header_fields][trc_idx] = (shot, channel, shot, offset, sail_line, cable, gun) # Assign coordinate fields with mock data x = start_x + step_x * trc_shot_idx @@ -144,7 +144,7 @@ def segy_mock_4d_shots(fake_segy_tmp: Path) -> dict[StreamerShotGeometryType, Pa num_samples = 25 shots = [2, 3, 5, 6, 7, 8, 9] guns = [1, 2] - cables = [0, 101, 201, 301] + cables = [0, 3, 5, 7] receivers_per_cable = [1, 5, 7, 5] segy_paths = {} diff --git a/tests/integration/test_import_streamer_grid_overrides.py b/tests/integration/test_import_streamer_grid_overrides.py index 7ff89fb0..c90d8c8c 100644 --- a/tests/integration/test_import_streamer_grid_overrides.py +++ b/tests/integration/test_import_streamer_grid_overrides.py @@ -59,7 +59,7 @@ def test_import_4d_segy( # noqa: PLR0913 # Expected values num_samples = 25 shots = [2, 3, 5, 6, 7, 8, 9] - cables = [0, 101, 201, 301] + cables = [0, 3, 5, 7] receivers_per_cable = [1, 5, 7, 5] ds = open_mdio(zarr_tmp) @@ -106,7 +106,7 @@ def test_import_4d_segy( # noqa: PLR0913 # Expected values num_samples = 25 shots = [2, 3, 5, 6, 7, 8, 9] - cables = [0, 101, 201, 301] + cables = [0, 3, 5, 7] receivers_per_cable = [1, 5, 7, 5] ds = open_mdio(zarr_tmp) @@ -156,12 +156,9 @@ def test_import_4d_segy( # noqa: PLR0913 assert "This grid is very sparse and most likely user error with indexing." in str(execinfo.value) -# TODO(Altay): Finish implementing these grid overrides. -# https://github.com/TGSAI/mdio-python/issues/612 -@pytest.mark.skip(reason="AutoShotWrap requires a template that is not implemented yet.") -@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True}, {"AutoShotWrap": True}, None]) +@pytest.mark.parametrize("grid_override", [{"AutoChannelWrap": True, "AutoShotWrap": True}, {"AutoShotWrap": True}]) @pytest.mark.parametrize("chan_header_type", [StreamerShotGeometryType.A, StreamerShotGeometryType.B]) -class TestImport6D: # pragma: no cover - tests is skipped +class TestImport6D: """Test for 6D segy import with grid overrides.""" def test_import_6d_segy( # noqa: PLR0913 @@ -177,7 +174,7 @@ def test_import_6d_segy( # noqa: PLR0913 segy_to_mdio( segy_spec=segy_spec, - mdio_template=TemplateRegistry().get("XYZ"), # Placeholder for the template + mdio_template=TemplateRegistry().get("StreamerFieldRecords3D"), input_path=segy_path, output_path=zarr_tmp, overwrite=True, @@ -186,26 +183,33 @@ def test_import_6d_segy( # noqa: PLR0913 # Expected values num_samples = 25 - shots = [2, 3, 5, 6, 7, 8, 9] # original shot list - if grid_override is not None and "AutoShotWrap" in grid_override: - shots_new = [int(shot / 2) for shot in shots] # Updated shot index when ingesting with 2 guns - shots_set = set(shots_new) # remove duplicates - shots = list(shots_set) # Unique shot points for 6D indexed with gun - cables = [0, 101, 201, 301] + shot_points = [2, 3, 5, 6, 7, 8, 9] # original shot list, missing shot ~ 4. + + shot_index = [int(sp / 2) for sp in shot_points] # Updated shot index when ingesting with 2 guns + shot_index = np.unique(shot_index) - 1 # Unique shot point indices for 6D indexed with gun + cables = [0, 3, 5, 7] guns = [1, 2] receivers_per_cable = [1, 5, 7, 5] ds = open_mdio(zarr_tmp) xrt.assert_duckarray_equal(ds["gun"], guns) - xrt.assert_duckarray_equal(ds["shot_point"], shots) + xrt.assert_duckarray_equal(ds["shot_index"], shot_index) xrt.assert_duckarray_equal(ds["cable"], cables) - if chan_header_type == StreamerShotGeometryType.B and grid_override is None: + if chan_header_type == StreamerShotGeometryType.B and "AutoChannelWrap" not in grid_override: expected = list(range(1, np.sum(receivers_per_cable) + 1)) else: expected = list(range(1, np.amax(receivers_per_cable) + 1)) xrt.assert_duckarray_equal(ds["channel"], expected) + expected_shot_points = [ + [ + [2, 4294967295, 6, 8], # gun = 1 + [3, 5, 7, 9], # gun = 2 + ], # sail_line = 1 + ] + xrt.assert_duckarray_equal(ds["shot_point"], expected_shot_points) + times_expected = list(range(0, num_samples, 1)) xrt.assert_duckarray_equal(ds["time"], times_expected) diff --git a/tests/unit/v1/templates/test_seismic_3d_streamer_field.py b/tests/unit/v1/templates/test_seismic_3d_streamer_field.py new file mode 100644 index 00000000..9ea83793 --- /dev/null +++ b/tests/unit/v1/templates/test_seismic_3d_streamer_field.py @@ -0,0 +1,164 @@ +"""Unit tests for Seismic3DStreamerFieldRecordsTemplate.""" + +from tests.unit.v1.helpers import validate_variable + +from mdio.builder.schemas.chunk_grid import RegularChunkGrid +from mdio.builder.schemas.compressors import Blosc +from mdio.builder.schemas.compressors import BloscCname +from mdio.builder.schemas.dtype import ScalarType +from mdio.builder.schemas.dtype import StructuredType +from mdio.builder.schemas.v1.dataset import Dataset +from mdio.builder.schemas.v1.units import LengthUnitEnum +from mdio.builder.schemas.v1.units import LengthUnitModel +from mdio.builder.schemas.v1.units import TimeUnitEnum +from mdio.builder.schemas.v1.units import TimeUnitModel +from mdio.builder.templates.seismic_3d_streamer_field import Seismic3DStreamerFieldRecordsTemplate + +UNITS_METER = LengthUnitModel(length=LengthUnitEnum.METER) +UNITS_SECOND = TimeUnitModel(time=TimeUnitEnum.SECOND) + + +DATASET_SIZE_MAP = {"sail_line": 1, "gun": 2, "shot_index": 128, "cable": 256, "channel": 12, "time": 1024} +DATASET_DTYPE_MAP = {"sail_line": "uint32", "gun": "uint8", "cable": "uint8", "channel": "uint16", "time": "int32"} +EXPECTED_COORDINATES = [ + "shot_point", + "orig_field_record_num", + "source_coord_x", + "source_coord_y", + "group_coord_x", + "group_coord_y", +] + + +def _validate_coordinates_headers_trace_mask(dataset: Dataset, headers: StructuredType, domain: str) -> None: + """Validate the coordinate, headers, trace_mask variables in the dataset.""" + # Verify variables + # 6 dim coords + 5 non-dim coords + 1 data + 1 trace mask + 1 headers = 14 variables + assert len(dataset.variables) == 14 + + # Verify trace headers + validate_variable( + dataset, + name="headers", + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k != domain], + coords=EXPECTED_COORDINATES, + dtype=headers, + ) + + validate_variable( + dataset, + name="trace_mask", + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k != domain], + coords=EXPECTED_COORDINATES, + dtype=ScalarType.BOOL, + ) + + # Verify dimension coordinate variables + for dim_name, dim_size in DATASET_SIZE_MAP.items(): + if dim_name == "shot_index": + continue + + validate_variable( + dataset, + name=dim_name, + dims=[(dim_name, dim_size)], + coords=[dim_name], + dtype=ScalarType(DATASET_DTYPE_MAP[dim_name]), + ) + + # Verify non-dimension coordinate variables + validate_variable( + dataset, + name="orig_field_record_num", + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k in ["sail_line", "gun", "shot_index"]], + coords=["orig_field_record_num"], + dtype=ScalarType.UINT32, + ) + + validate_variable( + dataset, + name="shot_point", + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k in ["sail_line", "gun", "shot_index"]], + coords=["shot_point"], + dtype=ScalarType.UINT32, + ) + + # Verify coordinate variables with units + for coord_name in ["source_coord_x", "source_coord_y"]: + coord = validate_variable( + dataset, + name=coord_name, + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k in ["sail_line", "gun", "shot_index"]], + coords=[coord_name], + dtype=ScalarType.FLOAT64, + ) + assert coord.metadata.units_v1.length == LengthUnitEnum.METER + + for coord_name in ["group_coord_x", "group_coord_y"]: + coord = validate_variable( + dataset, + name=coord_name, + dims=[(k, v) for k, v in DATASET_SIZE_MAP.items() if k != domain], + coords=[coord_name], + dtype=ScalarType.FLOAT64, + ) + assert coord.metadata.units_v1.length == LengthUnitEnum.METER + + +class TestSeismic3DStreamerFieldRecordsTemplate: + """Unit tests for Seismic3DStreamerFieldRecordsTemplate.""" + + def test_configuration(self) -> None: + """Unit tests for Seismic3DStreamerFieldRecordsTemplate.""" + t = Seismic3DStreamerFieldRecordsTemplate(data_domain="time") + + # Template attributes + assert t.name == "StreamerFieldRecords3D" + assert t._dim_names == ("sail_line", "gun", "shot_index", "cable", "channel", "time") + assert t._physical_coord_names == ("source_coord_x", "source_coord_y", "group_coord_x", "group_coord_y") + assert t.full_chunk_shape == (1, 1, 16, 1, 32, 1024) + + # Variables instantiated when build_dataset() is called + assert t._builder is None + assert t._dim_sizes == () + + # Verify dataset attributes + attrs = t._load_dataset_attributes() + assert attrs == {"surveyDimensionality": "3D", "ensembleType": "common_source_by_gun"} + assert t.default_variable_name == "amplitude" + + def test_build_dataset(self, structured_headers: StructuredType) -> None: + """Unit tests for Seismic3DStreamerFieldRecordsTemplate build.""" + t = Seismic3DStreamerFieldRecordsTemplate(data_domain="time") + t.add_units({"source_coord_x": UNITS_METER, "source_coord_y": UNITS_METER}) # spatial domain units + t.add_units({"group_coord_x": UNITS_METER, "group_coord_y": UNITS_METER}) # spatial domain units + t.add_units({"time": UNITS_SECOND}) # data domain units + + dataset = t.build_dataset("Survey3D", sizes=(1, 2, 128, 256, 12, 1024), header_dtype=structured_headers) + + assert dataset.metadata.name == "Survey3D" + assert dataset.metadata.attributes["surveyDimensionality"] == "3D" + assert dataset.metadata.attributes["ensembleType"] == "common_source_by_gun" + + _validate_coordinates_headers_trace_mask(dataset, structured_headers, "time") + + # Verify seismic variable + seismic = validate_variable( + dataset, + name="amplitude", + dims=[("sail_line", 1), ("gun", 2), ("shot_index", 128), ("cable", 256), ("channel", 12), ("time", 1024)], + coords=[ + "shot_point", + "orig_field_record_num", + "source_coord_x", + "source_coord_y", + "group_coord_x", + "group_coord_y", + ], + dtype=ScalarType.FLOAT32, + ) + assert isinstance(seismic.compressor, Blosc) + assert seismic.compressor.cname == BloscCname.zstd + assert isinstance(seismic.metadata.chunk_grid, RegularChunkGrid) + assert seismic.metadata.chunk_grid.configuration.chunk_shape == (1, 1, 16, 1, 32, 1024) + assert seismic.metadata.stats_v1 is None diff --git a/tests/unit/v1/templates/test_template_registry.py b/tests/unit/v1/templates/test_template_registry.py index 743c1765..e0b641f0 100644 --- a/tests/unit/v1/templates/test_template_registry.py +++ b/tests/unit/v1/templates/test_template_registry.py @@ -33,6 +33,7 @@ "CocaGathers3DDepth", "StreamerShotGathers2D", "StreamerShotGathers3D", + "StreamerFieldRecords3D", ] @@ -239,7 +240,7 @@ def test_list_all_templates(self) -> None: registry.register(template2) templates = registry.list_all_templates() - assert len(templates) == 16 + 2 # 16 default + 2 custom + assert len(templates) == 17 + 2 # 17 default + 2 custom assert "Template_One" in templates assert "Template_Two" in templates @@ -249,7 +250,7 @@ def test_clear_templates(self) -> None: # Default templates are always installed templates = list_templates() - assert len(templates) == 16 + assert len(templates) == 17 # Add some templates template1 = MockDatasetTemplate("Template1") @@ -258,7 +259,7 @@ def test_clear_templates(self) -> None: registry.register(template1) registry.register(template2) - assert len(registry.list_all_templates()) == 16 + 2 # 16 default + 2 custom + assert len(registry.list_all_templates()) == 17 + 2 # 17 default + 2 custom # Clear all registry.clear() @@ -391,7 +392,7 @@ def test_list_templates_global(self) -> None: register_template(template2) templates = list_templates() - assert len(templates) == 18 # 16 default + 2 custom + assert len(templates) == 19 # 17 default + 2 custom assert "template1" in templates assert "template2" in templates @@ -433,8 +434,8 @@ def register_template_worker(template_id: int) -> None: # All registrations should succeed assert len(errors) == 0 assert len(results) == 10 - # Including 8 default templates - assert len(registry.list_all_templates()) == 26 # 16 default + 10 registered + # Including default templates + assert len(registry.list_all_templates()) == 27 # 17 default + 10 registered # Check all templates are registered for i in range(10):