Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/mdio/segy/_disaster_recovery_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Consumer-side utility to get both raw and transformed header data with single filesystem read."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np
from segy.file import SegyFile
from segy.transforms import Transform, ByteSwapTransform, IbmFloatTransform
from numpy.typing import NDArray

def _reverse_single_transform(data: NDArray, transform: Transform, endianness: Endianness) -> NDArray:
"""Reverse a single transform operation."""
from segy.schema import Endianness
from segy.transforms import ByteSwapTransform
from segy.transforms import IbmFloatTransform

if isinstance(transform, ByteSwapTransform):
# Reverse the endianness conversion
if endianness == Endianness.LITTLE:
return data

reverse_transform = ByteSwapTransform(Endianness.BIG)
return reverse_transform.apply(data)

elif isinstance(transform, IbmFloatTransform): # TODO: This seems incorrect...
# Reverse IBM float conversion
reverse_direction = "to_ibm" if transform.direction == "to_ieee" else "to_ieee"
reverse_transform = IbmFloatTransform(reverse_direction, transform.keys)
return reverse_transform.apply(data)

else:
# For unknown transforms, return data unchanged
return data

def get_header_raw_and_transformed(
segy_file: SegyFile,
indices: int | list[int] | NDArray | slice,
do_reverse_transforms: bool = True
) -> tuple[NDArray | None, NDArray, NDArray]:
"""Get both raw and transformed header data.

Args:
segy_file: The SegyFile instance
indices: Which headers to retrieve
do_reverse_transforms: Whether to apply the reverse transform to get raw data

Returns:
Tuple of (raw_headers, transformed_headers, traces)
"""
traces = segy_file.trace[indices]
transformed_headers = traces.header

# Reverse transforms to get raw data
if do_reverse_transforms:
raw_headers = _reverse_transforms(transformed_headers, segy_file.header.transform_pipeline, segy_file.spec.endianness)
else:
raw_headers = None

return raw_headers, transformed_headers, traces

def _reverse_transforms(transformed_data: NDArray, transform_pipeline, endianness: Endianness) -> NDArray:
"""Reverse the transform pipeline to get raw data."""
raw_data = transformed_data.copy() if hasattr(transformed_data, 'copy') else transformed_data

# Apply transforms in reverse order
for transform in reversed(transform_pipeline.transforms):
raw_data = _reverse_single_transform(raw_data, transform, endianness)

return raw_data
44 changes: 13 additions & 31 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from mdio.api.io import to_mdio
from mdio.builder.schemas.dtype import ScalarType
from mdio.segy._disaster_recovery_wrapper import get_header_raw_and_transformed

if TYPE_CHECKING:
from segy.arrays import HeaderArray
Expand Down Expand Up @@ -81,7 +82,6 @@ def header_scan_worker(

return cast("HeaderArray", trace_header)


def trace_worker( # noqa: PLR0913
segy_kw: SegyFileArguments,
output_path: UPath,
Expand Down Expand Up @@ -120,26 +120,30 @@ def trace_worker( # noqa: PLR0913
zarr_config.set({"threading.max_workers": 1})

live_trace_indexes = local_grid_map[not_null].tolist()
traces = segy_file.trace[live_trace_indexes]

header_key = "headers"
raw_header_key = "raw_headers"

# Used to disable the reverse transforms if we aren't going to write the raw headers
do_reverse_transforms = False

# Get subset of the dataset that has not yet been saved
# The headers might not be present in the dataset
worker_variables = [data_variable_name]
if header_key in dataset.data_vars: # Keeping the `if` here to allow for more worker configurations
worker_variables.append(header_key)
if raw_header_key in dataset.data_vars:

do_reverse_transforms = True
worker_variables.append(raw_header_key)

raw_headers, transformed_headers, traces = get_header_raw_and_transformed(segy_file, live_trace_indexes, do_reverse_transforms=do_reverse_transforms)
ds_to_write = dataset[worker_variables]

if header_key in worker_variables:
# Create temporary array for headers with the correct shape
# TODO(BrianMichell): Implement this better so that we can enable fill values without changing the code. #noqa: TD003
tmp_headers = np.zeros_like(dataset[header_key])
tmp_headers[not_null] = traces.header
tmp_headers[not_null] = transformed_headers
# Create a new Variable object to avoid copying the temporary array
# The ideal solution is to use `ds_to_write[header_key][:] = tmp_headers`
# but Xarray appears to be copying memory instead of doing direct assignment.
Expand All @@ -150,41 +154,19 @@ def trace_worker( # noqa: PLR0913
attrs=ds_to_write[header_key].attrs,
encoding=ds_to_write[header_key].encoding, # Not strictly necessary, but safer than not doing it.
)
del transformed_headers # Manage memory
if raw_header_key in worker_variables:
tmp_raw_headers = np.zeros_like(dataset[raw_header_key])

# Get the indices where we need to place results
live_mask = not_null
live_positions = np.where(live_mask.ravel())[0]

if len(live_positions) > 0:
# Calculate byte ranges for headers
header_size = 240
trace_offset = segy_file.spec.trace.offset
trace_itemsize = segy_file.spec.trace.itemsize

starts = []
ends = []
for global_trace_idx in live_trace_indexes:
header_start = trace_offset + global_trace_idx * trace_itemsize
header_end = header_start + header_size
starts.append(header_start)
ends.append(header_end)

# Capture raw bytes
raw_header_bytes = merge_cat_file(segy_file.fs, segy_file.url, starts, ends)

# Convert and place results
raw_headers_array = np.frombuffer(bytes(raw_header_bytes), dtype="|V240")
tmp_raw_headers.ravel()[live_positions] = raw_headers_array
tmp_raw_headers[not_null] = raw_headers.view("|V240")

ds_to_write[raw_header_key] = Variable(
ds_to_write[raw_header_key].dims,
tmp_raw_headers,
attrs=ds_to_write[raw_header_key].attrs,
encoding=ds_to_write[raw_header_key].encoding,
)
encoding=ds_to_write[raw_header_key].encoding, # Not strictly necessary, but safer than not doing it.


del raw_headers # Manage memory
data_variable = ds_to_write[data_variable_name]
fill_value = _get_fill_value(ScalarType(data_variable.dtype.name))
tmp_samples = np.full_like(data_variable, fill_value=fill_value)
Expand Down
2 changes: 1 addition & 1 deletion src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,4 @@ def to_segy(

non_consecutive_axes -= 1

return block_io_records
return block_io_records
Loading
Loading