Skip to content

Commit

Permalink
simplify mp code and remove wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal committed Apr 28, 2023
1 parent 7f9d580 commit 75bfa50
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 40 deletions.
14 changes: 0 additions & 14 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,3 @@ def trace_worker(
max_val = tmp_data.max()

return count, chunk_sum, chunk_sum_squares, min_val, max_val


# tqdm only works properly with pool.map
# However, we need pool.starmap because we have more than one
# argument to make pool.map work with multiple arguments, we
# wrap the function and consolidate arguments to one
def trace_worker_wrapper(args):
"""Wrapper to make it work with map and multiple arguments."""
return trace_worker(*args)


def header_scan_worker_wrapper(args):
"""Wrapper to make it work with map and multiple arguments."""
return header_scan_worker(*args)
21 changes: 8 additions & 13 deletions src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from mdio.core import Grid
from mdio.core.indexing import ChunkIterator
from mdio.segy._workers import trace_worker_wrapper
from mdio.segy._workers import trace_worker
from mdio.segy.byte_utils import ByteOrder
from mdio.segy.byte_utils import Dtype
from mdio.segy.creation import concat_files
Expand Down Expand Up @@ -133,16 +133,6 @@ def to_zarr(
chunker = ChunkIterator(trace_array, chunk_samples=False)
num_chunks = len(chunker)

# Setting all multiprocessing parameters.
parallel_inputs = zip( # noqa: B905
repeat(segy_path),
repeat(trace_array),
repeat(header_array),
repeat(grid),
chunker,
repeat(segy_endian),
)

# For Unix async writes with s3fs/fsspec & multiprocessing,
# use 'spawn' instead of default 'fork' to avoid deadlocks
# on cloud stores. Slower but necessary. Default on Windows.
Expand All @@ -157,8 +147,13 @@ def to_zarr(
tqdm_kw = dict(unit="block", dynamic_ncols=True)
with executor:
lazy_work = executor.map(
trace_worker_wrapper, # fn
parallel_inputs, # iterables
trace_worker, # fn
repeat(segy_path),
repeat(trace_array),
repeat(header_array),
repeat(grid),
chunker,
repeat(segy_endian),
chunksize=pool_chunksize,
)

Expand Down
20 changes: 7 additions & 13 deletions src/mdio/segy/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tqdm.auto import tqdm

from mdio.core import Dimension
from mdio.segy._workers import header_scan_worker_wrapper
from mdio.segy._workers import header_scan_worker


NUM_CORES = cpu_count(logical=False)
Expand Down Expand Up @@ -104,24 +104,18 @@ def parse_trace_headers(

trace_ranges.append((start, stop))

# Note: Make sure the order of this is exactly
# the same as the function call.
parallel_inputs = zip( # noqa: B905 or strict=False >= py3.10
repeat(segy_path),
trace_ranges,
repeat(byte_locs),
repeat(byte_lengths),
repeat(segy_endian),
)

num_workers = min(n_blocks, NUM_CORES)

tqdm_kw = dict(unit="block", dynamic_ncols=True)
with ProcessPoolExecutor(num_workers) as executor:
# pool.imap is lazy
lazy_work = executor.map(
header_scan_worker_wrapper, # fn
parallel_inputs, # iterables
header_scan_worker, # fn
repeat(segy_path),
trace_ranges,
repeat(byte_locs),
repeat(byte_lengths),
repeat(segy_endian),
chunksize=2, # Not array chunks. This is for `multiprocessing`
)

Expand Down

0 comments on commit 75bfa50

Please sign in to comment.