From 121eee08d3c82d2d9113f2e40f0892e3ca39375f Mon Sep 17 00:00:00 2001 From: Keegan Date: Thu, 27 Jul 2023 09:07:54 +0000 Subject: [PATCH] s5cmd integration --- academic_observatory_workflows/s5cmd.py | 131 ++++++++++++++++++ .../workflows/orcid_telescope.py | 123 +++++++--------- requirements.sh | 7 + 3 files changed, 190 insertions(+), 71 deletions(-) create mode 100644 academic_observatory_workflows/s5cmd.py diff --git a/academic_observatory_workflows/s5cmd.py b/academic_observatory_workflows/s5cmd.py new file mode 100644 index 000000000..849c59e1f --- /dev/null +++ b/academic_observatory_workflows/s5cmd.py @@ -0,0 +1,131 @@ +from dataclasses import dataclass +from typing import List, Tuple +from subprocess import Popen, PIPE +import logging +from contextlib import contextmanager +from tempfile import NamedTemporaryFile +import re +import shlex + + +@dataclass +class S5CmdCpConfig: + flatten_dir: bool = False + no_overwrite: bool = False + overwrite_if_size: bool = False + overwrite_if_newer: bool = False + + @property + def cp_config_str(self): + cfg = [ + "--flatten " if self.flatten_dir else "", + "--no-clobber " if self.no_overwrite else "", + "--if-size-differ " if self.overwrite_if_size else "", + "--if-source-newer " if self.overwrite_if_newer else "", + ] + return "".join(cfg) + + def __str__(self): + return self.cp_config_str + + def __bool__(self): + return bool(self.cp_config_str) + + +class S5Cmd: + def __init__( + self, + access_credentials: Tuple[str, str], + logging_level: str = "debug", + cp_config: S5CmdCpConfig = None, + ): + if not cp_config: + self.cp_config = S5CmdCpConfig() + + self.access_credentials = access_credentials + self.logging_level = logging_level + self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://" + + @contextmanager + def _bucket_credentials(self): + try: + with NamedTemporaryFile() as tmp: + with open(tmp.name, "w") as f: + f.write("[default]\n") + f.write(f"aws_access_key_id = {self.access_credentials[0]}\n") + f.write(f"aws_secret_access_key = {self.access_credentials[1]}\n") + yield tmp.name + finally: + pass + + def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None): + """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" + main_cmd = "s5cmd" + + # Check the integrity of the supplied URIs + uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris] + if None in uri_prefixes: + raise ValueError("All URIs must begin with a qualified bucket prefix.") + uri_prefixes = [prefix.group() for prefix in uri_prefixes] + if not len(set(uri_prefixes)) == 1: + raise ValueError(f"All URIs must begin with the same prefix. Found prefixes: {set(uri_prefixes)}") + uri_prefix = uri_prefixes[0] + if uri_prefix not in ["gs://", "s3://"]: + raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}") + + # Amend the URIs with the s3:// prefix and add endpoint URL if required + if uri_prefix == "gs://": + main_cmd += " --endpoint-url https://storage.googleapis.com" + for i, uri in enumerate(uris): + uris[i] = uri.replace("gs://", "s3://") + + # Configure the input and output streams + stdout = out_stream if out_stream else PIPE + stderr = out_stream if out_stream else PIPE + + # Make the run commands + blob_cmds = [] + for uri in uris: + blob_cmd = "cp" + if self.cp_config: + blob_cmd += f" {str(self.cp_config)}" + blob_cmd += f" {uri} {local_path}" + blob_cmds.append(blob_cmd) + blob_stdin = "\n".join(blob_cmds) + logging.info(f"s5cmd blob download command example: {blob_cmds[0]}") + + # Initialise credentials and download + with self._bucket_credentials() as credentials: + main_cmd += f" --credentials-file {credentials} run" + logging.info(f"Executing download command: {main_cmd}") + proc = Popen(shlex.split(main_cmd), stdout=stdout, stderr=stderr, stdin=PIPE) + proc.communicate(input=blob_stdin.encode()) + returncode = proc.wait() + return returncode + + def upload_to_bucket(self, files: List[str], bucket_uri: str, out_stream=None): + """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" + cmd = "s5cmd" + + # Check that the uri prefixes are supported + bucket_prefix = re.match(self.uri_identifier_regex, bucket_uri).group(0) + if bucket_prefix not in ["gs://", "s3://"]: + raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}") + + # Amend the URIs with the s3:// prefix and add endpoint URL if required + if bucket_prefix == "gs://": + cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"]) + + # Configure the input and output streams + stdout = out_stream if out_stream else PIPE + stderr = out_stream if out_stream else PIPE + blob_stdin = "\n".join(files) + + # Initialise credentials and download + with self._bucket_credentials() as credentials: + logging.info(f"Executing download command: {cmd}") + cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"]) + proc = Popen(shlex.split(cmd), shell=False, stdout=stdout, stderr=stderr) + proc.communicate(input=blob_stdin.encode()) + returncode = proc.wait() + return returncode diff --git a/academic_observatory_workflows/workflows/orcid_telescope.py b/academic_observatory_workflows/workflows/orcid_telescope.py index 2a522ade4..0caef2a72 100644 --- a/academic_observatory_workflows/workflows/orcid_telescope.py +++ b/academic_observatory_workflows/workflows/orcid_telescope.py @@ -21,7 +21,7 @@ import time import logging import os -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from datetime import timedelta from typing import List, Dict, Tuple, Union import itertools @@ -39,6 +39,7 @@ from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag +from academic_observatory_workflows.s5cmd import S5Cmd from observatory.api.client.model.dataset_release import DatasetRelease from observatory.platform.airflow import PreviousDagRunSensor, is_first_dag_run from observatory.platform.api import get_dataset_releases, get_latest_dataset_release @@ -59,6 +60,7 @@ gcs_blob_uri, gcs_blob_name_from_path, gcs_create_aws_transfer, + gcs_list_blobs, ) from observatory.platform.observatory_config import CloudWorkspace from observatory.platform.workflows.workflow import ( @@ -227,6 +229,7 @@ def __init__( api_dataset_id: str = "orcid", observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API, aws_orcid_conn_id: str = "aws_orcid", + gcs_hmac_conn_id: str = "gcs_hmac", start_date: pendulum.DateTime = pendulum.datetime(2023, 6, 1), schedule_interval: str = "@weekly", queue: str = "default", # TODO: remote_queue @@ -261,6 +264,7 @@ def __init__( self.max_workers = max_workers self.observatory_api_conn_id = observatory_api_conn_id self.aws_orcid_conn_id = aws_orcid_conn_id + self.gcs_hmac_conn_id = gcs_hmac_conn_id external_task_id = "dag_run_complete" self.add_operator( @@ -307,7 +311,14 @@ def aws_orcid_key(self) -> Tuple[str, str]: connection = BaseHook.get_connection(self.aws_orcid_conn_id) return connection.login, connection.password + @property + def gcs_hmac_key(self) -> Tuple[str, str]: + """Return GCS HMAC access key ID and secret""" + connection = BaseHook.get_connection(self.gcs_hmac_conn_id) + return connection.login, connection.password + def make_release(self, **kwargs) -> OrcidRelease: + """Generates the OrcidRelease object.""" dag_run = kwargs["dag_run"] is_first_run = is_first_dag_run(dag_run) releases = get_dataset_releases(dag_id=self.dag_id, dataset_id=self.api_dataset_id) @@ -420,40 +431,34 @@ def create_manifests(self, release: OrcidRelease, **kwargs): def download(self, release: OrcidRelease, **kwargs): """Reads each batch's manifest and downloads the files from the gcs bucket.""" - start_time = time.time() total_files = 0 + start_time = time.time() for orcid_batch in release.orcid_batches(): if not orcid_batch.missing_records: logging.info(f"All files present for {orcid_batch.batch_str}. Skipping download.") continue - all_files_present = False - # Loop - download and assert all files downloaded - batch_start = time.time() logging.info(f"Downloading files for ORCID directory: {orcid_batch.batch_str}") - for i in range(3): # Try up to 3 times - returncode = gsutil_download(orcid_batch=orcid_batch) - if returncode != 0: - logging.warn( - f"Attempt {i+1} for '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}" - ) - continue - if orcid_batch.missing_records: - logging.warn( - f"Attempt {i+1} for '{orcid_batch.batch_str}': {len(orcid_batch.missing_records)} files missing" - ) - continue - else: - all_files_present = True - break - - if not all_files_present: - raise AirflowException(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.") + s5cmd = S5Cmd(access_credentials=self.gcs_hmac_key) + print(orcid_batch.blob_uris[0]) + with open(orcid_batch.download_log_file, "w") as f: + returncode = s5cmd.download_from_bucket( + uris=orcid_batch.blob_uris, local_path=orcid_batch.download_batch_dir, out_stream=f + ) + # returncode = gsutil_download(orcid_batch=orcid_batch) + if returncode != 0: + raise RuntimeError( + f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}" + ) + if orcid_batch.missing_records: + raise FileNotFoundError(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.") total_files += len(orcid_batch.expected_records) - logging.info(f"Download for '{orcid_batch.batch_str}' completed successfully.") - logging.info(f"Downloaded {len(orcid_batch.expected_records)} in {time.time() - batch_start} seconds") - logging.info(f"Completed download for {total_files} files in {(time.time() - start_time)/3600} hours") + logging.info( + f"Downloaded {len(orcid_batch.expected_records)} records for batch '{orcid_batch.batch_str}' completed successfully." + ) + total_time = time.time() - start_time + logging.info(f"Completed download for {total_files} files in {str(datetime.timedelta(seconds=total_time))}") def transform(self, release: OrcidRelease, **kwargs): """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" @@ -464,7 +469,7 @@ def transform(self, release: OrcidRelease, **kwargs): batch_start = time.time() logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}") transformed_data = [] - with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: futures = [] for record in orcid_batch.existing_records: future = executor.submit( @@ -500,7 +505,8 @@ def transform(self, release: OrcidRelease, **kwargs): ) logging.info(f"Time taken for batch: {time.time() - batch_start} seconds") logging.info(f"Transformed {total_upsert_records} upserts and {total_delete_records} deletes in total") - logging.info(f"Time taken for all batches: {(time.time() - start_time)/3600} hours") + total_time = time.time() - start_time + logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}") def upload_transformed(self, release: OrcidRelease, **kwargs): """Uploads the upsert and delete files to the transform bucket.""" @@ -627,6 +633,7 @@ def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None: def cleanup(self, release: OrcidRelease, **kwargs) -> None: """Delete all files, folders and XComs associated with this release.""" + return cleanup(dag_id=self.dag_id, execution_date=kwargs["logical_date"], workflow_folder=release.workflow_folder) @@ -709,17 +716,7 @@ def orcid_batch_names() -> List[str]: return ["".join(i) for i in combinations] -def gcs_list_blobs(bucket_name: str, prefix: str = None) -> List[storage.Blob]: - """List blobs in a bucket using a gcs_uri. - - :param bucket_name: The name of the bucket - :param prefix: The prefix to filter by - """ - storage_client = storage.Client() - return list(storage_client.list_blobs(bucket_name, prefix=prefix)) - - -def transform_orcid_record(record_path: str) -> Union(Dict[str, any], str): +def transform_orcid_record(record_path: str) -> Union(Dict, str): """Transform a single ORCID file/record. The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. The keys of the dictionary are slightly changed so they are valid BigQuery fields. @@ -728,62 +725,46 @@ def transform_orcid_record(record_path: str) -> Union(Dict[str, any], str): :param download_path: The path to the file with the ORCID record. :param transform_folder: The path where transformed files will be saved. :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned. + :raises AirflowException: If the record is not valid - no 'error' or 'record' section found """ # Get the orcid from the record path expected_orcid = re.search(ORCID_REGEX, record_path).group(0) with open(record_path, "rb") as f: orcid_dict = xmltodict.parse(f) - orcid_record = orcid_dict.get("record:record") + try: + orcid_record = orcid_dict["record:record"] # Some records do not have a 'record', but only 'error'. We return the path to the file in this case. - if not orcid_record: - orcid_record = orcid_dict.get("error:error") - if not orcid_record: - raise AirflowException(f"Key error for file: {record_path}") + except KeyError: + orcid_record = orcid_dict["error:error"] return expected_orcid # Check that the ORCID in the file name matches the ORCID in the record - assert ( - orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid - ), f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}" + if not orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid: + raise ValueError( + f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}" + ) # Transform the keys of the dictionary so they are valid BigQuery fields - orcid_record = change_keys(orcid_record, convert) + orcid_record = change_keys(orcid_record) return orcid_record -def change_keys(obj, convert): +def change_keys(obj): """Recursively goes through the dictionary obj and replaces keys with the convert function. - :param obj: The dictionary value, can be object of any type - :param convert: The convert function. + :param obj: The dictionary value, can be an object of any type :return: The transformed object. """ if isinstance(obj, (str, int, float)): return obj + + convert_key = lambda k: k.split(":")[-1].lstrip("@#").replace("-", "_") if isinstance(obj, dict): - new = obj.__class__() - for k, v in list(obj.items()): - if k.startswith("@xmlns"): - continue - new[convert(k)] = change_keys(v, convert) + return {convert_key(k): change_keys(v) for k, v in obj.items() if not k.startswith("@xmlns")} elif isinstance(obj, (list, set, tuple)): - new = obj.__class__(change_keys(v, convert) for v in obj) + return obj.__class__(change_keys(v) for v in obj) else: return obj - return new - - -def convert(k: str) -> str: - """Convert key of dictionary to valid BQ key. - - :param k: Key - :return: The converted key - """ - k = k.split(":")[-1] - if k.startswith("@") or k.startswith("#"): - k = k[1:] - k = k.replace("-", "_") - return k diff --git a/requirements.sh b/requirements.sh index b46fdecc2..a0d7d8ac1 100644 --- a/requirements.sh +++ b/requirements.sh @@ -26,3 +26,10 @@ apt-get install pigz unzip -y echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | \ tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \ apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && apt-get update -y && apt-get install google-cloud-sdk -y + +# ORCID install s5cmd +curl -LO https://github.com/peak/s5cmd/releases/download/v2.1.0/s5cmd_2.1.0_linux_amd64.deb && \ +dpkg -i s5cmd_2.1.0_linux_amd64.deb +#chmod +x s5cmd_2.1.0_linux_amd64.deb | \ +#apt install ./s5cmd_2.1.0_linux_amd64.deb -y +