From a82fe803c48cbce5d4827a2087221d02d9f83dd5 Mon Sep 17 00:00:00 2001 From: Altay Sansal Date: Tue, 21 May 2024 11:12:21 -0500 Subject: [PATCH] Update CPU count configuration for parallel operations Changes have been made to improve the control of parallel computations in the application. The number of CPUs used by the mdio_to_segy and to_zarr methods can now be controlled by altering environmental variables MDIO__EXPORT__CPU_COUNT and MDIO__IMPORT__CPU_COUNT respectively. This allows users to optimize the program's performance based on their specific hardware setup. --- src/mdio/converters/mdio.py | 13 ++++++++++++- src/mdio/segy/blocked_io.py | 7 ++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/mdio/converters/mdio.py b/src/mdio/converters/mdio.py index e50f75a0..c5a394c6 100644 --- a/src/mdio/converters/mdio.py +++ b/src/mdio/converters/mdio.py @@ -3,10 +3,12 @@ from __future__ import annotations +import os from os import path from tempfile import TemporaryDirectory import numpy as np +from psutil import cpu_count from tqdm.dask import TqdmCallback from mdio import MDIOReader @@ -24,6 +26,10 @@ distributed = None +default_cpus = cpu_count(logical=True) +NUM_CPUS = int(os.getenv("MDIO__EXPORT__CPU_COUNT", default_cpus)) + + def mdio_to_segy( # noqa: C901 mdio_path_or_buffer: str, output_segy_path: str, @@ -176,7 +182,12 @@ def mdio_to_segy( # noqa: C901 out_byteorder=out_byteorder, file_root=tmp_dir.name, axis=tuple(range(1, samples.ndim)), - ).compute() + ) + + if client is not None: + flat_files = flat_files.compute() + else: + flat_files = flat_files.compute(num_workers=NUM_CPUS) # If whole blocks are missing, remove them from the list. missing_mask = flat_files == "missing" diff --git a/src/mdio/segy/blocked_io.py b/src/mdio/segy/blocked_io.py index 97c8aacb..75554e02 100644 --- a/src/mdio/segy/blocked_io.py +++ b/src/mdio/segy/blocked_io.py @@ -4,6 +4,7 @@ from __future__ import annotations import multiprocessing as mp +import os from concurrent.futures import ProcessPoolExecutor from itertools import repeat @@ -35,8 +36,8 @@ ZFPY = None zfpy = None -# Globals -NUM_CORES = cpu_count(logical=False) +default_cpus = cpu_count(logical=True) +NUM_CPUS = int(os.getenv("MDIO__IMPORT__CPU_COUNT", default_cpus)) def to_zarr( @@ -136,7 +137,7 @@ def to_zarr( # For Unix async writes with s3fs/fsspec & multiprocessing, # use 'spawn' instead of default 'fork' to avoid deadlocks # on cloud stores. Slower but necessary. Default on Windows. - num_workers = min(num_chunks, NUM_CORES) + num_workers = min(num_chunks, NUM_CPUS) context = mp.get_context("spawn") executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)