Skip to content

Commit

Permalink
s5cmd integration
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Jul 27, 2023
1 parent 68ef2cd commit 121eee0
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 71 deletions.
131 changes: 131 additions & 0 deletions academic_observatory_workflows/s5cmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dataclasses import dataclass
from typing import List, Tuple
from subprocess import Popen, PIPE
import logging
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
import re
import shlex


@dataclass
class S5CmdCpConfig:
flatten_dir: bool = False
no_overwrite: bool = False
overwrite_if_size: bool = False
overwrite_if_newer: bool = False

@property
def cp_config_str(self):
cfg = [
"--flatten " if self.flatten_dir else "",
"--no-clobber " if self.no_overwrite else "",
"--if-size-differ " if self.overwrite_if_size else "",
"--if-source-newer " if self.overwrite_if_newer else "",
]
return "".join(cfg)

def __str__(self):
return self.cp_config_str

def __bool__(self):
return bool(self.cp_config_str)


class S5Cmd:
def __init__(
self,
access_credentials: Tuple[str, str],
logging_level: str = "debug",
cp_config: S5CmdCpConfig = None,
):
if not cp_config:
self.cp_config = S5CmdCpConfig()

self.access_credentials = access_credentials
self.logging_level = logging_level
self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://"

@contextmanager
def _bucket_credentials(self):
try:
with NamedTemporaryFile() as tmp:
with open(tmp.name, "w") as f:
f.write("[default]\n")
f.write(f"aws_access_key_id = {self.access_credentials[0]}\n")
f.write(f"aws_secret_access_key = {self.access_credentials[1]}\n")
yield tmp.name
finally:
pass

def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
main_cmd = "s5cmd"

# Check the integrity of the supplied URIs
uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris]
if None in uri_prefixes:
raise ValueError("All URIs must begin with a qualified bucket prefix.")
uri_prefixes = [prefix.group() for prefix in uri_prefixes]
if not len(set(uri_prefixes)) == 1:
raise ValueError(f"All URIs must begin with the same prefix. Found prefixes: {set(uri_prefixes)}")
uri_prefix = uri_prefixes[0]
if uri_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
if uri_prefix == "gs://":
main_cmd += " --endpoint-url https://storage.googleapis.com"
for i, uri in enumerate(uris):
uris[i] = uri.replace("gs://", "s3://")

# Configure the input and output streams
stdout = out_stream if out_stream else PIPE
stderr = out_stream if out_stream else PIPE

# Make the run commands
blob_cmds = []
for uri in uris:
blob_cmd = "cp"
if self.cp_config:
blob_cmd += f" {str(self.cp_config)}"
blob_cmd += f" {uri} {local_path}"
blob_cmds.append(blob_cmd)
blob_stdin = "\n".join(blob_cmds)
logging.info(f"s5cmd blob download command example: {blob_cmds[0]}")

# Initialise credentials and download
with self._bucket_credentials() as credentials:
main_cmd += f" --credentials-file {credentials} run"
logging.info(f"Executing download command: {main_cmd}")
proc = Popen(shlex.split(main_cmd), stdout=stdout, stderr=stderr, stdin=PIPE)
proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode

def upload_to_bucket(self, files: List[str], bucket_uri: str, out_stream=None):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
cmd = "s5cmd"

# Check that the uri prefixes are supported
bucket_prefix = re.match(self.uri_identifier_regex, bucket_uri).group(0)
if bucket_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
if bucket_prefix == "gs://":
cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"])

# Configure the input and output streams
stdout = out_stream if out_stream else PIPE
stderr = out_stream if out_stream else PIPE
blob_stdin = "\n".join(files)

# Initialise credentials and download
with self._bucket_credentials() as credentials:
logging.info(f"Executing download command: {cmd}")
cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=stdout, stderr=stderr)
proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode
123 changes: 52 additions & 71 deletions academic_observatory_workflows/workflows/orcid_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import time
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from datetime import timedelta
from typing import List, Dict, Tuple, Union
import itertools
Expand All @@ -39,6 +39,7 @@


from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag
from academic_observatory_workflows.s5cmd import S5Cmd
from observatory.api.client.model.dataset_release import DatasetRelease
from observatory.platform.airflow import PreviousDagRunSensor, is_first_dag_run
from observatory.platform.api import get_dataset_releases, get_latest_dataset_release
Expand All @@ -59,6 +60,7 @@
gcs_blob_uri,
gcs_blob_name_from_path,
gcs_create_aws_transfer,
gcs_list_blobs,
)
from observatory.platform.observatory_config import CloudWorkspace
from observatory.platform.workflows.workflow import (
Expand Down Expand Up @@ -227,6 +229,7 @@ def __init__(
api_dataset_id: str = "orcid",
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API,
aws_orcid_conn_id: str = "aws_orcid",
gcs_hmac_conn_id: str = "gcs_hmac",
start_date: pendulum.DateTime = pendulum.datetime(2023, 6, 1),
schedule_interval: str = "@weekly",
queue: str = "default", # TODO: remote_queue
Expand Down Expand Up @@ -261,6 +264,7 @@ def __init__(
self.max_workers = max_workers
self.observatory_api_conn_id = observatory_api_conn_id
self.aws_orcid_conn_id = aws_orcid_conn_id
self.gcs_hmac_conn_id = gcs_hmac_conn_id

external_task_id = "dag_run_complete"
self.add_operator(
Expand Down Expand Up @@ -307,7 +311,14 @@ def aws_orcid_key(self) -> Tuple[str, str]:
connection = BaseHook.get_connection(self.aws_orcid_conn_id)
return connection.login, connection.password

@property
def gcs_hmac_key(self) -> Tuple[str, str]:
"""Return GCS HMAC access key ID and secret"""
connection = BaseHook.get_connection(self.gcs_hmac_conn_id)
return connection.login, connection.password

def make_release(self, **kwargs) -> OrcidRelease:
"""Generates the OrcidRelease object."""
dag_run = kwargs["dag_run"]
is_first_run = is_first_dag_run(dag_run)
releases = get_dataset_releases(dag_id=self.dag_id, dataset_id=self.api_dataset_id)
Expand Down Expand Up @@ -420,40 +431,34 @@ def create_manifests(self, release: OrcidRelease, **kwargs):

def download(self, release: OrcidRelease, **kwargs):
"""Reads each batch's manifest and downloads the files from the gcs bucket."""
start_time = time.time()
total_files = 0
start_time = time.time()
for orcid_batch in release.orcid_batches():
if not orcid_batch.missing_records:
logging.info(f"All files present for {orcid_batch.batch_str}. Skipping download.")
continue
all_files_present = False

# Loop - download and assert all files downloaded
batch_start = time.time()
logging.info(f"Downloading files for ORCID directory: {orcid_batch.batch_str}")
for i in range(3): # Try up to 3 times
returncode = gsutil_download(orcid_batch=orcid_batch)
if returncode != 0:
logging.warn(
f"Attempt {i+1} for '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}"
)
continue
if orcid_batch.missing_records:
logging.warn(
f"Attempt {i+1} for '{orcid_batch.batch_str}': {len(orcid_batch.missing_records)} files missing"
)
continue
else:
all_files_present = True
break

if not all_files_present:
raise AirflowException(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.")
s5cmd = S5Cmd(access_credentials=self.gcs_hmac_key)
print(orcid_batch.blob_uris[0])
with open(orcid_batch.download_log_file, "w") as f:
returncode = s5cmd.download_from_bucket(
uris=orcid_batch.blob_uris, local_path=orcid_batch.download_batch_dir, out_stream=f
)
# returncode = gsutil_download(orcid_batch=orcid_batch)
if returncode != 0:
raise RuntimeError(
f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}"
)
if orcid_batch.missing_records:
raise FileNotFoundError(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.")

total_files += len(orcid_batch.expected_records)
logging.info(f"Download for '{orcid_batch.batch_str}' completed successfully.")
logging.info(f"Downloaded {len(orcid_batch.expected_records)} in {time.time() - batch_start} seconds")
logging.info(f"Completed download for {total_files} files in {(time.time() - start_time)/3600} hours")
logging.info(
f"Downloaded {len(orcid_batch.expected_records)} records for batch '{orcid_batch.batch_str}' completed successfully."
)
total_time = time.time() - start_time
logging.info(f"Completed download for {total_files} files in {str(datetime.timedelta(seconds=total_time))}")

def transform(self, release: OrcidRelease, **kwargs):
"""Transforms the downloaded files into serveral bigquery-compatible .jsonl files"""
Expand All @@ -464,7 +469,7 @@ def transform(self, release: OrcidRelease, **kwargs):
batch_start = time.time()
logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}")
transformed_data = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for record in orcid_batch.existing_records:
future = executor.submit(
Expand Down Expand Up @@ -500,7 +505,8 @@ def transform(self, release: OrcidRelease, **kwargs):
)
logging.info(f"Time taken for batch: {time.time() - batch_start} seconds")
logging.info(f"Transformed {total_upsert_records} upserts and {total_delete_records} deletes in total")
logging.info(f"Time taken for all batches: {(time.time() - start_time)/3600} hours")
total_time = time.time() - start_time
logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}")

def upload_transformed(self, release: OrcidRelease, **kwargs):
"""Uploads the upsert and delete files to the transform bucket."""
Expand Down Expand Up @@ -627,6 +633,7 @@ def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None:

def cleanup(self, release: OrcidRelease, **kwargs) -> None:
"""Delete all files, folders and XComs associated with this release."""
return
cleanup(dag_id=self.dag_id, execution_date=kwargs["logical_date"], workflow_folder=release.workflow_folder)


Expand Down Expand Up @@ -709,17 +716,7 @@ def orcid_batch_names() -> List[str]:
return ["".join(i) for i in combinations]


def gcs_list_blobs(bucket_name: str, prefix: str = None) -> List[storage.Blob]:
"""List blobs in a bucket using a gcs_uri.
:param bucket_name: The name of the bucket
:param prefix: The prefix to filter by
"""
storage_client = storage.Client()
return list(storage_client.list_blobs(bucket_name, prefix=prefix))


def transform_orcid_record(record_path: str) -> Union(Dict[str, any], str):
def transform_orcid_record(record_path: str) -> Union(Dict, str):
"""Transform a single ORCID file/record.
The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section.
The keys of the dictionary are slightly changed so they are valid BigQuery fields.
Expand All @@ -728,62 +725,46 @@ def transform_orcid_record(record_path: str) -> Union(Dict[str, any], str):
:param download_path: The path to the file with the ORCID record.
:param transform_folder: The path where transformed files will be saved.
:return: The transformed ORCID record. If the record is an error, the ORCID ID is returned.
:raises AirflowException: If the record is not valid - no 'error' or 'record' section found
"""
# Get the orcid from the record path
expected_orcid = re.search(ORCID_REGEX, record_path).group(0)

with open(record_path, "rb") as f:
orcid_dict = xmltodict.parse(f)
orcid_record = orcid_dict.get("record:record")

try:
orcid_record = orcid_dict["record:record"]
# Some records do not have a 'record', but only 'error'. We return the path to the file in this case.
if not orcid_record:
orcid_record = orcid_dict.get("error:error")
if not orcid_record:
raise AirflowException(f"Key error for file: {record_path}")
except KeyError:
orcid_record = orcid_dict["error:error"]
return expected_orcid

# Check that the ORCID in the file name matches the ORCID in the record
assert (
orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid
), f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}"
if not orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid:
raise ValueError(
f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}"
)

# Transform the keys of the dictionary so they are valid BigQuery fields
orcid_record = change_keys(orcid_record, convert)
orcid_record = change_keys(orcid_record)

return orcid_record


def change_keys(obj, convert):
def change_keys(obj):
"""Recursively goes through the dictionary obj and replaces keys with the convert function.
:param obj: The dictionary value, can be object of any type
:param convert: The convert function.
:param obj: The dictionary value, can be an object of any type
:return: The transformed object.
"""
if isinstance(obj, (str, int, float)):
return obj

convert_key = lambda k: k.split(":")[-1].lstrip("@#").replace("-", "_")
if isinstance(obj, dict):
new = obj.__class__()
for k, v in list(obj.items()):
if k.startswith("@xmlns"):
continue
new[convert(k)] = change_keys(v, convert)
return {convert_key(k): change_keys(v) for k, v in obj.items() if not k.startswith("@xmlns")}
elif isinstance(obj, (list, set, tuple)):
new = obj.__class__(change_keys(v, convert) for v in obj)
return obj.__class__(change_keys(v) for v in obj)
else:
return obj
return new


def convert(k: str) -> str:
"""Convert key of dictionary to valid BQ key.
:param k: Key
:return: The converted key
"""
k = k.split(":")[-1]
if k.startswith("@") or k.startswith("#"):
k = k[1:]
k = k.replace("-", "_")
return k
7 changes: 7 additions & 0 deletions requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ apt-get install pigz unzip -y
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt cloud-sdk main" | \
tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | \
apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && apt-get update -y && apt-get install google-cloud-sdk -y

# ORCID install s5cmd
curl -LO https://github.com/peak/s5cmd/releases/download/v2.1.0/s5cmd_2.1.0_linux_amd64.deb && \
dpkg -i s5cmd_2.1.0_linux_amd64.deb
#chmod +x s5cmd_2.1.0_linux_amd64.deb | \
#apt install ./s5cmd_2.1.0_linux_amd64.deb -y

0 comments on commit 121eee0

Please sign in to comment.