Skip to content

Commit

Permalink
Merge pull request #202 from TGSAI/enh/ingestion_exception_handling
Browse files Browse the repository at this point in the history
Ingestion: Replace `multiprocessing.Pool` with `concurrent.futures.ProcessPoolExecutor`
  • Loading branch information
tasansal committed Apr 28, 2023
2 parents 79e7008 + 75bfa50 commit 793e968
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 54 deletions.
14 changes: 0 additions & 14 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,3 @@ def trace_worker(
max_val = tmp_data.max()

return count, chunk_sum, chunk_sum_squares, min_val, max_val


# tqdm only works properly with pool.map
# However, we need pool.starmap because we have more than one
# argument to make pool.map work with multiple arguments, we
# wrap the function and consolidate arguments to one
def trace_worker_map(args):
"""Wrapper for trace worker to use with tqdm."""
return trace_worker(*args)


def header_scan_worker_map(args):
"""Wrapper for header scan worker to use with tqdm."""
return header_scan_worker(*args)
41 changes: 17 additions & 24 deletions src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

import numpy as np
Expand All @@ -19,7 +20,7 @@

from mdio.core import Grid
from mdio.core.indexing import ChunkIterator
from mdio.segy._workers import trace_worker_map
from mdio.segy._workers import trace_worker
from mdio.segy.byte_utils import ByteOrder
from mdio.segy.byte_utils import Dtype
from mdio.segy.creation import concat_files
Expand Down Expand Up @@ -132,35 +133,27 @@ def to_zarr(
chunker = ChunkIterator(trace_array, chunk_samples=False)
num_chunks = len(chunker)

# Setting all multiprocessing parameters.
parallel_inputs = zip( # noqa: B905
repeat(segy_path),
repeat(trace_array),
repeat(header_array),
repeat(grid),
chunker,
repeat(segy_endian),
)

# This is for Unix async writes to s3fs/fsspec, when using
# multiprocessing. By default, Linux uses the 'fork' method.
# 'spawn' is a little slower to spool up processes, but 'fork'
# doesn't work. If you don't use this, processes get deadlocked
# on cloud stores. 'spawn' is default in Windows.
# For Unix async writes with s3fs/fsspec & multiprocessing,
# use 'spawn' instead of default 'fork' to avoid deadlocks
# on cloud stores. Slower but necessary. Default on Windows.
num_workers = min(num_chunks, NUM_CORES)
context = mp.get_context("spawn")
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)

# This is the chunksize for multiprocessing. Not to be confused
# with Zarr chunksize.
num_workers = min(num_chunks, NUM_CORES)
# Chunksize here is for multiprocessing, not Zarr chunksize.
pool_chunksize, extra = divmod(num_chunks, num_workers * 4)
pool_chunksize += 1 if extra else pool_chunksize

tqdm_kw = dict(unit="block", dynamic_ncols=True)
with context.Pool(num_workers) as pool:
# pool.imap is lazy
lazy_work = pool.imap(
func=trace_worker_map,
iterable=parallel_inputs,
with executor:
lazy_work = executor.map(
trace_worker, # fn
repeat(segy_path),
repeat(trace_array),
repeat(header_array),
repeat(grid),
chunker,
repeat(segy_endian),
chunksize=pool_chunksize,
)

Expand Down
26 changes: 10 additions & 16 deletions src/mdio/segy/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from __future__ import annotations

from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from math import ceil
from multiprocessing import Pool
from typing import Any
from typing import Sequence

Expand All @@ -15,7 +15,7 @@
from tqdm.auto import tqdm

from mdio.core import Dimension
from mdio.segy._workers import header_scan_worker_map
from mdio.segy._workers import header_scan_worker


NUM_CORES = cpu_count(logical=False)
Expand Down Expand Up @@ -104,24 +104,18 @@ def parse_trace_headers(

trace_ranges.append((start, stop))

# Note: Make sure the order of this is exactly
# the same as the function call.
parallel_inputs = zip( # noqa: B905 or strict=False >= py3.10
repeat(segy_path),
trace_ranges,
repeat(byte_locs),
repeat(byte_lengths),
repeat(segy_endian),
)

num_workers = min(n_blocks, NUM_CORES)

tqdm_kw = dict(unit="block", dynamic_ncols=True)
with Pool(num_workers) as pool:
with ProcessPoolExecutor(num_workers) as executor:
# pool.imap is lazy
lazy_work = pool.imap(
func=header_scan_worker_map,
iterable=parallel_inputs,
lazy_work = executor.map(
header_scan_worker, # fn
repeat(segy_path),
trace_ranges,
repeat(byte_locs),
repeat(byte_lengths),
repeat(segy_endian),
chunksize=2, # Not array chunks. This is for `multiprocessing`
)

Expand Down

0 comments on commit 793e968

Please sign in to comment.