diff --git a/README.md b/README.md index b1de6d8..e2653e6 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ in [issue #1](https://github.com/Imageomics/distributed-downloader/issues/1)). official websites: - [OpenMPI](https://docs.open-mpi.org/en/v5.0.x/installing-open-mpi/quickstart.html) - [IntelMPI](https://www.intel.com/content/www/us/en/docs/mpi-library/developer-guide-linux/2021-6/installation.html) -3. Install required package: +3. Install the required package: - For general use: ```commandline pip install git+https://github.com/Imageomics/distributed-downloader @@ -44,33 +44,86 @@ in [issue #1](https://github.com/Imageomics/distributed-downloader/issues/1)). `distributed-downloader` utilizes multiple nodes on a High Performance Computing (HPC) system (specifically, an HPC with `slurm` workload manager) to download a collection of images specified in a given tab-delimited text file. +### Configuration + +The downloader is configured using a YAML configuration file. Here's an example configuration: + +```yaml +# Example configuration file +path_to_input: "/path/to/input/urls.csv" +path_to_output: "/path/to/output" + +output_structure: + urls_folder: "urls" + logs_folder: "logs" + images_folder: "images" + schedules_folder: "schedules" + profiles_table: "profiles.csv" + ignored_table: "ignored.csv" + inner_checkpoint_file: "checkpoint.json" + tools_folder: "tools" + +downloader_parameters: + num_downloads: 1 + max_nodes: 20 + workers_per_node: 20 + cpu_per_worker: 1 + header: true + image_size: 224 + logger_level: "INFO" + batch_size: 10000 + rate_multiplier: 0.5 + default_rate_limit: 3 + +tools_parameters: + num_workers: 1 + max_nodes: 10 + workers_per_node: 20 + cpu_per_worker: 1 + threshold_size: 10000 + new_resize_size: 224 +``` + ### Main script -There are one manual step to get the downloader running as designed: +There is one manual step to get the downloader running as designed: You need to call function `download_images` from package `distributed_downloader` with the `config_path` as an argument. This will initialize filestructure in the output folder, partition the input file, profile the servers for their possible download speed, and start downloading images. If downloading didn't finish, you can call the same function with the same `config_path` argument to continue downloading. +```python +from distributed_downloader import download_images + +# Start or continue downloading +download_images("/path/to/config.yaml") +``` + Downloader has two logging profiles: -- `INFO` - logs only the most important information, for example when a batch is started and finished. It also logs out +- `INFO` - logs only the most important information, for example, when a batch is started and finished. It also logs out any error that occurred during download, image decoding, or writing batch to the filesystem -- `DEBUG` - logs all information, for example logging start and finish of each downloaded image. +- `DEBUG` - logs all information, for example, logging start and finish of each downloaded image. ### Tools script -After downloading is finished, you can use the `tools` package perform various operations on them. +After downloading is finished, you can use the `tools` package to perform various operations on the downloaded images. To do this, you need to call the function `apply_tools` from package `distributed_downloader` with the `config_path` and `tool_name` as an argument. -Following tools are available: -- `resize` - resizes images to a new size -- `image_verification` - verifies images by checking if they are corrupted -- `duplication_based` - removes duplicate images -- `size_based` - removes images that are too small +```python +from distributed_downloader import apply_tools -You can also add your own tool, the instructions are in the section below. +# Apply a specific tool +apply_tools("/path/to/config.yaml", "resize") +``` + +The following tools are available: + +- `resize` - resizes images to a new size (specified in config) +- `image_verification` - verifies images by checking if they are corrupted +- `duplication_based` - removes duplicate images using MD5 hashing +- `size_based` - removes images that are too small (threshold specified in config) ### Creating a new tool @@ -87,7 +140,34 @@ You can also add your own tool by creating 3 classes and registering them with r - Each tool should have a `run` method that will be called by the main script. - Each tool should be registered with a decorator from a respective package (`FilterRegister` from `filters` etc.) -## Rules for scripts: +Example of creating a custom tool: + +```python +from distributed_downloader.tools import FilterRegister, SchedulerRegister, RunnerRegister, ToolsBase + + +@FilterRegister("my_custom_tool") +class MyCustomToolFilter(ToolsBase): + def run(self): + # Implementation of filter step + pass + + +@SchedulerRegister("my_custom_tool") +class MyCustomToolScheduler(ToolsBase): + def run(self): + # Implementation of scheduler step + pass + + +@RunnerRegister("my_custom_tool") +class MyCustomToolRunner(ToolsBase): + def run(self): + # Implementation of runner step + pass +``` + +## Environment Variables All scripts can expect to have the following custom environment variables, specific variables are only initialized when respective tool is called: @@ -123,3 +203,63 @@ when respective tool is called: - `TOOLS_CPU_PER_WORKER` - `TOOLS_THRESHOLD_SIZE` - `TOOLS_NEW_RESIZE_SIZE` + +## Working with downloaded data + +Downloaded data is stored in `images_folder` (configured in config file), +partitioned by `server_name` and `partition_id`, in two parquet files with following schemes: + +- `successes.parquet`: + - uuid: string - downloaded dataset internal unique identifier (created to distinguish between all component datasets downloaded with this package) + - source_id: string - id of the entry provided by its source (e.g., `gbifID`) + - identifier: string - source URL of the image + - is_license_full: boolean - True indicates that `license`, `source`, and `title` all have non-null values for that + particular entry. + - license: string + - source: string + - title: string + - hashsum_original: string - MD5 hash of the original image data + - hashsum_resized: string - MD5 hash of the resized image data + - original_size: [height, width] - dimensions of original image + - resized_size: [height, width] - dimensions after resizing + - image: bytes - binary image data + +- `errors.parquet`: + - uuid: string - downloaded dataset internal unique identifier (created to distinguish between all component datasets downloaded with this package) + - identifier: string - URL of the image + - retry_count: integer - number of download attempts + - error_code: integer - HTTP or other error code + - error_msg: string - detailed error message + +For general operations (that do not involve access to `image` column, e.g. count the total number of entries, create +size distribution etc.) it is recommended to use Spark or similar applications. For any operation that does involve +`image` column, it is recommended to use Pandas or similar library to access each parquet file separately. + +## Supported Image Formats + +The downloader supports most common image formats, including: + +- JPEG/JPG +- PNG +- GIF (first frame only) +- BMP +- TIFF + +## Error Handling and Troubleshooting + +Common issues and solutions: + +1. **Rate limiting errors**: If you see many errors with code 429, adjust the `default_rate_limit` in your config to a + lower value. + +2. **Memory issues**: If the process is killed due to memory constraints, try reducing `batch_size` or + `workers_per_node` in your config. + +3. **Corrupt images**: Images that cannot be decoded are logged in the errors parquet file with appropriate error codes. + +4. **Resuming failed downloads**: The downloader creates checkpoints automatically. Simply run the same command again to + resume from the last checkpoint. + +## License + +This project is licensed under the MIT License - see the LICENSE file for details. diff --git a/environment.yaml b/environment.yaml index e9c9913..3fd9fd7 100644 --- a/environment.yaml +++ b/environment.yaml @@ -3,51 +3,35 @@ channels: - conda-forge - defaults dependencies: - - openmpi - - python - - uv - - opencv - - pyspark + - python>=3.10 <=3.12 - attrs - brotli - - certifi - - charset-normalizer - cramjam - cython - - exceptiongroup - fsspec - - hatchling - - idna - inflate64 - - iniconfig + - openmpi - mpi4py - multivolumefile - - numpy - - packaging + - opencv - pandas - pathspec - pillow - - pip - - pluggy - psutil - - py4j - pyarrow - pybcj - pycryptodomex - pyppmd - - pytest - - python-dateutil + - pyspark>=3.4.0 - python-dotenv - - pytz - pyyaml - pyzstd - requests - setuptools - - six - texttable - - tomli - trove-classifiers - typing-extensions - - tzdata - - urllib3 - wheel + # Development dependencies + - pytest + - ruff \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b56cad2..4b750fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling", "hatch-requirements-txt"] +requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] @@ -7,7 +7,7 @@ packages = ["src/distributed_downloader"] [project] name = "distributed_downloader" -dynamic = ["dependencies", "version"] +dynamic = ["version"] authors = [ { name = "Andrey Kopanev", email = "kopanev.1@osu.edu" }, { name = "Elizabeth G. Campolongo", email = "e.campolongo479@gmail.com" }, @@ -15,18 +15,47 @@ authors = [ ] description = "A tool for downloading files from a list of URLs in parallel." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10, <=3.12" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] - -[tool.hatch.metadata.hooks.requirements_txt] -files = ["requirements.txt"] +dependencies = [ + "attrs", + "brotli", + "cramjam", + "cython", + "fsspec", + "inflate64", + "mpi4py", + "multivolumefile", + "opencv-python", + "pandas", + "pathspec", + "pillow", + "psutil", + "pyarrow", + "pybcj", + "pycryptodomex", + "pyppmd", + "pyspark", + "python-dotenv", + "pyyaml", + "pyzstd", + "requests", + "setuptools", + "texttable", + "trove-classifiers", + "typing-extensions", + "wheel" +] [project.optional-dependencies] -dev = ["pytest"] +dev = [ + "pytest", + "ruff" +] keywords = [ "parallel", diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 5c32703..0000000 --- a/requirements.txt +++ /dev/null @@ -1,85 +0,0 @@ -# This file was autogenerated by uv via the following command: -# uv pip compile - -o requirements.txt -attrs==24.1.0 -brotli==1.1.0 -certifi==2024.7.4 - # via - # requests -charset-normalizer==3.3.2 - # via - # requests -cramjam==2.8.3 -cython==3.0.11 -exceptiongroup==1.2.2 - # via pytest -fsspec==2024.6.1 -hatchling==1.25.0 -idna==3.7 - # via - # requests -inflate64==1.0.0 -iniconfig==2.0.0 - # via pytest -mpi4py==4.0.0 -multivolumefile==0.2.3 -numpy==2.0.1 - # via - # opencv-python - # pandas - # pyarrow -opencv-python==4.10.0.84 -packaging==24.1 - # via - # hatchling - # pytest -pandas==2.2.2 -pathspec==0.12.1 - # via - # hatchling -pillow==10.4.0 -pip==24.2 -pluggy==1.5.0 - # via - # hatchling - # pytest -psutil==6.0.0 -py4j==0.10.9.7 - # via - # pyspark -pyarrow==17.0.0 -pybcj==1.0.2 -pycryptodomex==3.20.0 -pyppmd==1.1.0 -pyspark==3.5.1 -pytest==8.3.2 -python-dateutil==2.9.0.post0 - # via - # pandas -python-dotenv==1.0.1 -pytz==2024.1 - # via - # pandas -pyyaml==6.0.1 -pyzstd==0.16.1 -requests==2.32.3 -setuptools==72.1.0 -six==1.16.0 - # via - # python-dateutil -texttable==1.7.0 -tomli==2.0.1 - # via - # hatchling - # pytest -trove-classifiers==2024.7.2 - # via - # hatchling -typing-extensions==4.12.2 -tzdata==2024.1 - # via - # pandas -urllib3==2.2.2 - # via - # requests -wheel==0.44.0 -hatch-requirements-txt==0.4.1 diff --git a/scripts/tools_filter.slurm b/scripts/tools_filter.slurm index a3be2c1..4642e34 100644 --- a/scripts/tools_filter.slurm +++ b/scripts/tools_filter.slurm @@ -1,6 +1,7 @@ #!/bin/bash #SBATCH --job-name tool_filter #SBATCH --mem=0 +#SBATCH --time=01:00:00 if [ "$#" -eq 0 ]; then echo "Usage: $0 tool_name" diff --git a/src/distributed_downloader/core/MPI_download_prep.py b/src/distributed_downloader/core/MPI_download_prep.py index a05cc56..0c8c200 100644 --- a/src/distributed_downloader/core/MPI_download_prep.py +++ b/src/distributed_downloader/core/MPI_download_prep.py @@ -1,3 +1,15 @@ +""" +Preparation module for MPI-based distributed image downloads. + +This module handles the creation of download schedules and submission of download jobs +based on configuration parameters. It coordinates the distribution of download tasks +across multiple nodes and handles dependency chains between jobs. + +The main workflow: +1. Create schedules based on server profiles and available resources +2. Submit downloaders with appropriate dependencies +3. Submit verifiers to check the completion of download jobs +""" import os from logging import Logger from typing import Dict, List, Tuple @@ -5,13 +17,30 @@ import pandas as pd from pandas._libs.missing import NAType +from distributed_downloader.core.utils import ( + create_schedule_configs, + verify_batches_for_prep, +) from distributed_downloader.tools.checkpoint import Checkpoint from distributed_downloader.tools.config import Config -from distributed_downloader.tools.utils import submit_job, init_logger, preprocess_dep_ids -from distributed_downloader.core.utils import create_schedule_configs, verify_batches_for_prep +from distributed_downloader.tools.utils import ( + init_logger, + preprocess_dep_ids, + submit_job, +) def schedule_rule(total_batches: int, rule: List[Tuple[int, int]]) -> int | NAType: + """ + Determine the number of nodes to allocate based on batch count and rules. + + Args: + total_batches: The total number of batches to process + rule: List of tuples (min_batches, nodes) for resource allocation rules + + Returns: + int: Number of nodes to allocate or pd.NA if no rule matches + """ for min_batches, nodes in rule: if total_batches >= min_batches: return nodes @@ -19,6 +48,14 @@ def schedule_rule(total_batches: int, rule: List[Tuple[int, int]]) -> int | NATy def init_new_current_folder(old_folder: str) -> None: + """ + Initialize a new 'current' folder for schedules or logs. + + If a 'current' folder exists, rename it with a sequential number and create a new empty one. + + Args: + old_folder: Path to the parent folder containing the 'current' directory + """ if os.path.exists(f"{old_folder}/current"): number_of_folders = len( [folder for folder in os.listdir(old_folder) if os.path.isdir(f"{old_folder}/{folder}")]) @@ -28,6 +65,15 @@ def init_new_current_folder(old_folder: str) -> None: def fix_rule(rule: Dict[str, int]) -> List[Tuple[int, int]]: + """ + Convert scheduling rules from config dict format to sorted list of tuples. + + Args: + rule: Dict mapping minimum batch count (as string) to number of nodes + + Returns: + List[Tuple[int, int]]: List of (min_batches, nodes) tuples, sorted by min_batches in descending order + """ fixed_rule = [] for key, value in rule.items(): fixed_rule.append((int(key), value)) @@ -40,6 +86,19 @@ def submit_downloader(_schedule: str, dep_id: int, mpi_submitter_script: str, downloading_script: str) -> int: + """ + Submit a download job through the MPI submitter script. + + Args: + _schedule: Path to the schedule directory + iteration_id: Iteration identifier for the job + dep_id: ID of the job this submission depends on, or None + mpi_submitter_script: Path to the MPI job submission script + downloading_script: Path to the downloading script to be executed + + Returns: + int: Job ID of the submitted job + """ iteration = str(iteration_id).zfill(4) idx = submit_job(mpi_submitter_script, @@ -56,6 +115,19 @@ def submit_verifier(_schedule: str, mpi_submitter_script: str, verifying_script: str, dep_id: int = None) -> int: + """ + Submit a verification job through the MPI submitter script. + + Args: + _schedule: Path to the schedule directory + iteration_id: Iteration identifier for the job + mpi_submitter_script: Path to the MPI job submission script + verifying_script: Path to the verifying script to be executed + dep_id: ID of the job this submission depends on, or None + + Returns: + int: Job ID of the submitted job + """ iteration = str(iteration_id).zfill(4) idx = submit_job(mpi_submitter_script, @@ -68,6 +140,17 @@ def submit_verifier(_schedule: str, def create_schedules(config: Config, logger: Logger) -> None: + """ + Create download schedules based on server profiles and available resources. + + This function analyzes server profiles, determines how many resources to allocate + to each server based on scheduling rules, and creates schedule configurations + for the downloader jobs. + + Args: + config: Configuration object with download parameters + logger: Logger instance for output + """ logger.info("Creating schedules") # Get parameters from config server_ignored_csv: str = config.get_folder("ignored_table") @@ -116,6 +199,18 @@ def create_schedules(config: Config, logger: Logger) -> None: def submit_downloaders(config: Config, logger: Logger) -> None: + """ + Submit download and verification jobs for all schedules. + + For each schedule: + 1. Submit multiple download jobs with appropriate dependencies + 2. Submit a verification job dependent on the last download job + 3. Record job IDs for future reference + + Args: + config: Configuration object with job submission parameters + logger: Logger instance for output + """ logger.info("Submitting downloaders") # Get parameters from config schedules_path: str = os.path.join(config.get_folder("schedules_folder"), @@ -164,6 +259,14 @@ def submit_downloaders(config: Config, logger: Logger) -> None: def main(): + """ + Main entry point that coordinates the schedule creation and job submission process. + + 1. Loads configuration from environment variables + 2. Creates downloading schedules based on server profiles + 3. Submits downloading and verification jobs + 4. Updates the checkpoint to indicate completion + """ config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") diff --git a/src/distributed_downloader/core/MPI_downloader_verifier.py b/src/distributed_downloader/core/MPI_downloader_verifier.py index ff5f98a..d905b2a 100644 --- a/src/distributed_downloader/core/MPI_downloader_verifier.py +++ b/src/distributed_downloader/core/MPI_downloader_verifier.py @@ -13,6 +13,24 @@ def verify_batches(config: Config, server_schedule: str, logger: Logger) -> None: + """ + Verifies download completion status for batches in a schedule. + + This function: + 1. Loads the schedule configuration and existing verification data + 2. Checks each batch's status on disk (completed/failed) + 3. Updates the verification file with current status + 4. Creates _UNCHANGED flag if verification is stable + 5. Creates _DONE flag when all batches are completed + + Args: + config: Configuration object with paths to relevant folders + server_schedule: Path to the schedule directory to verify + logger: Logger instance for output messages + + Raises: + ValueError: If the schedule config file is not found + """ logger.info(f"Verifying batches for {server_schedule}") server_urls_downloaded = config.get_folder("images_folder") @@ -88,6 +106,18 @@ def verify_batches(config: Config, def main(): + """ + Main entry point that loads configuration and triggers batch verification. + + This function: + 1. Reads the configuration from the environment + 2. Parses command line arguments to get the schedule path + 3. Initializes a logger + 4. Calls verify_batches() to check download completion status + + Raises: + ValueError: If the CONFIG_PATH environment variable is not set + """ config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") diff --git a/src/distributed_downloader/core/MPI_multimedia_downloader.py b/src/distributed_downloader/core/MPI_multimedia_downloader.py index 5c11c13..39739e9 100644 --- a/src/distributed_downloader/core/MPI_multimedia_downloader.py +++ b/src/distributed_downloader/core/MPI_multimedia_downloader.py @@ -10,8 +10,11 @@ from distributed_downloader.core.mpi_downloader.dataclasses import CompletedBatch from distributed_downloader.core.mpi_downloader.Downloader import Downloader from distributed_downloader.core.mpi_downloader.PreLoader import load_one_batch -from distributed_downloader.core.mpi_downloader.utils import get_latest_schedule, \ - get_or_init_downloader, is_enough_time +from distributed_downloader.core.mpi_downloader.utils import ( + get_latest_schedule, + get_or_init_downloader, + is_enough_time, +) from distributed_downloader.tools.config import Config from distributed_downloader.tools.utils import init_logger @@ -21,6 +24,22 @@ def download_batch( _input_path: str, _batch_id: int, ) -> Tuple[CompletedBatch, float]: + """ + Downloads a single batch of images using the provided downloader. + + This function loads URL data from a parquet file at the input path, + passes it to the downloader, and returns both the completed batch + and the final download rate achieved. + + Args: + _downloader: Instance of Downloader class to use for downloading + _input_path: Path to the directory containing batch data + _batch_id: Identifier for the batch being downloaded + + Returns: + Tuple[CompletedBatch, float]: Completed batch object containing downloaded items and + the final download rate achieved + """ batch = load_one_batch(_input_path) _completed_batch, _finish_rate = _downloader.get_images(batch) @@ -33,6 +52,24 @@ def download_schedule( server_schedule: str, logger: logging.Logger, ): + """ + Main download function that processes a schedule using MPI parallelism. + + This function coordinates multiple MPI processes to download images in parallel: + 1. Each MPI rank loads its assigned part of the schedule + 2. For each server assigned, it initializes or reuses a downloader with appropriate rate limits + 3. Downloads assigned batches using exclusive MPI locks to prevent server overloading + 4. Writes downloaded batch data to storage + 5. Tracks performance metrics (download time, write time) + + The function handles synchronization between MPI ranks using window locking + to ensure server rate limits are respected globally. + + Args: + config: Configuration object with download parameters + server_schedule: Path to the directory containing schedule files + logger: Logger instance for output + """ header_str = config["downloader_parameters"]["header"] header = {header_str.split(": ")[0]: header_str.split(": ")[1]} img_size = config["downloader_parameters"]["image_size"] @@ -44,6 +81,7 @@ def download_schedule( logger.info(f"Schedule {server_schedule} already done") return + # Initialize MPI communication comm = MPI.COMM_WORLD rank = comm.rank mem = MPI.Alloc_mem(1) @@ -60,6 +98,7 @@ def download_schedule( latest_schedule = latest_schedule.to_dict("records") job_end_time: int = int(os.getenv("SLURM_JOB_END_TIME", 0)) + # Dictionary for reusing downloaders across batches for the same server downloader_schedule: Dict[str, Tuple] = {} downloading_time = 0 @@ -68,6 +107,7 @@ def download_schedule( logger.info(f"Rank {rank} started downloading") for schedule_dict in latest_schedule: + # Get or initialize a downloader for this server downloader, _, rate_limit = get_or_init_downloader(header, img_size, schedule_dict, @@ -77,6 +117,7 @@ def download_schedule( logger) for batch_id in range(schedule_dict["partition_id_from"], schedule_dict["partition_id_to"]): + # Lock to ensure exclusive access when downloading from a server window.Lock(schedule_dict["main_rank"], MPI.LOCK_EXCLUSIVE) try: if not is_enough_time(rate_limit, job_end_time=job_end_time): @@ -110,12 +151,24 @@ def download_schedule( except Exception as e: logger.error(f"Rank {rank} failed with error: {e}") finally: - # comm.Barrier() + # Clean up MPI resources window.Free() mem.release() def main(): + """ + Main entry point that parses arguments and initiates the download process. + + This function: + 1. Loads configuration from the environment + 2. Initializes logging + 3. Parses the schedule path from command line arguments + 4. Calls the download_schedule function to begin downloading + + Raises: + ValueError: If the CONFIG_PATH environment variable is not set + """ config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") diff --git a/src/distributed_downloader/core/MPI_multimedia_downloader_controller.py b/src/distributed_downloader/core/MPI_multimedia_downloader_controller.py index 3979586..f3a842c 100644 --- a/src/distributed_downloader/core/MPI_multimedia_downloader_controller.py +++ b/src/distributed_downloader/core/MPI_multimedia_downloader_controller.py @@ -2,13 +2,17 @@ import os from collections import deque from logging import Logger -from typing import Any, Dict, List, Deque +from typing import Any, Deque, Dict, List import pandas as pd -from distributed_downloader.core.mpi_downloader.utils import get_latest_schedule, generate_ids_to_download, \ - separate_to_blocks, \ - get_largest_nonempty_bucket, get_schedule_count +from distributed_downloader.core.mpi_downloader.utils import ( + generate_ids_to_download, + get_largest_nonempty_bucket, + get_latest_schedule, + get_schedule_count, + separate_to_blocks, +) from distributed_downloader.tools.config import Config from distributed_downloader.tools.utils import init_logger @@ -16,6 +20,25 @@ def create_new_schedule(config: Config, server_schedule: str, logger: Logger) -> None: + """ + Creates a new download schedule based on verification results. + + This function orchestrates the download scheduling process by: + 1. Loading server configuration and previous verification results + 2. Identifying which batches still need to be downloaded + 3. Organizing batches into optimal server-specific processing blocks + 4. Distributing the workload across available workers + 5. Creating a new schedule file for the downloader to use + + The scheduling algorithm prioritizes servers with higher process_per_node + values to maximize throughput and uses a bucket-based allocation to efficiently + distribute work across available worker processes. + + Args: + config: Configuration object with download parameters + server_schedule: Path to the schedule directory + logger: Logger instance for output messages + """ logger.info(f"Creating new schedule for {server_schedule}") number_of_workers: int = (config["downloader_parameters"]["max_nodes"] @@ -30,6 +53,7 @@ def create_new_schedule(config: Config, logger.info(f"Schedule {server_schedule} already done") return + # Set up batch ranges for each server server_config_df["start_index"] = 0 server_config_df["end_index"] = 0 server_config_columns = server_config_df.columns.to_list() @@ -41,6 +65,7 @@ def create_new_schedule(config: Config, server_config_df["end_index"] = server_config_df["total_batches"] - 1 server_config_df = server_config_df[server_config_columns] + # Incorporate data from latest schedule if it exists latest_schedule = get_latest_schedule(server_schedule) if latest_schedule is not None and len(latest_schedule) > 0: latest_schedule_aggr = latest_schedule.groupby("server_name").agg( @@ -50,6 +75,7 @@ def create_new_schedule(config: Config, server_config_df["start_index"] = server_config_df["partition_id_from"].astype(int) server_config_df = server_config_df[server_config_columns] + # Find which batches still need downloading by comparing with verification results batches_to_download: pd.DataFrame = server_config_df.apply(generate_ids_to_download, axis=1, args=(server_verifier_df,)) batches_to_download = batches_to_download.merge(server_config_df, on="server_name", how="left").drop( @@ -58,6 +84,7 @@ def create_new_schedule(config: Config, batches_to_download.sort_values(by=["process_per_node", "nodes"], inplace=True, ascending=False) + # Create buckets based on process_per_node for efficient worker allocation ids_to_schedule_in_buckets: Dict[int, Deque[Dict[str, Any]]] = {} process_per_nodes = batches_to_download["process_per_node"].unique() for process_per_node in process_per_nodes: @@ -67,6 +94,7 @@ def create_new_schedule(config: Config, logger.info("Filtered out already downloaded batches, creating schedule...") logger.debug(ids_to_schedule_in_buckets) + # Generate schedule by assigning batches to workers schedule_list: List[Dict[str, Any]] = [] worker_id = 0 @@ -81,12 +109,14 @@ def create_new_schedule(config: Config, worker_id = 0 continue + # Pop a server from the highest priority bucket current_server = ids_to_schedule_in_buckets[largest_key].popleft() current_server["nodes"] -= 1 server_rate_limit = server_profiler_df[server_profiler_df["server_name"] == current_server["server_name"]][ "rate_limit"].array[0] if len(current_server["batches"]) > 0: + # Schedule batches for this server across multiple workers batches_to_schedule = [current_server["batches"].pop(0) for _ in range(current_server["process_per_node"])] main_worker_id = worker_id for batches in batches_to_schedule: @@ -101,12 +131,15 @@ def create_new_schedule(config: Config, }) worker_id += 1 + # Return server to bucket if it still has nodes to allocate if current_server["nodes"] > 0: ids_to_schedule_in_buckets[largest_key].append(current_server) + # Remove empty buckets if len(ids_to_schedule_in_buckets[largest_key]) == 0: del ids_to_schedule_in_buckets[largest_key] + # Write the new schedule to disk schedule_number = get_schedule_count(server_schedule) pd.DataFrame(schedule_list).to_csv(f"{server_schedule}/{schedule_number:0=4}.csv", index=False, header=True) @@ -114,6 +147,18 @@ def create_new_schedule(config: Config, def main(): + """ + Main entry point that loads configuration and triggers schedule creation. + + This function: + 1. Reads the configuration from the environment + 2. Parses command line arguments to get the schedule path + 3. Initializes a logger + 4. Calls create_new_schedule() to generate the next download schedule + + Raises: + ValueError: If the CONFIG_PATH environment variable is not set + """ config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") diff --git a/src/distributed_downloader/core/fake_profiler.py b/src/distributed_downloader/core/fake_profiler.py index e299ded..3c3bd4a 100644 --- a/src/distributed_downloader/core/fake_profiler.py +++ b/src/distributed_downloader/core/fake_profiler.py @@ -7,6 +7,37 @@ def main(): + """ + Generates a profiling table for servers with default rate limits. + + Background: + ----------- + The original concept was to dynamically profile server performance by: + - Running test downloads to measure actual download speeds + - Identifying response times and download capacity before throttling occurs + - Filtering out non-responsive or problematic servers + + Current Implementation: + ---------------------- + This simpler approach was adopted because comprehensive profiling: + - Is time-consuming to execute + - Isn't compatible with the current downloader architecture (lacks dynamic + allocation of new downloaders when bandwidth permits) + + Current behavior: + 1. Counts available download partitions for each server + 2. Assigns a default rate limit (from config) to all servers + 3. Creates a profiles table CSV with this information + + Future Direction: + ---------------- + 1. Move partition counting to initialization phase + 2. Implement controller-worker downloader architecture to: + - Dynamically adjust speeds based on real-time performance + - Vary the number of concurrent downloaders per server + - Self-regulate without needing a separate profiling step + 3. Allow initial speed/concurrency suggestions to optimize early performance + """ config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") diff --git a/src/distributed_downloader/core/initialization.py b/src/distributed_downloader/core/initialization.py index 8e7de9b..8b12e7b 100644 --- a/src/distributed_downloader/core/initialization.py +++ b/src/distributed_downloader/core/initialization.py @@ -1,30 +1,32 @@ import os.path -import uuid -from typing import Dict -from urllib.parse import urlparse - -import pyspark.sql.functions as func -from pyspark.sql import SparkSession, Window -from pyspark.sql.functions import udf -from pyspark.sql.types import StringType - -from distributed_downloader.core.schemes import multimedia_scheme -from distributed_downloader.tools.config import Config -from distributed_downloader.tools.utils import load_dataframe, truncate_paths, init_logger - - -@udf(returnType=StringType()) -def get_server_name(url: str): - return urlparse(url).netloc - - -@udf(returnType=StringType()) -def get_uuid(): - return str(uuid.uuid4()) +from typing import Dict, Type + +from distributed_downloader.core.initializers.base_initializer import BaseInitializer +from distributed_downloader.core.initializers.eol_initializer import EoLInitializer +from distributed_downloader.core.initializers.fathom_net_initializer import ( + FathomNetInitializer, +) +from distributed_downloader.core.initializers.gbif_initializer import GBIFInitializer +from distributed_downloader.core.initializers.lila_initializer import LilaInitializer +from distributed_downloader.tools import Config +from distributed_downloader.tools.utils import ( + truncate_paths, +) + +__initializers: Dict[str, Type[BaseInitializer]] = { + "gbif": GBIFInitializer, + "fathom_net": FathomNetInitializer, + "lila": LilaInitializer, + "eol": EoLInitializer, +} def init_filestructure(file_structure: Dict[str, str]) -> None: - filtered_fs = [value for key, value in file_structure.items() if key not in ["inner_checkpoint_file", "ignored_table"]] + filtered_fs = [ + value + for key, value in file_structure.items() + if key not in ["inner_checkpoint_file", "ignored_table"] + ] truncate_paths(filtered_fs) @@ -32,70 +34,10 @@ def init_filestructure(file_structure: Dict[str, str]) -> None: config_path = os.environ.get("CONFIG_PATH") if config_path is None: raise ValueError("CONFIG_PATH not set") - config = Config.from_path(config_path, "downloader") + assert ( + config["initializer_type"] in __initializers.keys() + ), "Unknown initialization type, aborting" - # Initialize parameters - input_path = config["path_to_input"] - # init_filestructure(config) - output_path = config.get_folder("urls_folder") - logger = init_logger(__name__) - - # Initialize SparkSession - spark = SparkSession.builder.appName("Multimedia prep").getOrCreate() - spark.conf.set("spark.sql.parquet.datetimeRebaseModeInWrite", "CORRECTED") - spark.conf.set("spark.sql.parquet.int96RebaseModeInWrite", "CORRECTED") - - multimedia_df = load_dataframe(spark, input_path, multimedia_scheme.schema) - - multimedia_df_prep = (multimedia_df - .filter((multimedia_df["gbifID"].isNotNull()) - & (multimedia_df["identifier"].isNotNull()) - & ( - (multimedia_df["type"] == "StillImage") - | ( - (multimedia_df["type"].isNull()) - & (multimedia_df["format"].contains("image")) - ) - )) - .repartition(20)) - - multimedia_df_prep = multimedia_df_prep.withColumn("server_name", - get_server_name(multimedia_df_prep.identifier)) - multimedia_df_prep = multimedia_df_prep.withColumn("UUID", get_uuid()) - - columns = multimedia_df_prep.columns - - logger.info("Starting batching") - - servers_grouped = (multimedia_df_prep - .select("server_name") - .groupBy("server_name") - .count() - .withColumn("batch_count", - func.floor(func.col("count") / config["downloader_parameters"]["batch_size"]))) - - window_part = Window.partitionBy("server_name").orderBy("server_name") - master_df_filtered = (multimedia_df_prep - .withColumn("row_number", func.row_number().over(window_part)) - .join(servers_grouped, ["server_name"]) - .withColumn("partition_id", func.col("row_number") % func.col("batch_count")) - .withColumn("partition_id", - (func - .when(func.col("partition_id").isNull(), 0) - .otherwise(func.col("partition_id")))) - .select(*columns, "partition_id")) - - logger.info("Writing to parquet") - - (master_df_filtered - .repartition("server_name", "partition_id") - .write - .partitionBy("server_name", "partition_id") - .mode("overwrite") - .format("parquet") - .save(output_path)) - - logger.info("Finished batching") - - spark.stop() + initializer = __initializers[config["initializer_type"]](config) + initializer.run() diff --git a/src/distributed_downloader/core/initializers/README.md b/src/distributed_downloader/core/initializers/README.md new file mode 100644 index 0000000..41f7553 --- /dev/null +++ b/src/distributed_downloader/core/initializers/README.md @@ -0,0 +1,80 @@ +# Initializers + +Initializers are used to process the target dataset into a format that can be used by the downloader. +Spark is used to process the data in parallel. +The following steps are performed by the initializers: + +1. Read the dataset from the source. +2. Filter the dataset to remove unwanted data. +3. Rename the columns to match the downloader's schema. Specifically: + - `identifier` is the url to the image. + - `source_id` some identification of the image, that can later be used to access relevant information from the + source. + - `license` (optional) the license of the image (needs to be an url to the license). + - `source` (optional) source of the image, for licensing purposes. + - `title` (optional) title of the image, for licensing purposes. +4. Extract server name from the `identifier` and generate `uuid` (unique internal identifier) for each image. +5. Partition the dataset first by the server name and then into smaller partitions with a fixed size (can be configured, + default is 10,000) +6. Save the partitioned dataset to the target location in `parquet` format. + +The initializers are run only once for each dataset, and the resulting partitioned dataset is used by the downloader to +download the images. + +## Structure + +### Base Initializer + +The `BaseInitializer` class is the base class for all initializers. It contains all the common functionality used by all +initializers. +It has the following methods: + +- `load_raw_df`: Reads the dataset from the source. +- `extract_server_name`: Extracts the server name from the `identifier`. +- `generate_uuid`: Generates a unique identifier for each image. +- `partition_dataframe`: Partitions the dataset first by the server name and then into smaller partitions with a fixed + size. +- `save_results`: Saves the partitioned dataset to the target location in `parquet` format. + +### Initializers + +The initializers are classes that inherit from the `BaseInitializer` class. They implement the specific logic for each +dataset. +The following initializers are available: + +- `GBIFInitializer`: Initializer for the GBIF dataset. It filters out any entries without an `gbifID` or `identifier` + value. + Additionally, removes any entries that are not `StillImage` by type or `image` by format. + And lastly, it removes any entries that have `MATERIAL_CITATION` in `basisOfRecord`. This is because these are known to be images of text documents. +- `FathomNetInitializer`: Initializer for the FathomNet dataset. It filters out any entries without an `uuid` or `url` + value. + Additionally, removes any entries that are "not valid" by the `valid` column. +- `EoLInitializer`: Initializer for the EOL dataset. It filters out any entries without an `EOL content ID` or + `EOL Full-Size Copy URL` value. + - The `EOL content ID` is set as the `source_id`, which is used to map to the original metadata file to get the `EOL page ID` to match to the `taxon.tab` for taxa information. `EOL content ID` is not a persistent identifier at EOL, so it is important to maintain the original metadata file. +- `LilaInitializer`: Initializer for the LILA dataset. It filters out any entries that do not have a `url` value (by + `url_gcp`, `url_aws` or `url_azure`). + Additionally, removes any entries that are `empty` by the `original_label` column. + +## Creating a new Initializer + +To create a new initializer, you need to create a new class that inherits from the `BaseInitializer` class. +You will need to implement only `run` method, which ties together all the steps described above. +In the most cases you will need to implement only custom logic for filtering and renaming columns, everything else is +already implemented in the `BaseInitializer` class, and should be called in the order, described above. + +### Important considerations before creating a new Initializer + +- You need to be familiar with the dataset you are working with, and understand the structure of the data before + starting to create a new initializer. +- `source_id` is highly recommended to be unique for each image, and be persistent across iterations of the dataset. + Otherwise, it may be challenging to map additional information to the images. Or it may be challenging to update the + dataset in the future. + +### Required columns for proper `distributed-downloader` work + +Initializer has to create a dataset with the following columns: +- `uuid` - unique internal identifier that is highly recommended to be generated on this step for later consistency +- `identifier` - url to the image +- `source_id` - some identification of the image that can later be used to access relative information from the + source diff --git a/src/distributed_downloader/core/schemes/__init__.py b/src/distributed_downloader/core/initializers/__init__.py similarity index 100% rename from src/distributed_downloader/core/schemes/__init__.py rename to src/distributed_downloader/core/initializers/__init__.py diff --git a/src/distributed_downloader/core/initializers/base_initializer.py b/src/distributed_downloader/core/initializers/base_initializer.py new file mode 100644 index 0000000..dc59585 --- /dev/null +++ b/src/distributed_downloader/core/initializers/base_initializer.py @@ -0,0 +1,173 @@ +import uuid +from abc import ABC, abstractmethod +from urllib.parse import urlparse + +import pyspark.sql.functions as func +from pyspark.sql import DataFrame, SparkSession, Window +from pyspark.sql.functions import udf +from pyspark.sql.types import StringType + +from distributed_downloader.tools import Config, init_logger, load_dataframe + + +class BaseInitializer(ABC): + """ + Base class for all initializers. + This class is responsible for initializing the Spark session and loading the raw dataframe. + It also provides methods for saving the results and extracting the server name from the URL. + + It is an abstract class, therefore method run() must be implemented in the child class. + """ + def __init__(self, config: Config): + """ + Initializes the Spark session and loads the raw dataframe. + Args: + config: Config object containing the configuration parameters. + """ + self.config = config + + self.input_path = self.config["path_to_input"] + self.output_path = self.config.get_folder("urls_folder") + + self.logger = init_logger(__name__) + + self.spark = SparkSession.builder.appName("Multimedia prep").getOrCreate() + self.spark.conf.set("spark.sql.parquet.datetimeRebaseModeInWrite", "CORRECTED") + self.spark.conf.set("spark.sql.parquet.int96RebaseModeInWrite", "CORRECTED") + + def load_raw_df(self) -> DataFrame: + """ + Loads the raw dataframe from the input path (taken from the config file). + Returns: + DataFrame: Raw dataframe. + """ + return load_dataframe(self.spark, self.input_path) + + def save_results(self, resul_df: DataFrame) -> None: + """ + Saves the results to the output path (taken from the config file). + Args: + resul_df: DataFrame to be saved. + + Returns: + None + """ + ( + resul_df.repartition("server_name", "partition_id") + .write.partitionBy("server_name", "partition_id") + .mode("overwrite") + .format("parquet") + .save(self.output_path) + ) + + def extract_server_name(self, data_frame: DataFrame) -> DataFrame: + """ + Extracts the server name from the URL (`identifier` column) and adds it as a new column - + `server_name` to the dataframe. + Args: + data_frame: DataFrame to be processed. + + Returns: + DataFrame: DataFrame with the new column. + """ + return data_frame.withColumn( + "server_name", self.get_server_name(data_frame.identifier) + ) + + def generate_uuid(self, data_frame: DataFrame) -> DataFrame: + """ + Generates a UUID for each row in the dataframe and adds it as a new column - `uuid`. + Args: + data_frame: DataFrame to be processed. + + Returns: + DataFrame: DataFrame with the new column. + """ + return data_frame.withColumn("uuid", self.get_uuid()) + + def partition_dataframe(self, data_frame: DataFrame) -> DataFrame: + """ + Partitions the dataframe into batches based on the `server_name` and batch size. + Args: + data_frame: DataFrame to be processed. + + Returns: + DataFrame: DataFrame with the new column. + """ + columns = data_frame.columns + + self.logger.info("Starting batching") + + servers_grouped = ( + data_frame.select("server_name") + .groupBy("server_name") + .count() + .withColumn( + "batch_count", + func.floor( + func.col("count") + / self.config["downloader_parameters"]["batch_size"] + ), + ) + ) + + window_part = Window.partitionBy("server_name").orderBy("server_name") + master_df_filtered = ( + data_frame.withColumn("row_number", func.row_number().over(window_part)) + .join(servers_grouped, ["server_name"]) + .withColumn( + "partition_id", func.col("row_number") % func.col("batch_count") + ) + .withColumn( + "partition_id", + ( + func.when(func.col("partition_id").isNull(), 0).otherwise( + func.col("partition_id") + ) + ), + ) + .select(*columns, "partition_id") + ) + + self.logger.info("Finished batching") + + return master_df_filtered + + @staticmethod + @udf(returnType=StringType()) + def get_server_name(url: str): + """ + PySpark UDF that extracts the server name from the URL. + Args: + url: URL to be processed. + + Returns: + str: Server name. + """ + return urlparse(url).netloc + + @staticmethod + @udf(returnType=StringType()) + def get_uuid(): + """ + PySpark UDF that generates a UUID. + Returns: + str: UUID. + """ + return str(uuid.uuid4()) + + def __del__(self): + """ + Destructor method that stops the Spark session when the object is deleted. + """ + self.spark.stop() + + @abstractmethod + def run(self): + """ + Abstract method that must be implemented in the child class. + Intended to be an entry point for the class. + Returns: + None + """ + pass diff --git a/src/distributed_downloader/core/initializers/eol_initializer.py b/src/distributed_downloader/core/initializers/eol_initializer.py new file mode 100644 index 0000000..82d04d9 --- /dev/null +++ b/src/distributed_downloader/core/initializers/eol_initializer.py @@ -0,0 +1,50 @@ +from distributed_downloader.core.initializers.base_initializer import BaseInitializer + + +class EoLInitializer(BaseInitializer): + """ + Initializer for the Encyclopedia of Life (EoL) dataset. + + This initializer processes the EoL dataset with the following steps: + 1. Loads the raw dataframe from the specified input path. + 2. Filters out entries that don't have an 'EOL content ID' or 'EOL Full-Size Copy URL'. + 3. Renames columns to match the downloader schema: + - 'EOL content ID' -> 'source_id' + - 'EOL Full-Size Copy URL' -> 'identifier' + - 'License Name' -> 'license' + - 'Copyright Owner' -> 'owner' + 4. Extracts server names from the identifiers + 5. Generates UUIDs for each entry + 6. Partitions the dataframe based on server names and batch size + 7. Saves the processed dataset to the specified output location + """ + + def run(self): + """ + Executes the initialization process for the Encyclopedia of Life dataset. + + This method performs the complete pipeline of loading, filtering, + processing, and saving the EoL data. + """ + multimedia_df = self.load_raw_df() + + multimedia_df_prep = ( + multimedia_df + .filter((multimedia_df["EOL content ID"].isNotNull()) + & (multimedia_df["EOL Full-Size Copy URL"].isNotNull())) + .repartition(20) + .withColumnsRenamed( + { + "EOL content ID": "source_id", + "EOL Full-Size Copy URL": "identifier", + "License Name": "license", + "Copyright Owner": "owner" + }) + ) + + multimedia_df_prep = self.extract_server_name(multimedia_df_prep) + multimedia_df_prep = self.generate_uuid(multimedia_df_prep) + master_df_filtered = self.partition_dataframe(multimedia_df_prep) + self.logger.info("Writing to parquet") + self.save_results(master_df_filtered) + self.logger.info("Finished batching") diff --git a/src/distributed_downloader/core/initializers/fathom_net_initializer.py b/src/distributed_downloader/core/initializers/fathom_net_initializer.py new file mode 100644 index 0000000..b945a40 --- /dev/null +++ b/src/distributed_downloader/core/initializers/fathom_net_initializer.py @@ -0,0 +1,55 @@ +import pyspark.sql.functions as func + +from distributed_downloader.core.initializers.base_initializer import BaseInitializer + + +class FathomNetInitializer(BaseInitializer): + """ + Initializer for the FathomNet dataset. + + This initializer processes the FathomNet dataset with the following steps: + 1. Loads the raw dataframe from the specified input path. + 2. Filters out entries that: + - Don't have a uuid or url + - Are not valid (based on the 'valid' column) + 3. Renames columns to match the downloader schema: + - 'uuid' -> 'source_id' + - 'url' -> 'identifier' + 4. Adds license information to each entry + 5. Extracts server names from the identifiers + 6. Generates UUIDs for each entry + 7. Partitions the dataframe based on server names and batch size + 8. Saves the processed dataset to the specified output location + """ + + def run(self): + """ + Executes the initialization process for the FathomNet dataset. + + This method performs the complete pipeline of loading, filtering, + processing, and saving the FathomNet data. + """ + multimedia_df = self.load_raw_df() + + multimedia_df_prep = ( + multimedia_df.filter( + (multimedia_df["uuid"].isNotNull()) + & (multimedia_df["url"].isNotNull()) + & (multimedia_df["valid"].cast("boolean")) + ) + .repartition(20) + .withColumnsRenamed({"uuid": "source_id", "url": "identifier"}) + .withColumn( + "license", + func.lit( + "https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en" + ), + ) + ) + + multimedia_df_prep = self.extract_server_name(multimedia_df_prep) + multimedia_df_prep = self.generate_uuid(multimedia_df_prep) + master_df_filtered = self.partition_dataframe(multimedia_df_prep) + self.logger.info(f"Writing to parquet: {self.output_path}") + self.save_results(master_df_filtered) + self.logger.info("Finished batching") diff --git a/src/distributed_downloader/core/initializers/gbif_initializer.py b/src/distributed_downloader/core/initializers/gbif_initializer.py new file mode 100644 index 0000000..d9d7c34 --- /dev/null +++ b/src/distributed_downloader/core/initializers/gbif_initializer.py @@ -0,0 +1,48 @@ +from distributed_downloader.core.initializers.base_initializer import BaseInitializer + + +class GBIFInitializer(BaseInitializer): + """ + Initializer for the Global Biodiversity Information Facility (GBIF) dataset. + + This initializer processes the GBIF dataset with the following steps: + 1. Loads the raw dataframe from the specified input path. + 2. Filters out entries that: + - Don't have a gbifID or identifier + - Are not of type StillImage or don't contain 'image' in the format field + - Have 'MATERIAL_CITATION' in the basisOfRecord field + 3. Extracts server names from the identifiers + 4. Generates UUIDs for each entry + 5. Partitions the dataframe based on server names and batch size + 6. Saves the processed dataset to the specified output location + """ + + def run(self): + """ + Executes the initialization process for the GBIF dataset. + + This method performs the complete pipeline of loading, filtering, + processing, and saving the GBIF data. + """ + multimedia_df = self.load_raw_df() + + multimedia_df_prep = multimedia_df.filter( + (multimedia_df["gbifID"].isNotNull()) + & (multimedia_df["identifier"].isNotNull()) + & ( + (multimedia_df["type"] == "StillImage") + | ( + (multimedia_df["type"].isNull()) + & (multimedia_df["format"].contains("image")) + ) + ) + & ~(multimedia_df["basisOfRecord"].contains("MATERIAL_CITATION")) + ).repartition(20) + + multimedia_df_prep = self.extract_server_name(multimedia_df_prep) + multimedia_df_prep = self.generate_uuid(multimedia_df_prep) + master_df_filtered = self.partition_dataframe(multimedia_df_prep) + + self.logger.info("Writing to parquet") + self.save_results(master_df_filtered) + self.logger.info("Finished batching") diff --git a/src/distributed_downloader/core/initializers/lila_initializer.py b/src/distributed_downloader/core/initializers/lila_initializer.py new file mode 100644 index 0000000..ea3d2e2 --- /dev/null +++ b/src/distributed_downloader/core/initializers/lila_initializer.py @@ -0,0 +1,61 @@ +import pyspark.sql.functions as func + +from distributed_downloader.core.initializers.base_initializer import BaseInitializer + + +class LilaInitializer(BaseInitializer): + """ + Initializer for the Labeled Information Library of Alexandria (Lila) dataset. + + This initializer processes the Lila dataset with the following steps: + 1. Loads the raw dataframe from the specified input path. + 2. Filters out entries that: + - Don't have a URL value (checking url_gcp, url_aws, and url_azure columns) + - Have "empty" as the original_label + 3. Creates an 'identifier' column from available URL columns with the following priority: + url_gcp -> url_aws -> url_azure + 4. Renames 'image_id' to 'source_id' + 5. Extracts server names from the identifiers + 6. Generates UUIDs for each entry + 7. Partitions the dataframe based on server names and batch size + 8. Saves the processed dataset to the specified output location + """ + + def run(self): + """ + Executes the initialization process for the Lila dataset. + + This method performs the complete pipeline of loading, filtering, + processing, and saving the Lila data. + """ + multimedia_df = self.load_raw_df() + + multimedia_df_prep = ( + multimedia_df.filter( + ( + multimedia_df["url_gcp"].isNotNull() + | multimedia_df["url_aws"].isNotNull() + | multimedia_df["url_azure"].isNotNull() + ) + & (multimedia_df["original_label"] != "empty") + ) + .repartition(20) + .withColumn( + "identifier", + func.when( + multimedia_df["url_gcp"].isNotNull(), multimedia_df["url_gcp"] + ).otherwise( + func.when( + multimedia_df["url_aws"].isNotNull(), multimedia_df["url_aws"] + ).otherwise(multimedia_df["url_azure"]) + ), + ) + .withColumnsRenamed({"image_id": "source_id"}) + ) + + multimedia_df_prep = self.extract_server_name(multimedia_df_prep) + multimedia_df_prep = self.generate_uuid(multimedia_df_prep) + master_df_filtered = self.partition_dataframe(multimedia_df_prep) + self.logger.info("Writing to parquet") + self.save_results(master_df_filtered) + self.logger.info("Finished batching") diff --git a/src/distributed_downloader/core/main.py b/src/distributed_downloader/core/main.py index 4f48be5..3822969 100644 --- a/src/distributed_downloader/core/main.py +++ b/src/distributed_downloader/core/main.py @@ -168,14 +168,16 @@ def main() -> None: config_path = _args.config_path state_override = None - if _args.reset_filtering: + if _args.reset_batched: state_override = { "batched": False, - "profiled": False + "profiled": False, + "schedule_creation_scheduled": False } - elif _args.reset_scheduling: + elif _args.reset_profiled: state_override = { - "profiled": False + "profiled": False, + "schedule_creation_scheduled": False } dd = DistributedDownloader.from_path(config_path, state_override) diff --git a/src/distributed_downloader/core/mpi_downloader/DirectWriter.py b/src/distributed_downloader/core/mpi_downloader/DirectWriter.py index b03beb6..84d87a6 100644 --- a/src/distributed_downloader/core/mpi_downloader/DirectWriter.py +++ b/src/distributed_downloader/core/mpi_downloader/DirectWriter.py @@ -1,11 +1,19 @@ +""" +Storage module for the distributed downloader. + +This module handles writing downloaded image data to persistent storage. +It processes both successful and failed downloads, storing them in +separate parquet files with appropriate metadata. +""" + import logging import os import time -from typing import List, Any +from typing import Any, List import pandas as pd -from .dataclasses import ErrorEntry, SuccessEntry, CompletedBatch +from .dataclasses import CompletedBatch, ErrorEntry, SuccessEntry def write_batch( @@ -14,6 +22,24 @@ def write_batch( job_end_time: int, logger: logging.Logger = logging.getLogger() ) -> None: + """ + Write a completed batch of downloads to storage. + + This function processes successful and failed downloads from a batch, + writing them to separate parquet files with appropriate metadata. + It also creates marker files indicating completion status. + + Args: + completed_batch: CompletedBatch object containing successful and failed downloads + output_path: Directory path to write the batch data + job_end_time: UNIX timestamp when the job should end + logger: Logger instance for output messages + + Raises: + TimeoutError: If there is not enough time left to complete writing + ValueError: If the batch is empty (no successes or errors) + Exception: For other errors during the write process + """ logger.debug(f"Writing batch to {output_path}") os.makedirs(output_path, exist_ok=True) @@ -46,9 +72,11 @@ def write_batch( logger.info(f"Completed collecting entries for {output_path}") pd.DataFrame(successes_list, columns=SuccessEntry.get_names()).to_parquet(f"{output_path}/successes.parquet", - index=False) + index=False, compression="zstd", + compression_level=3) pd.DataFrame(errors_list, columns=ErrorEntry.get_names()).to_parquet(f"{output_path}/errors.parquet", - index=False) + index=False, compression="zstd", + compression_level=3) logger.info(f"Completed writing to {output_path}") diff --git a/src/distributed_downloader/core/mpi_downloader/Downloader.py b/src/distributed_downloader/core/mpi_downloader/Downloader.py index 6bc06ad..3833319 100644 --- a/src/distributed_downloader/core/mpi_downloader/Downloader.py +++ b/src/distributed_downloader/core/mpi_downloader/Downloader.py @@ -1,16 +1,29 @@ +""" +Core downloader module for retrieving and processing images. + +This module implements the Downloader class which is responsible for: +- Downloading images from URLs with rate limiting and error handling +- Processing and resizing downloaded images +- Managing parallel downloads with thread pooling +- Adapting download rates based on server responses + +The downloader uses concurrent threads with semaphores to control the +rate of requests to each server. +""" + import concurrent.futures import hashlib import logging import queue import threading import time -from typing import List, Dict, Any, Tuple +from typing import Any, Dict, List, Tuple import cv2 import numpy as np import requests -from .dataclasses import DownloadedImage, CompletedBatch, RateLimit +from .dataclasses import CompletedBatch, DownloadedImage, RateLimit _MAX_RETRIES = 5 _TIMEOUT = 5 @@ -19,6 +32,26 @@ class Downloader: + """ + Image downloader with concurrent processing, rate limiting, and error handling. + + This class manages the downloading and processing of images from URLs, controlling + request rates to avoid server throttling while maximizing throughput. It handles + retries for transient failures and processes images for storage. + + Attributes: + header: HTTP headers to use for requests + job_end_time: UNIX timestamp when the job should end + logger: Logger instance for output messages + session: HTTP session for making requests + rate_limit: Current download rate limit (requests per second) + img_size: Maximum size for image resize + upper_limit: Maximum allowed download rate + bottom_limit: Minimum allowed download rate + semaphore: Controls concurrent access to resources + convert_image: Whether to process and resize images after download + """ + def __init__(self, header: dict, session: requests.Session, @@ -27,6 +60,18 @@ def __init__(self, job_end_time: int = 0, convert_image: bool = True, logger=logging.getLogger()): + """ + Initialize a new Downloader instance. + + Args: + header: HTTP headers to use for requests + session: Prepared requests.Session for HTTP connections + rate_limit: Rate limit object with initial, min, and max rates + img_size: Maximum size for image resize + job_end_time: UNIX timestamp when the job should end + convert_image: Whether to process and resize images + logger: Logger instance for output messages + """ self.header = header self.job_end_time = job_end_time self.logger = logger @@ -50,6 +95,20 @@ def get_images(self, images_requested: List[Dict[str, Any]], new_rate_limit: RateLimit = None) \ -> Tuple[CompletedBatch, float]: + """ + Download and process a batch of images. + + This method handles the concurrent downloading of multiple images, + managing rate limits and collecting results. + + Args: + images_requested: List of dictionaries containing image metadata and URLs + new_rate_limit: Optional new rate limit to apply + + Returns: + Tuple[CompletedBatch, float]: A completed batch containing successful and + failed downloads, and the final download rate + """ if new_rate_limit is not None: self.rate_limit = new_rate_limit.initial_rate self.upper_limit = new_rate_limit.upper_bound @@ -87,6 +146,24 @@ def get_images(self, return CompletedBatch(self.success_queue, self.error_queue), self.rate_limit def load_url(self, url: DownloadedImage, timeout: int = 2, was_delayed=True) -> bytes: + """ + Download an image from a URL with rate limiting. + + This method handles the actual HTTP request to download an image, + respecting rate limits and checking for job timeouts. + + Args: + url: DownloadedImage object containing the URL and metadata + timeout: HTTP request timeout in seconds + was_delayed: Whether a delay was already applied before calling + + Returns: + bytes: Raw image content + + Raises: + TimeoutError: If the job end time is approaching + requests.HTTPError: For HTTP-related errors + """ with self.semaphore: if self.job_end_time - time.time() < 0: raise TimeoutError("Not enough time") @@ -109,6 +186,19 @@ def load_url(self, url: DownloadedImage, timeout: int = 2, was_delayed=True) -> return response.content def _callback_builder(self, executor: concurrent.futures.ThreadPoolExecutor, prep_img: DownloadedImage): + """ + Build a callback function for handling download completion or failure. + + This method creates a callback function that will be called when a download + completes or fails, handling retry logic and rate limit adjustments. + + Args: + executor: Thread executor for submitting retry tasks + prep_img: DownloadedImage object being processed + + Returns: + function: Callback function to handle download results + """ def _done_callback(future: concurrent.futures.Future): with self._condition: try: @@ -144,6 +234,22 @@ def _done_callback(future: concurrent.futures.Future): return _done_callback def process_image(self, return_entry: DownloadedImage, raw_image_bytes: bytes) -> None: + """ + Process a downloaded image by resizing and computing checksums. + + This method processes raw image data by: + 1. Converting to a NumPy array + 2. Computing checksums for integrity verification + 3. Resizing if larger than the target dimensions + 4. Storing the processed image and metadata + + Args: + return_entry: DownloadedImage object to update with processed data + raw_image_bytes: Raw image bytes from the download + + Raises: + ValueError: If the image is corrupted or invalid + """ return_entry.end_time = time.perf_counter() self.logger.debug(f"Processing {return_entry.identifier}") @@ -180,6 +286,20 @@ def process_image(self, return_entry: DownloadedImage, raw_image_bytes: bytes) - self.logger.debug(f"Processed {return_entry.identifier}") def process_error(self, return_entry: DownloadedImage, error: Exception) -> bool: + """ + Process download errors and determine if retry is appropriate. + + This method handles various error types that can occur during download + and decides if a retry should be attempted based on the error type and + retry count. + + Args: + return_entry: DownloadedImage object that encountered an error + error: Exception that occurred during download or processing + + Returns: + bool: True if a retry should be attempted, False otherwise + """ if isinstance(error, TimeoutError): self.logger.info("Timout, trying to exit") @@ -216,6 +336,16 @@ def process_error(self, return_entry: DownloadedImage, error: Exception) -> bool @staticmethod def image_resize(image: np.ndarray, max_size=1024) -> tuple[np.ndarray[int, np.dtype[np.uint8]], np.ndarray[int, np.dtype[np.uint32]]]: + """ + Resize an image while preserving aspect ratio. + + Args: + image: NumPy array containing the image data + max_size: Maximum dimension for the resized image + + Returns: + tuple: (resized_image, new_dimensions) where dimensions are [height, width] + """ h, w = image.shape[:2] if h > w: new_h = max_size diff --git a/src/distributed_downloader/core/mpi_downloader/PreLoader.py b/src/distributed_downloader/core/mpi_downloader/PreLoader.py index 7c635c4..e4b1f2f 100644 --- a/src/distributed_downloader/core/mpi_downloader/PreLoader.py +++ b/src/distributed_downloader/core/mpi_downloader/PreLoader.py @@ -1,7 +1,14 @@ -from typing import List, Dict, Any, Iterator +""" +Batch loading utilities for the distributed downloader. + +This module provides functions to load batches of download tasks from +parquet files. It supports both loading multiple batches and single batches +for processing by the downloader workers. +""" + +from typing import Any, Dict, Iterator, List import pandas as pd -import re def load_batch( @@ -9,6 +16,20 @@ def load_batch( server_name: str, batches_to_download: List[int], ) -> Iterator[List[Dict[str, Any]]]: + """ + Load multiple batches for a specific server. + + This generator function loads batches of URLs from parquet files + for a given server, one batch at a time. + + Args: + path_to_parquet: Base directory for input parquet files + server_name: Name of the server to load batches for + batches_to_download: List of batch IDs to download + + Yields: + List[Dict[str, Any]]: List of dictionaries containing URL and metadata for each batch + """ for batch_id in batches_to_download: server_df = pd.read_parquet( f"{path_to_parquet}/ServerName={server_name.replace(':', '%3A')}/partition_id={batch_id}") @@ -16,4 +37,13 @@ def load_batch( def load_one_batch(input_path: str) -> List[Dict[str, Any]]: + """ + Load a single batch from a parquet file. + + Args: + input_path: Path to the parquet file to load + + Returns: + List[Dict[str, Any]]: List of dictionaries containing URL and metadata + """ return pd.read_parquet(input_path).to_dict("records") diff --git a/src/distributed_downloader/core/mpi_downloader/dataclasses.py b/src/distributed_downloader/core/mpi_downloader/dataclasses.py index ef0c159..0523e09 100644 --- a/src/distributed_downloader/core/mpi_downloader/dataclasses.py +++ b/src/distributed_downloader/core/mpi_downloader/dataclasses.py @@ -1,11 +1,22 @@ +""" +Data structures for the distributed downloader system. + +This module defines the core data structures used throughout the downloader: +- DownloadedImage: Represents an image being downloaded with metadata +- SuccessEntry/ErrorEntry: Models for successful and failed downloads +- CompletedBatch: Collection of download results +- RateLimit: Dynamic rate limiting configuration +- Other supporting classes for batch processing and scheduling +""" + from __future__ import annotations +import math import multiprocessing import queue import threading import uuid -import math -from typing import List, Dict, Any +from typing import Any, Dict, List import numpy as np from attr import define, field @@ -16,6 +27,32 @@ @define class DownloadedImage: + """ + Represents an image being downloaded with its metadata and processing state. + + This class tracks the state of an image throughout the download process, + including retry attempts, error conditions, and image processing results. + + Attributes: + retry_count: Number of retry attempts made + error_code: Error code if download failed (0 for success) + error_msg: Error message if download failed + unique_name: Unique identifier for the image + source_id: Source identifier from the original dataset + identifier: URL or other identifier for the image + is_license_full: Whether complete license information is available (license, source, and title) + license: License information + source: Source information + title: Title or caption information + hashsum_original: MD5 checksum of the original image + hashsum_resized: MD5 checksum of the resized image + image: Actual image data as bytes + original_size: Original dimensions [height, width] + resized_size: Resized dimensions [height, width] + start_time: Time when download started + end_time: Time when download completed + """ + retry_count: int error_code: int error_msg: str @@ -40,12 +77,26 @@ class DownloadedImage: @classmethod def from_row(cls, row: Dict[str, Any]) -> DownloadedImage: + """ + Create a DownloadedImage instance from a dictionary row. + + Args: + row: Dictionary containing image metadata + + Returns: + DownloadedImage: A new initialized instance + """ + if "EOL content ID" in row.keys() and 'EOL page ID' in row.keys(): + source_id = row["EOL content ID"] + "_" + row['EOL page ID'] + else: + source_id = "None" + return cls( retry_count=0, error_code=0, error_msg="", unique_name=row.get("uuid", uuid.uuid4().hex), - source_id=row.get("source_id", 0), + source_id=row.get("source_id", source_id), identifier=row.get("identifier", ""), is_license_full=all([row.get("license", None), row.get("source", None), row.get("title", None)]), license=row.get("license", _NOT_PROVIDED) or _NOT_PROVIDED, @@ -55,6 +106,16 @@ def from_row(cls, row: Dict[str, Any]) -> DownloadedImage: def init_downloaded_image_entry(image_entry: np.ndarray, row: Dict[str, Any]) -> np.ndarray: + """ + Initialize a numpy array entry with image metadata. + + Args: + image_entry: numpy array to be initialized + row: Dictionary containing image metadata + + Returns: + np.ndarray: Initialized array entry + """ image_entry["is_downloaded"] = False image_entry["retry_count"] = 0 image_entry["error_code"] = 0 @@ -72,6 +133,27 @@ def init_downloaded_image_entry(image_entry: np.ndarray, row: Dict[str, Any]) -> @define class SuccessEntry: + """ + Represents a successfully downloaded image with all required metadata. + + This class stores all information about a successfully downloaded and processed image, + including image data, checksums, and metadata from the original source. + + Attributes: + uuid: Unique identifier for the image + source_id: Source identifier from the original dataset + identifier: URL or other identifier for the image + is_license_full: Whether complete license information is available (license, source, and title) + license: License information + source: Source information + title: Title or caption information + hashsum_original: MD5 checksum of the original image + hashsum_resized: MD5 checksum of the resized image + original_size: Original dimensions [height, width] + resized_size: Resized dimensions [height, width] + image: Actual image data as bytes + """ + uuid: str source_id: int identifier: str @@ -86,9 +168,18 @@ class SuccessEntry: image: bytes def __success_dtype(self, img_size: int): + """ + Define the NumPy dtype for storing success entries. + + Args: + img_size: Size of the image dimension + + Returns: + np.dtype: NumPy data type definition + """ return np.dtype([ ("uuid", "S32"), - ("source_id", "i4"), + ("source_id", "S32"), ("identifier", "S256"), ("is_license_full", "bool"), ("license", "S256"), @@ -103,17 +194,25 @@ def __success_dtype(self, img_size: int): @staticmethod def get_success_spark_scheme(): - from pyspark.sql.types import StructType - from pyspark.sql.types import StringType - from pyspark.sql.types import LongType - from pyspark.sql.types import StructField - from pyspark.sql.types import BooleanType - from pyspark.sql.types import ArrayType - from pyspark.sql.types import BinaryType + """ + Define the PySpark schema for success entries. + + Returns: + StructType: PySpark schema definition + """ + from pyspark.sql.types import ( + ArrayType, + BinaryType, + BooleanType, + LongType, + StringType, + StructField, + StructType, + ) return StructType([ StructField("uuid", StringType(), False), - StructField("source_id", LongType(), False), + StructField("source_id", StringType(), False), StructField("identifier", StringType(), False), StructField("is_license_full", BooleanType(), False), StructField("license", StringType(), True), @@ -128,6 +227,15 @@ def get_success_spark_scheme(): @classmethod def from_downloaded(cls, downloaded: DownloadedImage) -> SuccessEntry: + """ + Create a SuccessEntry from a DownloadedImage. + + Args: + downloaded: The DownloadedImage to convert + + Returns: + SuccessEntry: A new success entry instance + """ return cls( uuid=downloaded.unique_name, source_id=downloaded.source_id, @@ -145,6 +253,15 @@ def from_downloaded(cls, downloaded: DownloadedImage) -> SuccessEntry: @staticmethod def to_list_download(downloaded: DownloadedImage) -> List: + """ + Convert a DownloadedImage to a list format for storage. + + Args: + downloaded: DownloadedImage to convert + + Returns: + List: List of values in the correct order for storage + """ return [ downloaded.unique_name, downloaded.source_id, @@ -162,6 +279,12 @@ def to_list_download(downloaded: DownloadedImage) -> List: @staticmethod def get_names() -> List[str]: + """ + Get the column names for success entries. + + Returns: + List[str]: List of column names + """ return [ "uuid", "source_id", @@ -178,6 +301,12 @@ def get_names() -> List[str]: ] def to_list(self) -> List: + """ + Convert this SuccessEntry to a list format. + + Returns: + List: List representation of this entry + """ return [ self.uuid, self.source_id, @@ -194,6 +323,12 @@ def to_list(self) -> List: ] def to_np(self) -> np.ndarray: + """ + Convert this SuccessEntry to a NumPy array. + + Returns: + np.ndarray: NumPy array representation of this entry + """ np_structure = np.array( [ (self.uuid, @@ -216,6 +351,20 @@ def to_np(self) -> np.ndarray: @define class ErrorEntry: + """ + Represents a failed download attempt with error information. + + This class stores information about failed downloads, including + the error code, error message, and retry count. + + Attributes: + uuid: Unique identifier for the download attempt + identifier: URL or other identifier for the image + retry_count: Number of retry attempts made + error_code: Error code from the download attempt + error_msg: Error message describing the failure + """ + uuid: str identifier: str retry_count: int @@ -232,6 +381,15 @@ class ErrorEntry: @classmethod def from_downloaded(cls, downloaded: DownloadedImage) -> ErrorEntry: + """ + Create an ErrorEntry from a DownloadedImage. + + Args: + downloaded: The DownloadedImage that failed + + Returns: + ErrorEntry: A new error entry instance + """ return cls( uuid=downloaded.unique_name, identifier=downloaded.identifier, @@ -242,6 +400,15 @@ def from_downloaded(cls, downloaded: DownloadedImage) -> ErrorEntry: @staticmethod def to_list_download(downloaded: DownloadedImage) -> List: + """ + Convert a DownloadedImage to a list format for error storage. + + Args: + downloaded: DownloadedImage to convert + + Returns: + List: List of error values in the correct order for storage + """ return [ downloaded.unique_name, downloaded.identifier, @@ -251,6 +418,12 @@ def to_list_download(downloaded: DownloadedImage) -> List: ] def to_list(self) -> List: + """ + Convert this ErrorEntry to a list format. + + Returns: + List: List representation of this entry + """ return [ self.uuid, self.identifier, @@ -260,6 +433,12 @@ def to_list(self) -> List: ] def to_np(self) -> np.ndarray: + """ + Convert this ErrorEntry to a NumPy array. + + Returns: + np.ndarray: NumPy array representation of this entry + """ np_structure = np.array( [ (self.uuid, @@ -274,6 +453,12 @@ def to_np(self) -> np.ndarray: @staticmethod def get_names() -> List[str]: + """ + Get the column names for error entries. + + Returns: + List[str]: List of column names + """ return [ "uuid", "identifier", @@ -285,6 +470,21 @@ def get_names() -> List[str]: @define class ImageBatchesByServerToRequest: + """ + Container for batches of images to be requested from a specific server. + + This class manages queue of URL batches for a single server, along with + synchronization primitives for coordinating access. + + Attributes: + server_name: Name of the server + lock: Lock for synchronized access to this container + writer_notification: Event to notify when writing is completed + urls: Queue of URL batches to process + max_rate: Maximum request rate for this server + total_batches: Total number of batches to process + """ + server_name: str lock: threading.Lock writer_notification: threading.Event @@ -298,6 +498,18 @@ def from_pandas(cls, manager: multiprocessing.Manager, urls: List[DataFrame], max_rate: int = 50) -> ImageBatchesByServerToRequest: + """ + Create an instance from pandas DataFrames. + + Args: + server_name: Name of the server + manager: Multiprocessing manager for creating synchronized objects + urls: List of DataFrame batches containing URLs + max_rate: Maximum request rate for this server + + Returns: + ImageBatchesByServerToRequest: A new instance + """ urls_queue: queue.Queue[List[Dict[str, Any]]] = queue.Queue() for url_batch in urls: urls_queue.put(url_batch.to_dict("records")) @@ -314,6 +526,19 @@ def from_pandas(cls, @define class CompletedBatch: + """ + Container for completed download results. + + This class holds queues of successful and failed downloads + from a batch processing operation. + + Attributes: + success_queue: Queue of successfully downloaded images + error_queue: Queue of failed download attempts + batch_id: Identifier for this batch + offset: Offset within the processing sequence + """ + success_queue: queue.Queue[DownloadedImage] error_queue: queue.Queue[DownloadedImage] batch_id: int = -1 @@ -322,30 +547,67 @@ class CompletedBatch: @define class WriterServer: + """ + Coordinates writing results for a specific server. + + This class tracks the completion status of downloads for a server + and manages the queue of completed batches. + + Attributes: + server_name: Name of the server + download_complete: Event signaling when downloads are complete + completed_queue: Queue of completed batches + total_batches: Total number of batches to process + done_batches: Number of batches that have been processed + """ + server_name: str download_complete: threading.Event - competed_queue: queue.Queue[CompletedBatch] + completed_queue: queue.Queue[CompletedBatch] total_batches: int done_batches: int = 0 @define class RateLimit: + """ + Dynamic rate limiting configuration. + + This class manages the rate limits for download requests, + providing upper and lower bounds for adaptive rate control. + + Attributes: + initial_rate: Starting rate limit + _multiplier: Multiplier for determining bounds + lower_bound: Minimum allowed rate + upper_bound: Maximum allowed rate + """ + initial_rate = field(init=True, type=float) _multiplier = field(init=True, type=float, default=0.5, validator=lambda _, __, value: 0 < value) lower_bound = field(init=False, type=int, converter=math.floor, default=0) upper_bound = field(init=False, type=int, converter=math.floor, default=0) def __attrs_post_init__(self): + """ + Post-initialization setup to calculate rate limits. + """ self.lower_bound = max(self.initial_rate * (1 - self._multiplier), 1) self.upper_bound = self.initial_rate * (1 + self._multiplier) def change_rate(self, new_rate: float): + """ + Update rate limits based on a new rate. + + Args: + new_rate: New rate to base limits on + """ self.initial_rate = new_rate self.lower_bound = max(self.initial_rate * (1 - self._multiplier), 1) self.upper_bound = self.initial_rate * (1 + self._multiplier) +# NumPy dtype for server profile data profile_dtype = np.dtype([ ("server_name", "S256"), ("total_batches", "i4"), diff --git a/src/distributed_downloader/core/mpi_downloader/utils.py b/src/distributed_downloader/core/mpi_downloader/utils.py index b7cc3a8..1743fe2 100644 --- a/src/distributed_downloader/core/mpi_downloader/utils.py +++ b/src/distributed_downloader/core/mpi_downloader/utils.py @@ -1,21 +1,41 @@ from __future__ import annotations +""" +Utility functions for the MPI-based distributed downloader. + +This module provides helper functions for the downloader system, including: +- Session management for HTTP requests +- Schedule handling and parsing +- Downloader initialization and reuse +- Batch processing utilities +- Time management for job execution +""" + import logging import os -import shutil import time -from typing import Dict, Tuple, Union, List, Any, Deque, Set +from typing import Any, Deque, Dict, List, Set, Tuple, Union import pandas as pd import requests from requests.adapters import HTTPAdapter from urllib3 import Retry -from .Downloader import Downloader from .dataclasses import RateLimit +from .Downloader import Downloader def create_new_session(url: str, max_rate: int) -> requests.Session: + """ + Create a new HTTP session with retry logic and connection pooling. + + Args: + url: Base URL for the server + max_rate: Maximum number of concurrent connections in the pool + + Returns: + requests.Session: Configured session with retry and pooling settings + """ session = requests.Session() retry = Retry(total=5, backoff_factor=0.1, status_forcelist=[500, 502, 503, 504]) adapter = HTTPAdapter(max_retries=retry, pool_maxsize=max_rate, pool_connections=max_rate) @@ -26,6 +46,16 @@ def create_new_session(url: str, max_rate: int) -> requests.Session: def get_latest_schedule(path_to_dir: str, rank: int = None) -> Union[pd.DataFrame, None]: + """ + Get the most recent schedule file from a directory. + + Args: + path_to_dir: Directory containing schedule files + rank: Optional MPI rank to filter the schedule for a specific worker + + Returns: + pd.DataFrame or None: DataFrame containing schedule information, or None if no schedules found + """ if not os.path.exists(path_to_dir) or not os.path.isdir(path_to_dir): return None @@ -50,6 +80,25 @@ def get_or_init_downloader(header: dict, rate_multiplier: float, job_end_time: int, logger: logging.Logger) -> Tuple[Downloader, requests.Session, RateLimit]: + """ + Get an existing downloader for a server or initialize a new one. + + This function maintains a cache of downloaders per server to avoid + creating multiple connections to the same server. + + Args: + header: HTTP headers to use for requests + img_size: Target size for image resizing + schedule_dict: Dictionary containing server information and rate limits + downloader_schedule: Cache of existing downloaders + rate_multiplier: Rate limit adjustment multiplier + job_end_time: Unix timestamp when the job should end + logger: Logger instance for output + + Returns: + Tuple[Downloader, requests.Session, RateLimit]: + The downloader, its session, and the rate limiter + """ if schedule_dict["server_name"] not in downloader_schedule.keys(): server_name = schedule_dict["server_name"].replace("%3A", ":") rate_limit = RateLimit(schedule_dict["rate_limit"], rate_multiplier) @@ -62,6 +111,20 @@ def get_or_init_downloader(header: dict, def generate_ids_to_download(schedule_row: pd.Series, verifier_df: pd.DataFrame) -> pd.Series: + """ + Determine which batch IDs need to be downloaded for a server. + + This function compares the total range of batches for a server with + those that have already been verified as downloaded, and returns + the difference (batches that still need to be downloaded). + + Args: + schedule_row: Row from schedule DataFrame with server information + verifier_df: DataFrame containing verification status of downloaded batches + + Returns: + pd.Series: Series with server_name and list of batch IDs to download + """ server_name = schedule_row["server_name"] server_start_idx = schedule_row["start_index"] server_end_idx = schedule_row["end_index"] @@ -80,6 +143,18 @@ def generate_ids_to_download(schedule_row: pd.Series, verifier_df: pd.DataFrame) def separate_to_blocks(data_row: pd.Series) -> List[List[Tuple[int, int]]]: + """ + Organize batch IDs into processing blocks for efficient distribution. + + This function takes a list of batch IDs and organizes them into blocks + based on the available processing capacity (nodes × processes per node). + + Args: + data_row: Row containing server information and batch IDs to process + + Returns: + List[List[Tuple[int, int]]]: Nested lists of batch ranges + """ batches: List[int] = data_row["batches"] num_of_blocks: int = data_row["process_per_node"] * data_row["nodes"] @@ -101,6 +176,15 @@ def separate_to_blocks(data_row: pd.Series) -> List[List[Tuple[int, int]]]: def compress_ids(ids: List[int]) -> List[Tuple[int, int]]: + """ + Compress consecutive batch IDs into ranges for efficient storage. + + Args: + ids: List of batch IDs (integers) + + Returns: + List[Tuple[int, int]]: List of (start_id, end_id+1) tuples representing ranges + """ if len(ids) < 1: return [] compressed_ids = [] @@ -118,6 +202,19 @@ def compress_ids(ids: List[int]) -> List[Tuple[int, int]]: def get_largest_nonempty_bucket(buckets: Dict[int, Deque[Dict[str, Any]]], avail_space: int) -> int: + """ + Find the largest bucket that fits within available space. + + This function is used in scheduling to find the largest process_per_node + value that can be accommodated with the remaining worker slots. + + Args: + buckets: Dictionary mapping bucket sizes to queues of servers + avail_space: Available space (number of worker slots) + + Returns: + int: Size of the largest bucket that fits, or 0 if none fit + """ largest_bucket = 0 for key, bucket in buckets.items(): @@ -130,11 +227,35 @@ def get_largest_nonempty_bucket(buckets: Dict[int, Deque[Dict[str, Any]]], avail def is_enough_time(rate_limit: RateLimit, batch_size: int = 10000, avg_write_time: int = 600, job_end_time: int = int(os.getenv("SLURM_JOB_END_TIME", 0))) -> bool: + """ + Check if there is enough time left in the job to process a batch. + + This function estimates if there's enough time to download and process + a batch based on the current rate limit and average write time. + + Args: + rate_limit: Rate limit object with current download rate + batch_size: Size of the batch to be processed + avg_write_time: Average time in seconds needed to write results + job_end_time: Unix timestamp when the job is scheduled to end + + Returns: + bool: True if there is enough time, False otherwise + """ current_time = time.time() time_left = job_end_time - current_time - avg_write_time return rate_limit.initial_rate * time_left >= batch_size def get_schedule_count(path_to_dir: str) -> int: + """ + Count the number of schedule files in a directory. + + Args: + path_to_dir: Path to the directory containing schedule files + + Returns: + int: Count of schedule files (excluding files starting with '_') + """ schedule_files = [file for file in os.listdir(path_to_dir) if not file.startswith("_")] return len(schedule_files) diff --git a/src/distributed_downloader/core/schemes/multimedia_scheme.py b/src/distributed_downloader/core/schemes/multimedia_scheme.py deleted file mode 100644 index 60efe37..0000000 --- a/src/distributed_downloader/core/schemes/multimedia_scheme.py +++ /dev/null @@ -1,19 +0,0 @@ -from pyspark.sql.types import StructField, StructType, StringType, LongType, TimestampType - -schema = StructType([ - StructField("gbifID", LongType(), True), - StructField("type", StringType(), True), - StructField("format", StringType(), True), - StructField("identifier", StringType(), True), - StructField("references", StringType(), True), - StructField("title", StringType(), True), - StructField("description", StringType(), True), - StructField("source", StringType(), True), - StructField("audience", StringType(), True), - StructField("created", TimestampType(), True), - StructField("creator", StringType(), True), - StructField("contributor", StringType(), True), - StructField("publisher", StringType(), True), - StructField("license", StringType(), True), - StructField("rightsHolder", StringType(), True) -]) diff --git a/src/distributed_downloader/tools/checkpoint.py b/src/distributed_downloader/tools/checkpoint.py index 799d241..820683c 100644 --- a/src/distributed_downloader/tools/checkpoint.py +++ b/src/distributed_downloader/tools/checkpoint.py @@ -18,6 +18,8 @@ def __load_checkpoint(path: str, default_structure: Dict[str, bool]) -> Dict[str try: with open(path, "r") as file: checkpoint = yaml.full_load(file) + if checkpoint is None: + checkpoint = {} for key, value in default_structure.items(): if key not in checkpoint: checkpoint[key] = value diff --git a/src/distributed_downloader/tools/config_templates/downloader.yaml b/src/distributed_downloader/tools/config_templates/downloader.yaml index 2473371..78c4ae3 100644 --- a/src/distributed_downloader/tools/config_templates/downloader.yaml +++ b/src/distributed_downloader/tools/config_templates/downloader.yaml @@ -1,6 +1,7 @@ account: "" path_to_input: "" path_to_output_folder: "" +initializer_type: "" scripts: # Wrapper scripts to submit jobs to the cluster diff --git a/src/distributed_downloader/tools/filter.py b/src/distributed_downloader/tools/filter.py index 2e6fc2c..c3d9685 100644 --- a/src/distributed_downloader/tools/filter.py +++ b/src/distributed_downloader/tools/filter.py @@ -1,6 +1,7 @@ import argparse import os +from distributed_downloader.tools import Checkpoint from distributed_downloader.tools.utils import init_logger from distributed_downloader.tools.config import Config from distributed_downloader.tools.registry import ToolsRegistryBase @@ -10,22 +11,38 @@ if config_path is None: raise ValueError("CONFIG_PATH not set") - config = Config.from_path(config_path, "tools") - logger = init_logger(__name__) - - parser = argparse.ArgumentParser(description='Filtering step of the Tool') - parser.add_argument("filter_name", metavar="filter_name", type=str, - help="the name of the tool that is intended to be used") + parser = argparse.ArgumentParser(description="Filtering step of the Tool") + parser.add_argument( + "filter_name", + metavar="filter_name", + type=str, + help="the name of the tool that is intended to be used", + ) _args = parser.parse_args() tool_name = _args.filter_name - assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError("unknown filter") + assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError( + "unknown filter" + ) + + config = Config.from_path(config_path, "tools") + checkpoint = Checkpoint.from_path( + os.path.join( + config.get_folder("tools_folder"), tool_name, "tool_checkpoint.yaml" + ), + {"filtering_scheduled": True, "filtering_completed": False}, + ) + logger = init_logger(__name__) + checkpoint["filtering_scheduled"] = False tool_filter = ToolsRegistryBase.TOOLS_REGISTRY[tool_name]["filter"](config) - logger.info("Starting filter") - tool_filter.run() + if not checkpoint.get("filtering_completed", False): + logger.info("Starting filter") + tool_filter.run() - logger.info("completed filtering") + logger.info("completed filtering") - tool_filter = None + checkpoint["filtering_completed"] = True + else: + logger.info("Filtering was already completed") diff --git a/src/distributed_downloader/tools/filters.py b/src/distributed_downloader/tools/filters.py index 7fd86b5..e2295f8 100644 --- a/src/distributed_downloader/tools/filters.py +++ b/src/distributed_downloader/tools/filters.py @@ -1,10 +1,12 @@ import os.path from functools import partial +from typing import Optional import pandas as pd import pyspark.sql as ps import pyspark.sql.functions as func from pyspark.sql import SparkSession +from pyspark.sql.types import StructType from distributed_downloader.core.mpi_downloader.dataclasses import SuccessEntry from distributed_downloader.tools.config import Config @@ -12,11 +14,13 @@ from distributed_downloader.tools.registry import ToolsRegistryBase FilterRegister = partial(ToolsRegistryBase.register, "filter") -__all__ = ["FilterRegister", - "SizeBasedFiltering", - "DuplicatesBasedFiltering", - "ResizeToolFilter", - "ImageVerificationToolFilter"] +__all__ = [ + "FilterRegister", + "SizeBasedFiltering", + "DuplicatesBasedFiltering", + "ResizeToolFilter", + "ImageVerificationToolFilter", +] class FilterToolBase(ToolsBase): @@ -31,30 +35,39 @@ class SparkFilterToolBase(FilterToolBase): def __init__(self, cfg: Config, spark: SparkSession = None): super().__init__(cfg) - self.spark: SparkSession = spark if spark is not None else SparkSession.builder.appName( - "Filtering").getOrCreate() + self.spark: SparkSession = ( + spark + if spark is not None + else SparkSession.builder.appName("Filtering").getOrCreate() + ) self.spark.conf.set("spark.sql.parquet.datetimeRebaseModeInWrite", "CORRECTED") self.spark.conf.set("spark.sql.parquet.int96RebaseModeInWrite", "CORRECTED") def run(self): raise NotImplementedError() - def load_data_parquet(self): - return (self.spark - .read - .schema(self.success_scheme) - .option("basePath", self.downloaded_images_path) - .parquet(self.downloaded_images_path + "/server_name=*/partition_id=*/successes.parquet")) + def load_data_parquet(self, scheme: Optional[StructType] = None): + if scheme is None: + scheme = self.success_scheme + return ( + self.spark.read.schema(scheme) + .option("basePath", self.downloaded_images_path) + .parquet( + self.downloaded_images_path + + "/server_name=*/partition_id=*/successes.parquet" + ) + ) def save_filter(self, df: ps.DataFrame): if self.filter_name is None: raise ValueError("filter name was not defined") - (df - .repartition(10) - .write - .csv(os.path.join(self.tools_path, self.filter_name, "filter_table"), - header=True, - mode="overwrite")) + ( + df.repartition(10).write.csv( + os.path.join(self.tools_path, self.filter_name, "filter_table"), + header=True, + mode="overwrite", + ) + ) def __del__(self): if self.spark is not None: @@ -63,30 +76,29 @@ def __del__(self): @FilterRegister("size_based") class SizeBasedFiltering(SparkFilterToolBase): - def __init__(self, cfg: Config, spark: SparkSession = None): super().__init__(cfg, spark) self.filter_name: str = "size_based" - assert "threshold_size" in self.config["tools_parameters"], ( - ValueError("threshold_size have to be defined")) - assert isinstance(self.config["tools_parameters"]["threshold_size"], int), ( - ValueError("threshold_size have to be Integer")) + assert "threshold_size" in self.config["tools_parameters"], ValueError( + "threshold_size have to be defined" + ) + assert isinstance( + self.config["tools_parameters"]["threshold_size"], int + ), ValueError("threshold_size have to be Integer") self.threshold_size = self.config["tools_parameters"]["threshold_size"] def run(self): successes_df: ps.DataFrame = self.load_data_parquet() - successes_df = (successes_df - .withColumn("is_big", - func.array_min(func.col("original_size")) >= - self.threshold_size)) + successes_df = successes_df.withColumn( + "is_big", func.array_min(func.col("original_size")) >= self.threshold_size + ) - too_small_images = successes_df.filter(~successes_df["is_big"]).select("uuid", - "gbif_id", - "server_name", - "partition_id") + too_small_images = successes_df.filter(~successes_df["is_big"]).select( + "uuid", "source_id", "server_name", "partition_id" + ) self.save_filter(too_small_images) @@ -95,7 +107,6 @@ def run(self): @FilterRegister("duplication_based") class DuplicatesBasedFiltering(SparkFilterToolBase): - def __init__(self, cfg: Config, spark: SparkSession = None): super().__init__(cfg, spark) self.filter_name: str = "duplication_based" @@ -103,34 +114,43 @@ def __init__(self, cfg: Config, spark: SparkSession = None): def run(self): successes_df: ps.DataFrame = self.load_data_parquet() - not_duplicate_records = (successes_df - .groupBy("hashsum_original") - .count() - .where('count = 1') - .drop('count')) + not_duplicate_records = ( + successes_df.groupBy("hashsum_original") + .count() + .where("count = 1") + .drop("count") + ) - duplicate_records = (successes_df - .join(not_duplicate_records, on="hashsum_original", how='left_anti') - .select("uuid", "gbif_id", "server_name", "partition_id", "hashsum_original")) + duplicate_records = successes_df.join( + not_duplicate_records, on="hashsum_original", how="left_anti" + ).select("uuid", "source_id", "server_name", "partition_id", "hashsum_original") - window = ps.Window.partitionBy("hashsum_original").orderBy("partition_id", "server_name") + window = ps.Window.partitionBy("hashsum_original").orderBy( + "partition_id", "server_name" + ) - duplicate_records_top = (duplicate_records - .withColumn("rn", func.row_number().over(window)) - .where("rn == 1") - .drop("rn")) + duplicate_records_top = ( + duplicate_records.withColumn("rn", func.row_number().over(window)) + .where("rn == 1") + .drop("rn") + ) duplicate_records_top = duplicate_records_top.withColumnsRenamed( - {"uuid": "uuid_main", - "gbif_id": "gbif_id_main", - "server_name": "server_name_main", - "partition_id": "partition_id_main"}) - - duplicate_records = (duplicate_records - .join(duplicate_records_top, on="hashsum_original", how="left") - .where("uuid != uuid_main") - .drop("hashsum_original") - ) + { + "uuid": "uuid_main", + "source_id": "source_id_main", + "server_name": "server_name_main", + "partition_id": "partition_id_main", + } + ) + + duplicate_records = ( + duplicate_records.join( + duplicate_records_top, on="hashsum_original", how="left" + ) + .where("uuid != uuid_main") + .drop("hashsum_original") + ) self.save_filter(duplicate_records) @@ -138,7 +158,6 @@ def run(self): class PythonFilterToolBase(FilterToolBase): - def __init__(self, cfg: Config): super().__init__(cfg) @@ -149,8 +168,9 @@ def get_all_paths_to_merge(self) -> pd.DataFrame: server_name = folder.split("=")[1] for partition in os.listdir(f"{path}/{folder}"): partition_path = f"{path}/{folder}/{partition}" - if (not os.path.exists(f"{partition_path}/successes.parquet") or - not os.path.exists(f"{partition_path}/completed")): + if not os.path.exists( + f"{partition_path}/successes.parquet" + ) or not os.path.exists(f"{partition_path}/completed"): continue all_schedules.append([server_name, partition.split("=")[1]]) return pd.DataFrame(all_schedules, columns=["server_name", "partition_id"]) @@ -158,15 +178,18 @@ def get_all_paths_to_merge(self) -> pd.DataFrame: def run(self): filter_table = self.get_all_paths_to_merge() - filter_table_folder = os.path.join(self.tools_path, self.filter_name, "filter_table") + filter_table_folder = os.path.join( + self.tools_path, self.filter_name, "filter_table" + ) os.makedirs(filter_table_folder, exist_ok=True) - filter_table.to_csv(filter_table_folder + "/table.csv", header=True, index=False) + filter_table.to_csv( + filter_table_folder + "/table.csv", header=True, index=False + ) @FilterRegister("resize") class ResizeToolFilter(PythonFilterToolBase): - def __init__(self, cfg: Config): super().__init__(cfg) self.filter_name = "resize" @@ -174,7 +197,6 @@ def __init__(self, cfg: Config): @FilterRegister("image_verification") class ImageVerificationToolFilter(PythonFilterToolBase): - def __init__(self, cfg: Config): super().__init__(cfg) self.filter_name = "image_verification" diff --git a/src/distributed_downloader/tools/main.py b/src/distributed_downloader/tools/main.py index e23ed8b..10be816 100644 --- a/src/distributed_downloader/tools/main.py +++ b/src/distributed_downloader/tools/main.py @@ -25,40 +25,56 @@ class Tools: logger: Logger = field(default=Factory(lambda: init_logger(__name__))) - tool_folder: str = None - tool_job_history_path: str = None - tool_checkpoint_path: str = None + tool_folder: Optional[str] = None + tool_job_history_path: Optional[str] = None + tool_checkpoint_path: Optional[str] = None checkpoint_scheme = { - "filtered": False, - "schedule_created": False, - "completed": False + "filtering_scheduled": False, + "filtering_completed": False, + "scheduling_scheduled": False, + "scheduling_completed": False, + "completed": False, } - tool_checkpoint: Checkpoint = None + tool_checkpoint: Optional[Checkpoint] = None _checkpoint_override: Optional[Dict[str, bool]] = None - tool_job_history: List[int] = None - tool_job_history_io: TextIO = None + tool_job_history: Optional[List[int]] = None + tool_job_history_io: Optional[TextIO] = None @classmethod - def from_path(cls, path: str, - tool_name: str, - checkpoint_override: Optional[Dict[str, bool]] = None) -> "Tools": - if tool_name not in ToolsRegistryBase.TOOLS_REGISTRY.keys(): + def from_path( + cls, + path: str, + tool_name: str, + checkpoint_override: Optional[Dict[str, bool]] = None, + tool_name_override: Optional[bool] = False, + ) -> "Tools": + if ( + not tool_name_override + and tool_name not in ToolsRegistryBase.TOOLS_REGISTRY.keys() + ): raise ValueError("unknown tool name") - return cls(config=Config.from_path(path, "tools"), - tool_name=tool_name, - checkpoint_override=checkpoint_override) + return cls( + config=Config.from_path(path, "tools"), + tool_name=tool_name, + checkpoint_override=checkpoint_override, + ) def __attrs_post_init__(self): # noinspection PyTypeChecker - self.tool_folder: str = os.path.join(self.config.get_folder("tools_folder"), - self.tool_name) - self.tool_job_history_path: str = os.path.join(self.tool_folder, "job_history.csv") - self.tool_checkpoint_path: str = os.path.join(self.tool_folder, "tool_checkpoint.yaml") + self.tool_folder: str = os.path.join( + self.config.get_folder("tools_folder"), self.tool_name + ) + self.tool_job_history_path: str = os.path.join( + self.tool_folder, "job_history.csv" + ) + self.tool_checkpoint_path: str = os.path.join( + self.tool_folder, "tool_checkpoint.yaml" + ) self.__init_environment() - self.__init_filestructure() + self.__init_file_structure() def __init_environment(self) -> None: os.environ["CONFIG_PATH"] = self.config.config_path @@ -69,27 +85,31 @@ def __init_environment(self) -> None: os.environ["PATH_TO_OUTPUT"] = self.config["path_to_output_folder"] for output_folder, output_path in self.config.folder_structure.items(): os.environ["OUTPUT_" + output_folder.upper()] = output_path - os.environ["OUTPUT_TOOLS_LOGS_FOLDER"] = os.path.join(self.tool_folder, - "logs") + os.environ["OUTPUT_TOOLS_LOGS_FOLDER"] = os.path.join(self.tool_folder, "logs") for downloader_var, downloader_value in self.config["tools_parameters"].items(): os.environ["TOOLS_" + downloader_var.upper()] = str(downloader_value) self.logger.info("Environment initialized") - def __init_filestructure(self): - ensure_created([ - self.tool_folder, - os.path.join(self.tool_folder, "filter_table"), - os.path.join(self.tool_folder, "verification"), - os.path.join(self.tool_folder, "logs") - ]) - - self.tool_checkpoint = Checkpoint.from_path(self.tool_checkpoint_path, self.checkpoint_scheme) + def __init_file_structure(self): + ensure_created( + [ + self.tool_folder, + os.path.join(self.tool_folder, "filter_table"), + os.path.join(self.tool_folder, "verification"), + os.path.join(self.tool_folder, "logs"), + ] + ) + + self.tool_checkpoint = Checkpoint.from_path( + self.tool_checkpoint_path, self.checkpoint_scheme + ) if self._checkpoint_override is not None: for key, value in self._checkpoint_override.items(): if key == "verification": truncate_paths([os.path.join(self.tool_folder, "verification")]) + self.tool_checkpoint["completed"] = False continue if key not in self.checkpoint_scheme.keys(): raise KeyError("Unknown key for override in checkpoint") @@ -118,22 +138,27 @@ def __update_job_history(self, new_id: int) -> None: def __schedule_filtering(self) -> None: self.logger.info("Scheduling filtering script") - job_id = submit_job(self.config.get_script("tools_submitter"), - self.config.get_script("tools_filter_script"), - self.tool_name, - *preprocess_dep_ids( - [self.tool_job_history[-1] if len(self.tool_job_history) != 0 else None]), - "--spark") + job_id = submit_job( + self.config.get_script("tools_submitter"), + self.config.get_script("tools_filter_script"), + self.tool_name, + *preprocess_dep_ids( + [self.tool_job_history[-1] if len(self.tool_job_history) != 0 else None] + ), + "--spark", + ) self.__update_job_history(job_id) self.tool_checkpoint["filtered"] = True self.logger.info("Scheduled filtering script") def __schedule_schedule_creation(self) -> None: self.logger.info("Scheduling schedule creation script") - job_id = submit_job(self.config.get_script("tools_submitter"), - self.config.get_script("tools_scheduling_script"), - self.tool_name, - *preprocess_dep_ids([self.tool_job_history[-1]])) + job_id = submit_job( + self.config.get_script("tools_submitter"), + self.config.get_script("tools_scheduling_script"), + self.tool_name, + *preprocess_dep_ids([self.tool_job_history[-1]]), + ) self.__update_job_history(job_id) self.tool_checkpoint["schedule_created"] = True self.logger.info("Scheduled schedule creation script") @@ -142,30 +167,44 @@ def __schedule_workers(self) -> None: self.logger.info("Scheduling workers script") for _ in range(self.config["tools_parameters"]["num_workers"]): - job_id = submit_job(self.config.get_script("tools_submitter"), - self.config.get_script("tools_worker_script"), - self.tool_name, - *preprocess_dep_ids([self.tool_job_history[-1]])) + job_id = submit_job( + self.config.get_script("tools_submitter"), + self.config.get_script("tools_worker_script"), + self.tool_name, + *preprocess_dep_ids([self.tool_job_history[-1]]), + ) self.__update_job_history(job_id) - job_id = submit_job(self.config.get_script("tools_submitter"), - self.config.get_script("tools_verification_script"), - self.tool_name, - *preprocess_dep_ids([self.tool_job_history[-1]])) + job_id = submit_job( + self.config.get_script("tools_submitter"), + self.config.get_script("tools_verification_script"), + self.tool_name, + *preprocess_dep_ids([self.tool_job_history[-1]]), + ) self.__update_job_history(job_id) self.logger.info("Scheduled workers script") def apply_tool(self): - if not self.tool_checkpoint.get("filtered", False): + if not ( + self.tool_checkpoint.get("filtering_scheduled", False) + or self.tool_checkpoint.get("filtering_completed", False) + ): self.__schedule_filtering() else: - self.logger.info("Skipping filtering script: table already created") - - if not self.tool_checkpoint.get("schedule_created", False): + self.logger.info( + "Skipping filtering script: job is already scheduled or table has been already created" + ) + + if not ( + self.tool_checkpoint.get("schedule_scheduled", False) + or self.tool_checkpoint.get("schedule_completed", False) + ): self.__schedule_schedule_creation() else: - self.logger.info("Skipping schedule creation script: schedule already created") + self.logger.info( + "Skipping schedule creation script: job is already scheduled or schedule has been already created" + ) if not self.tool_checkpoint.get("completed", False): self.__schedule_workers() @@ -178,14 +217,37 @@ def __del__(self): def main(): - parser = argparse.ArgumentParser(description='Tools') - parser.add_argument("config_path", metavar="config_path", type=str, - help="the name of the tool that is intended to be used") - parser.add_argument("tool_name", metavar="tool_name", type=str, - help="the name of the tool that is intended to be used") - parser.add_argument("--reset_filtering", action="store_true", help="Will reset filtering and scheduling steps") - parser.add_argument("--reset_scheduling", action="store_true", help="Will reset scheduling step") - parser.add_argument("--reset_runners", action="store_true", help="Will reset runners, making them to start over") + parser = argparse.ArgumentParser(description="Tools") + parser.add_argument( + "config_path", + metavar="config_path", + type=str, + help="the name of the tool that is intended to be used", + ) + parser.add_argument( + "tool_name", + metavar="tool_name", + type=str, + help="the name of the tool that is intended to be used", + ) + parser.add_argument( + "--reset_filtering", + action="store_true", + help="Will reset filtering and scheduling steps", + ) + parser.add_argument( + "--reset_scheduling", action="store_true", help="Will reset scheduling step" + ) + parser.add_argument( + "--reset_runners", + action="store_true", + help="Will reset runners, making them to start over", + ) + parser.add_argument( + "--tool_name_override", + action="store_true", + help="Will override tool name check (allows for custom tool run)", + ) _args = parser.parse_args() config_path = _args.config_path @@ -193,22 +255,25 @@ def main(): state_override = None if _args.reset_filtering: state_override = { - "filtered": False, - "schedule_created": False, - "verification": False + "filtering_scheduled": False, + "filtering_completed": False, + "scheduling_scheduled": False, + "scheduling_completed": False, + "verification": False, + "completed": False, } elif _args.reset_scheduling: state_override = { - "schedule_created": False - } - elif _args.reset_runners: - state_override = { - "verification": False + "scheduling_scheduled": False, + "scheduling_completed": False, + "completed": False, } + if _args.reset_runners: + state_override = {"verification": False, "completed": False} - dd = Tools.from_path(config_path, - tool_name, - state_override) + dd = Tools.from_path( + config_path, tool_name, state_override, _args.tool_name_override + ) dd.apply_tool() diff --git a/src/distributed_downloader/tools/runner.py b/src/distributed_downloader/tools/runner.py index fb69c20..f173e2d 100644 --- a/src/distributed_downloader/tools/runner.py +++ b/src/distributed_downloader/tools/runner.py @@ -1,6 +1,7 @@ import argparse import os +from distributed_downloader.tools import Checkpoint from distributed_downloader.tools.utils import init_logger from distributed_downloader.tools.config import Config from distributed_downloader.tools.registry import ToolsRegistryBase @@ -10,20 +11,36 @@ if config_path is None: raise ValueError("CONFIG_PATH not set") - config = Config.from_path(config_path, "tools") - logger = init_logger(__name__) - - parser = argparse.ArgumentParser(description='Running step of the Tool') - parser.add_argument("runner_name", metavar="runner_name", type=str, - help="the name of the tool that is intended to be used") + parser = argparse.ArgumentParser(description="Running step of the Tool") + parser.add_argument( + "runner_name", + metavar="runner_name", + type=str, + help="the name of the tool that is intended to be used", + ) _args = parser.parse_args() tool_name = _args.runner_name - assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError("unknown runner") + assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError( + "unknown runner" + ) + + config = Config.from_path(config_path, "tools") + checkpoint = Checkpoint.from_path( + os.path.join( + config.get_folder("tools_folder"), tool_name, "tool_checkpoint.yaml" + ), + {"scheduling_completed": False}, + ) + logger = init_logger(__name__) + + if not checkpoint.get("scheduling_completed", False): + logger.error("Scheduling wasn't complete, can't perform work") + exit(1) - tool_filter = ToolsRegistryBase.TOOLS_REGISTRY[tool_name]["runner"](config) + tool_runner = ToolsRegistryBase.TOOLS_REGISTRY[tool_name]["runner"](config) logger.info("Starting runner") - tool_filter.run() + tool_runner.run() logger.info("completed runner") diff --git a/src/distributed_downloader/tools/runners.py b/src/distributed_downloader/tools/runners.py index baca458..fb3cf88 100644 --- a/src/distributed_downloader/tools/runners.py +++ b/src/distributed_downloader/tools/runners.py @@ -3,13 +3,12 @@ import os import time from functools import partial -from typing import List, TextIO, Tuple +from typing import List, TextIO, Tuple, Optional import cv2 import numpy as np import pandas as pd from PIL import UnidentifiedImageError, Image -import mpi4py.MPI as MPI from distributed_downloader.tools.config import Config from distributed_downloader.tools.registry import ToolsBase, ToolsRegistryBase @@ -33,19 +32,21 @@ def __init__(self, cfg: Config): class MPIRunnerTool(RunnerToolBase): def __init__(self, cfg: Config): + import mpi4py.MPI as MPI + super().__init__(cfg) - self.filter_folder: str = None - self.filter_table_folder: str = None - self.verification_folder: str = None - self.verification_IO: TextIO = None + self.filter_folder: Optional[str] = None + self.filter_table_folder: Optional[str] = None + self.verification_folder: Optional[str] = None + self.verification_IO: Optional[TextIO] = None - self.data_scheme: List[str] = None - self.verification_scheme: List[str] = None + self.data_scheme: Optional[List[str]] = None + self.verification_scheme: Optional[List[str]] = None self.mpi_comm: MPI.Intracomm = MPI.COMM_WORLD self.mpi_rank: int = self.mpi_comm.rank - self.total_time: int = None + self.total_time: Optional[int] = None def is_enough_time(self): assert self.total_time is not None, ValueError("total_time is not set") @@ -83,8 +84,8 @@ def ensure_folders_created(self): def get_schedule(self): schedule_df = pd.read_csv(os.path.join(self.filter_folder, "schedule.csv")) schedule_df = schedule_df.query(f"rank == {self.mpi_rank}") - verification_df = self.load_table(self.verification_folder, ["server_name", "partition_id"]) - outer_join = schedule_df.merge(verification_df, how='outer', indicator=True, on=["server_name", "partition_id"]) + verification_df = self.load_table(self.verification_folder, self.verification_scheme) + outer_join = schedule_df.merge(verification_df, how='outer', indicator=True, on=self.verification_scheme) return outer_join[(outer_join["_merge"] == 'left_only')].drop('_merge', axis=1) def get_remaining_table(self, schedule: pd.DataFrame) -> pd.api.typing.DataFrameGroupBy: @@ -93,10 +94,10 @@ def get_remaining_table(self, schedule: pd.DataFrame) -> pd.api.typing.DataFrame df = self.load_table(self.filter_table_folder) df = df.merge(schedule, how="right", - on=["server_name", "partition_id"]) + on=self.verification_scheme) df = df[self.data_scheme] - return df.groupby(["server_name", "partition_id"], group_keys=True) + return df.groupby(self.verification_scheme, group_keys=True) def apply_filter(self, filtering_df: pd.DataFrame, server_name: str, partition_id: str) -> int: raise NotImplementedError() @@ -143,7 +144,7 @@ class FilterRunnerTool(MPIRunnerTool): def __init__(self, cfg: Config): super().__init__(cfg) - self.data_scheme: List[str] = ["uuid", "gbif_id", "server_name", "partition_id"] + self.data_scheme: List[str] = ["uuid", "source_id", "server_name", "partition_id"] self.verification_scheme: List[str] = ["server_name", "partition_id"] self.total_time = 150 @@ -203,9 +204,9 @@ def __init__(self, cfg: Config): self.data_scheme: List[str] = ["server_name", "partition_id"] self.verification_scheme: List[str] = ["server_name", "partition_id"] - self.corrupted_folder: str = None - self.corrupted_scheme: List[str] = ["uuid", "gbif_id", "server_name", "partition_id"] - self.corrupted_IO: TextIO = None + self.corrupted_folder: Optional[str] = None + self.corrupted_scheme: List[str] = ["uuid", "source_id", "server_name", "partition_id"] + self.corrupted_IO: Optional[TextIO] = None self.total_time = 150 def ensure_folders_created(self): diff --git a/src/distributed_downloader/tools/scheduler.py b/src/distributed_downloader/tools/scheduler.py index 285f070..7937910 100644 --- a/src/distributed_downloader/tools/scheduler.py +++ b/src/distributed_downloader/tools/scheduler.py @@ -1,6 +1,7 @@ import argparse import os +from distributed_downloader.tools import Checkpoint from distributed_downloader.tools.utils import init_logger from distributed_downloader.tools.config import Config from distributed_downloader.tools.registry import ToolsRegistryBase @@ -10,20 +11,46 @@ if config_path is None: raise ValueError("CONFIG_PATH not set") + parser = argparse.ArgumentParser(description="Running step of the Tool") + parser.add_argument( + "scheduler_name", + metavar="scheduler_name", + type=str, + help="the name of the tool that is intended to be used", + ) + _args = parser.parse_args() + tool_name = _args.scheduler_name + + assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError( + "unknown scheduler" + ) + config = Config.from_path(config_path, "tools") + checkpoint = Checkpoint.from_path( + os.path.join( + config.get_folder("tools_folder"), tool_name, "tool_checkpoint.yaml" + ), + { + "filtering_completed": False, + "scheduling_schedule": True, + "scheduling_completed": False, + }, + ) logger = init_logger(__name__) + checkpoint["scheduling_schedule"] = False - parser = argparse.ArgumentParser(description='Running step of the Tool') - parser.add_argument("scheduler_name", metavar="scheduler_name", type=str, - help="the name of the tool that is intended to be used") - _args = parser.parse_args() - tool_name = _args.scheduler_name + if not checkpoint.get("filtering_completed", False): + logger.error("Filtering wasn't complete, can't create schedule") + exit(1) - assert tool_name in ToolsRegistryBase.TOOLS_REGISTRY.keys(), ValueError("unknown scheduler") + tool_scheduler = ToolsRegistryBase.TOOLS_REGISTRY[tool_name]["scheduler"](config) - tool_filter = ToolsRegistryBase.TOOLS_REGISTRY[tool_name]["scheduler"](config) + if not checkpoint.get("scheduling_completed", False): + logger.info("Starting scheduler") + tool_scheduler.run() - logger.info("Starting scheduler") - tool_filter.run() + logger.info("completed scheduler") - logger.info("completed scheduler") + checkpoint["scheduling_completed"] = True + else: + logger.info("Scheduling was already completed") diff --git a/src/distributed_downloader/tools/schedulers.py b/src/distributed_downloader/tools/schedulers.py index efe8736..4b897be 100644 --- a/src/distributed_downloader/tools/schedulers.py +++ b/src/distributed_downloader/tools/schedulers.py @@ -1,6 +1,7 @@ import glob import os from functools import partial +from typing import List import pandas as pd @@ -28,16 +29,19 @@ class DefaultScheduler(SchedulerToolBase): def __init__(self, cfg: Config): super().__init__(cfg) + self.scheme: List[str] = ["server_name", "partition_id"] + def run(self): assert self.filter_name is not None, ValueError("filter name is not set") + assert self.scheme is not None, ValueError("Scheme was not set") filter_folder = os.path.join(self.tools_path, self.filter_name) filter_table_folder = os.path.join(filter_folder, "filter_table") all_files = glob.glob(os.path.join(filter_table_folder, "*.csv")) df: pd.DataFrame = pd.concat((pd.read_csv(f) for f in all_files), ignore_index=True) - df = df[["server_name", "partition_id"]] - df = df.drop_duplicates(subset=["server_name", "partition_id"]).reset_index(drop=True) + df = df[self.scheme] + df = df.drop_duplicates(subset=self.scheme).reset_index(drop=True) df["rank"] = df.index % self.total_workers df.to_csv(os.path.join(filter_folder, "schedule.csv"), header=True, index=False) diff --git a/src/distributed_downloader/tools/utils.py b/src/distributed_downloader/tools/utils.py index 4598328..e37c106 100644 --- a/src/distributed_downloader/tools/utils.py +++ b/src/distributed_downloader/tools/utils.py @@ -46,7 +46,7 @@ def infer_delimiter(_first_line): return df -def ensure_created(list_of_path: List[str]) -> None: +def ensure_created(list_of_path: Sequence[str]) -> None: for path in list_of_path: os.makedirs(path, exist_ok=True)