Skip to content

Commit

Permalink
s5cmd cat attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Jul 31, 2023
1 parent 121eee0 commit ebd5e8d
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 43 deletions.
Git LFS file not shown
Git LFS file not shown
77 changes: 46 additions & 31 deletions academic_observatory_workflows/s5cmd.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import re
import shlex
import logging
from dataclasses import dataclass
from typing import List, Tuple
from subprocess import Popen, PIPE
import logging
from contextlib import contextmanager
from tempfile import NamedTemporaryFile
import re
import shlex


@dataclass
Expand Down Expand Up @@ -37,15 +37,34 @@ def __init__(
self,
access_credentials: Tuple[str, str],
logging_level: str = "debug",
out_stream: str = PIPE,
cp_config: S5CmdCpConfig = None,
):
if not cp_config:
self.cp_config = S5CmdCpConfig()

self.access_credentials = access_credentials
self.logging_level = logging_level
self.output_stream = out_stream
self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://"

def _initialise_command(self, uri: str):
"""Initializes the command for the given bucket URI.
:param uri: The URI being accessed.
:return: The initialized command."""
cmd = "s5cmd"

# Check that the uri prefixes are supported
bucket_prefix = re.match(self.uri_identifier_regex, uri).group(0)
if bucket_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
if bucket_prefix == "gs://":
cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"])

return cmd

@contextmanager
def _bucket_credentials(self):
try:
Expand All @@ -58,9 +77,8 @@ def _bucket_credentials(self):
finally:
pass

def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None):
def download_from_bucket(self, uris: List[str], local_path: str):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
main_cmd = "s5cmd"

# Check the integrity of the supplied URIs
uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris]
Expand All @@ -74,15 +92,12 @@ def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
cmd = "s5cmd"
if uri_prefix == "gs://":
main_cmd += " --endpoint-url https://storage.googleapis.com"
cmd += " --endpoint-url https://storage.googleapis.com"
for i, uri in enumerate(uris):
uris[i] = uri.replace("gs://", "s3://")

# Configure the input and output streams
stdout = out_stream if out_stream else PIPE
stderr = out_stream if out_stream else PIPE

# Make the run commands
blob_cmds = []
for uri in uris:
Expand All @@ -94,38 +109,38 @@ def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None
blob_stdin = "\n".join(blob_cmds)
logging.info(f"s5cmd blob download command example: {blob_cmds[0]}")

# Initialise credentials and download
# Initialise credentials and execute
with self._bucket_credentials() as credentials:
main_cmd += f" --credentials-file {credentials} run"
logging.info(f"Executing download command: {main_cmd}")
proc = Popen(shlex.split(main_cmd), stdout=stdout, stderr=stderr, stdin=PIPE)
cmd += f" --credentials-file {credentials} run"
logging.info(f"Executing download command: {cmd}")
proc = Popen(shlex.split(cmd), stdout=self.out_stream, stderr=self.out_stream, stdin=PIPE)
proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode

def upload_to_bucket(self, files: List[str], bucket_uri: str, out_stream=None):
def upload_to_bucket(self, files: List[str], bucket_uri: str):
"""Downloads file(s) from a bucket using s5cmd and a supplied list of URIs."""
cmd = "s5cmd"

# Check that the uri prefixes are supported
bucket_prefix = re.match(self.uri_identifier_regex, bucket_uri).group(0)
if bucket_prefix not in ["gs://", "s3://"]:
raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}")

# Amend the URIs with the s3:// prefix and add endpoint URL if required
if bucket_prefix == "gs://":
cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"])

# Configure the input and output streams
stdout = out_stream if out_stream else PIPE
stderr = out_stream if out_stream else PIPE
cmd = self._initialise_command(bucket_uri)
blob_stdin = "\n".join(files)

# Initialise credentials and download
# Initialise credentials and execute
with self._bucket_credentials() as credentials:
logging.info(f"Executing download command: {cmd}")
cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=stdout, stderr=stderr)
proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream)
proc.communicate(input=blob_stdin.encode())
returncode = proc.wait()
return returncode

def cat(self, blob_uri: str) -> Tuple[int, bytes, bytes]:
"""Executes a s5cmd cat operation on a remote file"""
cmd = self._initialise_command(blob_uri)

# Initialise credentials and execute
with self._bucket_credentials() as credentials:
logging.info(f"Executing download command: {cmd}")
cmd = " ".join([cmd, f" --credentials_file {credentials} cat {blob_uri}"])
proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream)
stdout, stderr = proc.communicate()
returncode = proc.wait()
return returncode, stdout, stderr
105 changes: 95 additions & 10 deletions academic_observatory_workflows/workflows/orcid_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __init__(
self.add_task(self.create_manifests)

# Download and transform updated files
self.add_task(self.download)
# self.add_task(self.download)
self.add_task(self.transform)

# Load the data to table
Expand Down Expand Up @@ -467,14 +467,12 @@ def transform(self, release: OrcidRelease, **kwargs):
start_time = time.time()
for orcid_batch in release.orcid_batches():
batch_start = time.time()
logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}")
logging.info(f"Transforming ORCID batch using s5cmd cat {orcid_batch.batch_str}")
transformed_data = []
with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
futures = []
for record in orcid_batch.existing_records:
future = executor.submit(
transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record)
)
for record_uri in orcid_batch.blob_uris:
future = executor.submit(transform_orcid_record, record_uri, self.gcs_hmac_key)
futures.append(future)
for future in futures:
transformed_data.append(future.result())
Expand Down Expand Up @@ -508,6 +506,54 @@ def transform(self, release: OrcidRelease, **kwargs):
total_time = time.time() - start_time
logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}")

# def transform(self, release: OrcidRelease, **kwargs):
# """Transforms the downloaded files into serveral bigquery-compatible .jsonl files"""
# total_upsert_records = 0
# total_delete_records = 0
# start_time = time.time()
# for orcid_batch in release.orcid_batches():
# batch_start = time.time()
# logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}")
# transformed_data = []
# with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
# futures = []
# for record in orcid_batch.existing_records:
# future = executor.submit(
# transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record)
# )
# futures.append(future)
# for future in futures:
# transformed_data.append(future.result())

# # Save records to upsert
# batch_upserts = [record for record in transformed_data if isinstance(record, dict)]
# n_batch_upserts = len(batch_upserts)
# if n_batch_upserts > 0:
# save_jsonl_gz(orcid_batch.transform_upsert_file, batch_upserts)

# # Save records to delete
# batch_deletes = [{"id": record} for record in transformed_data if isinstance(record, str)]
# n_batch_deletes = len(batch_deletes)
# if n_batch_deletes > 0:
# save_jsonl_gz(orcid_batch.transform_delete_file, batch_deletes)

# # Check that the number of records matches the expected number of records
# total_records = n_batch_upserts + n_batch_deletes
# assert total_records == len(
# orcid_batch.expected_records
# ), f"Expected {len(orcid_batch.expected_records)} records but got {total_records} records ({n_batch_upserts} upserts | {n_batch_deletes} deletes)"

# # Record keeping
# total_upsert_records += n_batch_upserts
# total_delete_records += n_batch_deletes
# logging.info(
# f"Transformed {n_batch_upserts} upserts and {n_batch_deletes} deletes for batch {orcid_batch.batch_str}"
# )
# logging.info(f"Time taken for batch: {time.time() - batch_start} seconds")
# logging.info(f"Transformed {total_upsert_records} upserts and {total_delete_records} deletes in total")
# total_time = time.time() - start_time
# logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}")

def upload_transformed(self, release: OrcidRelease, **kwargs):
"""Uploads the upsert and delete files to the transform bucket."""
success = gcs_upload_files(bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.upsert_files)
Expand Down Expand Up @@ -716,7 +762,7 @@ def orcid_batch_names() -> List[str]:
return ["".join(i) for i in combinations]


def transform_orcid_record(record_path: str) -> Union(Dict, str):
def transform_orcid_record(record_uri: str, s5credentials: Tuple[str, str]) -> Union(Dict, str):
"""Transform a single ORCID file/record.
The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section.
The keys of the dictionary are slightly changed so they are valid BigQuery fields.
Expand All @@ -728,10 +774,13 @@ def transform_orcid_record(record_path: str) -> Union(Dict, str):
:raises AirflowException: If the record is not valid - no 'error' or 'record' section found
"""
# Get the orcid from the record path
expected_orcid = re.search(ORCID_REGEX, record_path).group(0)
expected_orcid = re.search(ORCID_REGEX, record_uri).group(0)
s5cmd = S5Cmd(access_credentials=s5credentials)
record = s5cmd.cat(record_uri)
orcid_dict = xmltodict.parse(record.decode("utf-8"))

with open(record_path, "rb") as f:
orcid_dict = xmltodict.parse(f)
# with open(record_path, "rb") as f:
# orcid_dict = xmltodict.parse(f)

try:
orcid_record = orcid_dict["record:record"]
Expand All @@ -752,6 +801,42 @@ def transform_orcid_record(record_path: str) -> Union(Dict, str):
return orcid_record


# def transform_orcid_record(record_path: str) -> Union(Dict, str):
# """Transform a single ORCID file/record.
# The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section.
# The keys of the dictionary are slightly changed so they are valid BigQuery fields.
# If the record is valid, it is returned. If it is an error, the ORCID ID is returned.

# :param download_path: The path to the file with the ORCID record.
# :param transform_folder: The path where transformed files will be saved.
# :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned.
# :raises AirflowException: If the record is not valid - no 'error' or 'record' section found
# """
# # Get the orcid from the record path
# expected_orcid = re.search(ORCID_REGEX, record_path).group(0)

# with open(record_path, "rb") as f:
# orcid_dict = xmltodict.parse(f)

# try:
# orcid_record = orcid_dict["record:record"]
# # Some records do not have a 'record', but only 'error'. We return the path to the file in this case.
# except KeyError:
# orcid_record = orcid_dict["error:error"]
# return expected_orcid

# # Check that the ORCID in the file name matches the ORCID in the record
# if not orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid:
# raise ValueError(
# f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}"
# )

# # Transform the keys of the dictionary so they are valid BigQuery fields
# orcid_record = change_keys(orcid_record)

# return orcid_record


def change_keys(obj):
"""Recursively goes through the dictionary obj and replaces keys with the convert function.
Expand Down
Loading

0 comments on commit ebd5e8d

Please sign in to comment.