Skip to content

Commit

Permalink
use concurrent futures which handles and raises exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
tasansal committed Apr 28, 2023
1 parent 79e7008 commit 7f9d580
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 24 deletions.
8 changes: 4 additions & 4 deletions src/mdio/segy/_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,11 @@ def trace_worker(
# However, we need pool.starmap because we have more than one
# argument to make pool.map work with multiple arguments, we
# wrap the function and consolidate arguments to one
def trace_worker_map(args):
"""Wrapper for trace worker to use with tqdm."""
def trace_worker_wrapper(args):
"""Wrapper to make it work with map and multiple arguments."""
return trace_worker(*args)


def header_scan_worker_map(args):
"""Wrapper for header scan worker to use with tqdm."""
def header_scan_worker_wrapper(args):
"""Wrapper to make it work with map and multiple arguments."""
return header_scan_worker(*args)
26 changes: 12 additions & 14 deletions src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from itertools import repeat

import numpy as np
Expand All @@ -19,7 +20,7 @@

from mdio.core import Grid
from mdio.core.indexing import ChunkIterator
from mdio.segy._workers import trace_worker_map
from mdio.segy._workers import trace_worker_wrapper
from mdio.segy.byte_utils import ByteOrder
from mdio.segy.byte_utils import Dtype
from mdio.segy.creation import concat_files
Expand Down Expand Up @@ -142,25 +143,22 @@ def to_zarr(
repeat(segy_endian),
)

# This is for Unix async writes to s3fs/fsspec, when using
# multiprocessing. By default, Linux uses the 'fork' method.
# 'spawn' is a little slower to spool up processes, but 'fork'
# doesn't work. If you don't use this, processes get deadlocked
# on cloud stores. 'spawn' is default in Windows.
# For Unix async writes with s3fs/fsspec & multiprocessing,
# use 'spawn' instead of default 'fork' to avoid deadlocks
# on cloud stores. Slower but necessary. Default on Windows.
num_workers = min(num_chunks, NUM_CORES)
context = mp.get_context("spawn")
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)

# This is the chunksize for multiprocessing. Not to be confused
# with Zarr chunksize.
num_workers = min(num_chunks, NUM_CORES)
# Chunksize here is for multiprocessing, not Zarr chunksize.
pool_chunksize, extra = divmod(num_chunks, num_workers * 4)
pool_chunksize += 1 if extra else pool_chunksize

tqdm_kw = dict(unit="block", dynamic_ncols=True)
with context.Pool(num_workers) as pool:
# pool.imap is lazy
lazy_work = pool.imap(
func=trace_worker_map,
iterable=parallel_inputs,
with executor:
lazy_work = executor.map(
trace_worker_wrapper, # fn
parallel_inputs, # iterables
chunksize=pool_chunksize,
)

Expand Down
12 changes: 6 additions & 6 deletions src/mdio/segy/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from __future__ import annotations

from concurrent.futures import ProcessPoolExecutor
from itertools import repeat
from math import ceil
from multiprocessing import Pool
from typing import Any
from typing import Sequence

Expand All @@ -15,7 +15,7 @@
from tqdm.auto import tqdm

from mdio.core import Dimension
from mdio.segy._workers import header_scan_worker_map
from mdio.segy._workers import header_scan_worker_wrapper


NUM_CORES = cpu_count(logical=False)
Expand Down Expand Up @@ -117,11 +117,11 @@ def parse_trace_headers(
num_workers = min(n_blocks, NUM_CORES)

tqdm_kw = dict(unit="block", dynamic_ncols=True)
with Pool(num_workers) as pool:
with ProcessPoolExecutor(num_workers) as executor:
# pool.imap is lazy
lazy_work = pool.imap(
func=header_scan_worker_map,
iterable=parallel_inputs,
lazy_work = executor.map(
header_scan_worker_wrapper, # fn
parallel_inputs, # iterables
chunksize=2, # Not array chunks. This is for `multiprocessing`
)

Expand Down

0 comments on commit 7f9d580

Please sign in to comment.