diff --git a/src/lightning/data/streaming/data_processor.py b/src/lightning/data/streaming/data_processor.py index ad9ec4793bc47..f2f79f73463ed 100644 --- a/src/lightning/data/streaming/data_processor.py +++ b/src/lightning/data/streaming/data_processor.py @@ -7,6 +7,7 @@ import types from abc import abstractmethod from dataclasses import dataclass +from datetime import datetime from multiprocessing import Process, Queue from queue import Empty from shutil import copyfile, rmtree @@ -15,7 +16,7 @@ from urllib import parse import torch -from tqdm.auto import tqdm +from tqdm.auto import tqdm as _tqdm from lightning import seed_everything from lightning.data.streaming import Cache @@ -278,6 +279,7 @@ def __init__( error_queue: Queue, stop_queue: Queue, num_downloaders: int, + num_uploaders: int, remove: bool, ) -> None: """The BaseWorker is responsible to process the user data.""" @@ -290,18 +292,19 @@ def __init__( self.items = items self.num_items = len(self.items) self.num_downloaders = num_downloaders + self.num_uploaders = num_uploaders self.remove = remove self.paths: List[List[str]] = [] self.remover: Optional[Process] = None self.downloaders: List[Process] = [] + self.uploaders: List[Process] = [] self.to_download_queues: List[Queue] = [] + self.to_upload_queues: List[Queue] = [] self.stop_queue = stop_queue self.ready_to_process_queue: Queue = Queue() self.remove_queue: Queue = Queue() - self.upload_queue: Queue = Queue() self.progress_queue: Queue = progress_queue self.error_queue: Queue = error_queue - self.uploader: Optional[Process] = None self._collected_items = 0 self._counter = 0 self._last_time = time() @@ -316,14 +319,14 @@ def run(self) -> None: traceback_format = traceback.format_exc() print(traceback_format) self.error_queue.put(traceback_format) - print(f"Worker {self.worker_index} is done.") + print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is done.") def _setup(self) -> None: self._set_environ_variables() self._create_cache() self._collect_paths() self._start_downloaders() - self._start_uploader() + self._start_uploaders() self._start_remover() def _loop(self) -> None: @@ -335,13 +338,19 @@ def _loop(self) -> None: if index is None: num_downloader_finished += 1 if num_downloader_finished == self.num_downloaders: + print(f"Worker {str(_get_node_rank() * self.num_workers + self.worker_index)} is terminating.") + if isinstance(self.data_recipe, DataChunkRecipe): self._handle_data_chunk_recipe_end() if self.output_dir.url if self.output_dir.url else self.output_dir.path: - assert self.uploader - self.upload_queue.put(None) - self.uploader.join() + # Inform the uploaders they are doing working + for i in range(self.num_uploaders): + self.to_upload_queues[i].put(None) + + # Wait for them all to be finished + for uploader in self.uploaders: + uploader.join() if self.remove: assert self.remover @@ -402,7 +411,7 @@ def _try_upload(self, filepath: Optional[str]) -> None: return assert os.path.exists(filepath), filepath - self.upload_queue.put(filepath) + self.to_upload_queues[self._counter % self.num_uploaders].put(filepath) def _collect_paths(self) -> None: items = [] @@ -475,19 +484,24 @@ def _start_remover(self) -> None: ) self.remover.start() - def _start_uploader(self) -> None: + def _start_uploaders(self) -> None: if self.output_dir.path is None and self.output_dir.url is None: return - self.uploader = Process( - target=_upload_fn, - args=( - self.upload_queue, - self.remove_queue, - self.cache_chunks_dir, - self.output_dir, - ), - ) - self.uploader.start() + + for _ in range(self.num_uploaders): + to_upload_queue: Queue = Queue() + p = Process( + target=_upload_fn, + args=( + to_upload_queue, + self.remove_queue, + self.cache_chunks_dir, + self.output_dir, + ), + ) + p.start() + self.uploaders.append(p) + self.to_upload_queues.append(to_upload_queue) def _handle_data_chunk_recipe(self, index: int) -> None: try: @@ -509,10 +523,10 @@ def _handle_data_chunk_recipe(self, index: int) -> None: def _handle_data_chunk_recipe_end(self) -> None: chunks_filepaths = self.cache.done() - if chunks_filepaths: - for chunk_filepath in chunks_filepaths: + if chunks_filepaths and len(self.to_upload_queues): + for i, chunk_filepath in enumerate(chunks_filepaths): if isinstance(chunk_filepath, str) and os.path.exists(chunk_filepath): - self.upload_queue.put(chunk_filepath) + self.to_upload_queues[i % self.num_uploaders].put(chunk_filepath) def _handle_data_transform_recipe(self, index: int) -> None: # Don't use a context manager to avoid deleting files that are being uploaded. @@ -721,6 +735,7 @@ def __init__( output_dir: Optional[Union[str, Dir]] = None, num_workers: Optional[int] = None, num_downloaders: Optional[int] = None, + num_uploaders: Optional[int] = None, delete_cached_files: bool = True, fast_dev_run: Optional[Union[bool, int]] = None, random_seed: Optional[int] = 42, @@ -734,6 +749,7 @@ def __init__( output_dir: The path to where the output data are stored. num_workers: The number of worker threads to use. num_downloaders: The number of file downloaders to use. + num_uploaders: The number of file uploaders to use. delete_cached_files: Whether to delete the cached files. fast_dev_run: Whether to run a quick dev run. random_seed: The random seed to be set before shuffling the data. @@ -744,7 +760,8 @@ def __init__( self.input_dir = _resolve_dir(input_dir) self.output_dir = _resolve_dir(output_dir) self.num_workers = num_workers or (1 if fast_dev_run else (os.cpu_count() or 1) * 4) - self.num_downloaders = num_downloaders or 1 + self.num_downloaders = num_downloaders or 2 + self.num_uploaders = num_uploaders or 5 self.delete_cached_files = delete_cached_files self.fast_dev_run = _get_fast_dev_run() if fast_dev_run is None else fast_dev_run self.workers: Any = [] @@ -816,30 +833,43 @@ def run(self, data_recipe: DataRecipe) -> None: current_total = 0 has_failed = False - with tqdm(total=num_items, smoothing=0, position=-1, mininterval=1) as pbar: - while True: + pbar = _tqdm( + desc="Progress", + total=num_items, + smoothing=0, + position=-1, + mininterval=1, + leave=True, + dynamic_ncols=True, + ) + + while True: + try: + error = self.error_queue.get(timeout=0.001) + self._exit_on_error(error) + except Empty: + assert self.progress_queue try: - error = self.error_queue.get(timeout=0.001) - self._exit_on_error(error) + index, counter = self.progress_queue.get(timeout=0.001) except Empty: - assert self.progress_queue - try: - index, counter = self.progress_queue.get(timeout=0.001) - except Empty: - continue - self.workers_tracker[index] = counter - new_total = sum(self.workers_tracker.values()) - - pbar.update(new_total - current_total) - current_total = new_total - if current_total == num_items: - break - - # Exit early if all the workers are done. - # This means there were some kinda of errors. - if all(not w.is_alive() for w in self.workers): - has_failed = True - break + continue + self.workers_tracker[index] = counter + new_total = sum(self.workers_tracker.values()) + + pbar.set_postfix({"time": datetime.now().strftime("%H:%M:%S.%f")}) + pbar.update(new_total - current_total) + + current_total = new_total + if current_total == num_items: + break + + # Exit early if all the workers are done. + # This means there were some kinda of errors. + if all(not w.is_alive() for w in self.workers): + has_failed = True + break + + pbar.close() num_nodes = _get_num_nodes() node_rank = _get_node_rank() @@ -896,6 +926,7 @@ def _create_process_workers(self, data_recipe: DataRecipe, workers_user_items: L self.error_queue, stop_queues[-1], self.num_downloaders, + self.num_uploaders, self.delete_cached_files, ) worker.start() diff --git a/tests/tests_data/streaming/test_data_processor.py b/tests/tests_data/streaming/test_data_processor.py index df0bfe0116789..45316e2dbb46f 100644 --- a/tests/tests_data/streaming/test_data_processor.py +++ b/tests/tests_data/streaming/test_data_processor.py @@ -484,6 +484,8 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, delete_cached_files=delete_cached_files, fast_dev_run=fast_dev_run, output_dir=remote_output_dir, + num_uploaders=1, + num_downloaders=1, ) data_processor.run(CustomDataChunkRecipe(chunk_size=2)) @@ -508,6 +510,7 @@ def test_data_processsor_distributed(fast_dev_run, delete_cached_files, tmpdir, data_processor = TestDataProcessor( input_dir=input_dir, num_workers=2, + num_uploaders=1, num_downloaders=1, delete_cached_files=delete_cached_files, fast_dev_run=fast_dev_run, @@ -668,7 +671,6 @@ def test_data_processing_map(monkeypatch, tmpdir): def optimize_fn(filepath): - print(filepath) from PIL import Image return [Image.open(filepath), os.path.basename(filepath)]