diff --git a/academic_observatory_workflows/fixtures/orcid/0000-0001-5010-1000.xml b/academic_observatory_workflows/fixtures/orcid/0000-0001-5010-1000.xml index eabd6bd24..ffd4c5e98 100644 --- a/academic_observatory_workflows/fixtures/orcid/0000-0001-5010-1000.xml +++ b/academic_observatory_workflows/fixtures/orcid/0000-0001-5010-1000.xml @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f48d49d096a100d209a695346af2987eb316813639251e8266445bc91a0be16f -size 13480 +oid sha256:865b44007c0171581e4649cd1396d7828191522f6f569db28d8eb43d8ff8244e +size 1253 diff --git a/academic_observatory_workflows/fixtures/orcid/0000-0001-5011-1000.xml b/academic_observatory_workflows/fixtures/orcid/0000-0001-5011-1000.xml new file mode 100644 index 000000000..b640efc09 --- /dev/null +++ b/academic_observatory_workflows/fixtures/orcid/0000-0001-5011-1000.xml @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c16034b0bf8a4495d3ebee5d9536d88d5451818a77bef52b307af222cc7c4af +size 1479 diff --git a/academic_observatory_workflows/s5cmd.py b/academic_observatory_workflows/s5cmd.py index 849c59e1f..ee0c0ffdd 100644 --- a/academic_observatory_workflows/s5cmd.py +++ b/academic_observatory_workflows/s5cmd.py @@ -1,11 +1,11 @@ +import re +import shlex +import logging from dataclasses import dataclass from typing import List, Tuple from subprocess import Popen, PIPE -import logging from contextlib import contextmanager from tempfile import NamedTemporaryFile -import re -import shlex @dataclass @@ -37,6 +37,7 @@ def __init__( self, access_credentials: Tuple[str, str], logging_level: str = "debug", + out_stream: str = PIPE, cp_config: S5CmdCpConfig = None, ): if not cp_config: @@ -44,8 +45,26 @@ def __init__( self.access_credentials = access_credentials self.logging_level = logging_level + self.output_stream = out_stream self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://" + def _initialise_command(self, uri: str): + """Initializes the command for the given bucket URI. + :param uri: The URI being accessed. + :return: The initialized command.""" + cmd = "s5cmd" + + # Check that the uri prefixes are supported + bucket_prefix = re.match(self.uri_identifier_regex, uri).group(0) + if bucket_prefix not in ["gs://", "s3://"]: + raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}") + + # Amend the URIs with the s3:// prefix and add endpoint URL if required + if bucket_prefix == "gs://": + cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"]) + + return cmd + @contextmanager def _bucket_credentials(self): try: @@ -58,9 +77,8 @@ def _bucket_credentials(self): finally: pass - def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None): + def download_from_bucket(self, uris: List[str], local_path: str): """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" - main_cmd = "s5cmd" # Check the integrity of the supplied URIs uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris] @@ -74,15 +92,12 @@ def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}") # Amend the URIs with the s3:// prefix and add endpoint URL if required + cmd = "s5cmd" if uri_prefix == "gs://": - main_cmd += " --endpoint-url https://storage.googleapis.com" + cmd += " --endpoint-url https://storage.googleapis.com" for i, uri in enumerate(uris): uris[i] = uri.replace("gs://", "s3://") - # Configure the input and output streams - stdout = out_stream if out_stream else PIPE - stderr = out_stream if out_stream else PIPE - # Make the run commands blob_cmds = [] for uri in uris: @@ -94,38 +109,38 @@ def download_from_bucket(self, uris: List[str], local_path: str, out_stream=None blob_stdin = "\n".join(blob_cmds) logging.info(f"s5cmd blob download command example: {blob_cmds[0]}") - # Initialise credentials and download + # Initialise credentials and execute with self._bucket_credentials() as credentials: - main_cmd += f" --credentials-file {credentials} run" - logging.info(f"Executing download command: {main_cmd}") - proc = Popen(shlex.split(main_cmd), stdout=stdout, stderr=stderr, stdin=PIPE) + cmd += f" --credentials-file {credentials} run" + logging.info(f"Executing download command: {cmd}") + proc = Popen(shlex.split(cmd), stdout=self.out_stream, stderr=self.out_stream, stdin=PIPE) proc.communicate(input=blob_stdin.encode()) returncode = proc.wait() return returncode - def upload_to_bucket(self, files: List[str], bucket_uri: str, out_stream=None): + def upload_to_bucket(self, files: List[str], bucket_uri: str): """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" - cmd = "s5cmd" - - # Check that the uri prefixes are supported - bucket_prefix = re.match(self.uri_identifier_regex, bucket_uri).group(0) - if bucket_prefix not in ["gs://", "s3://"]: - raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}") - - # Amend the URIs with the s3:// prefix and add endpoint URL if required - if bucket_prefix == "gs://": - cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"]) - - # Configure the input and output streams - stdout = out_stream if out_stream else PIPE - stderr = out_stream if out_stream else PIPE + cmd = self._initialise_command(bucket_uri) blob_stdin = "\n".join(files) - # Initialise credentials and download + # Initialise credentials and execute with self._bucket_credentials() as credentials: logging.info(f"Executing download command: {cmd}") cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"]) - proc = Popen(shlex.split(cmd), shell=False, stdout=stdout, stderr=stderr) + proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream) proc.communicate(input=blob_stdin.encode()) returncode = proc.wait() return returncode + + def cat(self, blob_uri: str) -> Tuple[int, bytes, bytes]: + """Executes a s5cmd cat operation on a remote file""" + cmd = self._initialise_command(blob_uri) + + # Initialise credentials and execute + with self._bucket_credentials() as credentials: + logging.info(f"Executing download command: {cmd}") + cmd = " ".join([cmd, f" --credentials_file {credentials} cat {blob_uri}"]) + proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream) + stdout, stderr = proc.communicate() + returncode = proc.wait() + return returncode, stdout, stderr diff --git a/academic_observatory_workflows/workflows/orcid_telescope.py b/academic_observatory_workflows/workflows/orcid_telescope.py index 0caef2a72..e009ee59b 100644 --- a/academic_observatory_workflows/workflows/orcid_telescope.py +++ b/academic_observatory_workflows/workflows/orcid_telescope.py @@ -287,7 +287,7 @@ def __init__( self.add_task(self.create_manifests) # Download and transform updated files - self.add_task(self.download) + # self.add_task(self.download) self.add_task(self.transform) # Load the data to table @@ -467,14 +467,12 @@ def transform(self, release: OrcidRelease, **kwargs): start_time = time.time() for orcid_batch in release.orcid_batches(): batch_start = time.time() - logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}") + logging.info(f"Transforming ORCID batch using s5cmd cat {orcid_batch.batch_str}") transformed_data = [] with ProcessPoolExecutor(max_workers=self.max_workers) as executor: futures = [] - for record in orcid_batch.existing_records: - future = executor.submit( - transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record) - ) + for record_uri in orcid_batch.blob_uris: + future = executor.submit(transform_orcid_record, record_uri, self.gcs_hmac_key) futures.append(future) for future in futures: transformed_data.append(future.result()) @@ -508,6 +506,54 @@ def transform(self, release: OrcidRelease, **kwargs): total_time = time.time() - start_time logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}") + # def transform(self, release: OrcidRelease, **kwargs): + # """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" + # total_upsert_records = 0 + # total_delete_records = 0 + # start_time = time.time() + # for orcid_batch in release.orcid_batches(): + # batch_start = time.time() + # logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}") + # transformed_data = [] + # with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + # futures = [] + # for record in orcid_batch.existing_records: + # future = executor.submit( + # transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record) + # ) + # futures.append(future) + # for future in futures: + # transformed_data.append(future.result()) + + # # Save records to upsert + # batch_upserts = [record for record in transformed_data if isinstance(record, dict)] + # n_batch_upserts = len(batch_upserts) + # if n_batch_upserts > 0: + # save_jsonl_gz(orcid_batch.transform_upsert_file, batch_upserts) + + # # Save records to delete + # batch_deletes = [{"id": record} for record in transformed_data if isinstance(record, str)] + # n_batch_deletes = len(batch_deletes) + # if n_batch_deletes > 0: + # save_jsonl_gz(orcid_batch.transform_delete_file, batch_deletes) + + # # Check that the number of records matches the expected number of records + # total_records = n_batch_upserts + n_batch_deletes + # assert total_records == len( + # orcid_batch.expected_records + # ), f"Expected {len(orcid_batch.expected_records)} records but got {total_records} records ({n_batch_upserts} upserts | {n_batch_deletes} deletes)" + + # # Record keeping + # total_upsert_records += n_batch_upserts + # total_delete_records += n_batch_deletes + # logging.info( + # f"Transformed {n_batch_upserts} upserts and {n_batch_deletes} deletes for batch {orcid_batch.batch_str}" + # ) + # logging.info(f"Time taken for batch: {time.time() - batch_start} seconds") + # logging.info(f"Transformed {total_upsert_records} upserts and {total_delete_records} deletes in total") + # total_time = time.time() - start_time + # logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}") + def upload_transformed(self, release: OrcidRelease, **kwargs): """Uploads the upsert and delete files to the transform bucket.""" success = gcs_upload_files(bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.upsert_files) @@ -716,7 +762,7 @@ def orcid_batch_names() -> List[str]: return ["".join(i) for i in combinations] -def transform_orcid_record(record_path: str) -> Union(Dict, str): +def transform_orcid_record(record_uri: str, s5credentials: Tuple[str, str]) -> Union(Dict, str): """Transform a single ORCID file/record. The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. The keys of the dictionary are slightly changed so they are valid BigQuery fields. @@ -728,10 +774,13 @@ def transform_orcid_record(record_path: str) -> Union(Dict, str): :raises AirflowException: If the record is not valid - no 'error' or 'record' section found """ # Get the orcid from the record path - expected_orcid = re.search(ORCID_REGEX, record_path).group(0) + expected_orcid = re.search(ORCID_REGEX, record_uri).group(0) + s5cmd = S5Cmd(access_credentials=s5credentials) + record = s5cmd.cat(record_uri) + orcid_dict = xmltodict.parse(record.decode("utf-8")) - with open(record_path, "rb") as f: - orcid_dict = xmltodict.parse(f) + # with open(record_path, "rb") as f: + # orcid_dict = xmltodict.parse(f) try: orcid_record = orcid_dict["record:record"] @@ -752,6 +801,42 @@ def transform_orcid_record(record_path: str) -> Union(Dict, str): return orcid_record +# def transform_orcid_record(record_path: str) -> Union(Dict, str): +# """Transform a single ORCID file/record. +# The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. +# The keys of the dictionary are slightly changed so they are valid BigQuery fields. +# If the record is valid, it is returned. If it is an error, the ORCID ID is returned. + +# :param download_path: The path to the file with the ORCID record. +# :param transform_folder: The path where transformed files will be saved. +# :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned. +# :raises AirflowException: If the record is not valid - no 'error' or 'record' section found +# """ +# # Get the orcid from the record path +# expected_orcid = re.search(ORCID_REGEX, record_path).group(0) + +# with open(record_path, "rb") as f: +# orcid_dict = xmltodict.parse(f) + +# try: +# orcid_record = orcid_dict["record:record"] +# # Some records do not have a 'record', but only 'error'. We return the path to the file in this case. +# except KeyError: +# orcid_record = orcid_dict["error:error"] +# return expected_orcid + +# # Check that the ORCID in the file name matches the ORCID in the record +# if not orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid: +# raise ValueError( +# f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}" +# ) + +# # Transform the keys of the dictionary so they are valid BigQuery fields +# orcid_record = change_keys(orcid_record) + +# return orcid_record + + def change_keys(obj): """Recursively goes through the dictionary obj and replaces keys with the convert function. diff --git a/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py b/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py index 8df270bf1..e40c79296 100644 --- a/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py +++ b/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py @@ -27,7 +27,10 @@ import shutil import tempfile from typing import Dict +import unittest from unittest.mock import patch +from dataclasses import dataclass + import pendulum from airflow.models import Connection @@ -61,6 +64,21 @@ ) +@dataclass +class OrcidTestAssets: + fixtures_folder = test_fixtures_folder("orcid") + valid_orcids = [ + {"0000-0001-5000-5000": os.path.join(fixtures_folder, "0000-0001-5000-5000.xml")}, + {"0000-0001-5001-3000": os.path.join(fixtures_folder, "0000-0001-5001-3000.xml")}, + {"0000-0001-5002-1000": os.path.join(fixtures_folder, "0000-0001-5002-1000.xml")}, + ] + error_orcid = {"0000-0001-5007-2000": os.path.join(fixtures_folder, "0000-0001-5007-2000.xml")} + invalid_key_orcid = {"0000-0001-5010-1000": os.path.join(fixtures_folder, "0000-0001-5010-1000.xml")} # Invalid Key + mismatched_orcid = ( + {"0000-0001-5011-1000": os.path.join(fixtures_folder, "0000-0001-5011-1000.xml")}, + ) # ORICD doesn't match path + + class TestOrcidUtils(ObservatoryTestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -190,3 +208,109 @@ def __init__(self, *args, **kwargs): self.aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID") self.aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") self.aws_region_name = os.getenv("AWS_DEFAULT_REGION") + + +class TestCreateOrcidBatchManifest(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.list_blobs_path = "oaebu_workflows.workflows.orcid_telescope.create_orcid_batch_manifest.gcs_list_blobs" + self.test_batch_str = "12X" + self.bucket_name = "test-bucket" + + def test_create_orcid_batch_manifest(self): + """Tests that the manifest file is created with the correct header and contains the correct blob names and + modification dates""" + with tempfile.TemporaryDirectory() as tmp_dir: + download_dir = os.path.join(tmp_dir, "download") + transform_dir = os.path.join(tmp_dir, "transform") + test_batch = OrcidBatch(download_dir, transform_dir, self.test_batch_str) + + # Mock gcs_list_blobs + blobs = [ + storage.Blob( + name=f"{self.test_batch_str}/blob1", + bucket=self.bucket_name, + updated=datetime.datetime(2023, 1, 1), + ), + storage.Blob( + name=f"{self.test_batch_str}/blob2", + bucket=self.bucket_name, + updated=datetime.datetime(2023, 1, 2), + ), + storage.Blob( + name=f"{self.test_batch_str}/blob3", + bucket=self.bucket_name, + updated=datetime.datetime(2022, 12, 31), # Before reference date - should not be included + ), + ] + with patch(self.list_blobs_path, return_value=blobs): + create_orcid_batch_manifest(test_batch, pendulum.datetime(2023, 1, 1), self.bucket_name) + + # Assert manifest file is created with correct header and content + with open(test_batch.manifest_file, "r") as f: + reader = csv.DictReader(f) + rows = list(reader) + + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0]["bucket_name"], self.bucket_name) + self.assertEqual(rows[0]["blob_name"], f"{self.test_batch_str}/blob1") + self.assertEqual(rows[0]["updated"], str(datetime.datetime(2023, 1, 1))) + self.assertEqual(rows[1]["bucket_name"], self.bucket_name) + self.assertEqual(rows[1]["blob_name"], f"{self.test_batch_str}/blob2") + self.assertEqual(rows[1]["updated"], str(datetime.datetime(2023, 1, 2))) + + # Tests that the manifest file is not created if there are no blobs modified after the reference date + + def test_no_results(self): + """Tests that the manifest file is not created if there are no blobs modified after the reference date""" + with tempfile.TemporaryDirectory() as tmp_dir: + download_dir = os.path.join(tmp_dir, "download") + transform_dir = os.path.join(tmp_dir, "transform") + test_batch = OrcidBatch(download_dir, transform_dir, self.test_batch_str) + + # Mock gcs_list_blobs + blobs = [ + storage.Blob( + name=f"{self.test_batch_str}/blob1", + bucket=self.bucket_name, + updated=datetime.datetime(2022, 1, 1), # Before reference date + ) + ] + with patch(self.list_blobs_path, return_value=blobs): + create_orcid_batch_manifest(test_batch, pendulum.datetime(2023, 1, 1), self.bucket_name) + + # Assert manifest file is created + self.assertTrue(os.path.exists(test_batch.manifest_file)) + with open(test_batch.manifest_file, "r") as f: + reader = csv.DictReader(f) + rows = list(reader) + self.assertEqual(len(rows), 0) + + +class TestTransformOrcidRecord(unittest.TestCase): + def test_valid_record(self): + """Tests that a valid ORCID record with 'record' section is transformed correctly""" + for asset in OrcidTestAssets.valid_orcids: + orcid, path = asset.items() + transformed_record = transform_orcid_record(path) + self.assertIsInstance(transformed_record, dict) + self.assertEqual(transformed_record["orcid_identifier"]["path"], orcid) + + def test_error_record(self): + """Tests that an ORCID record with 'error' section is transformed correctly""" + orcid, path = OrcidTestAssets.error_orcid.items() + transformed_record = transform_orcid_record(path) + self.assertIsInstance(transformed_record, str) + self.assertEqual(transformed_record, orcid) + + def test_invalid_key_record(self): + """Tests that an ORCID record with no 'error' or 'record' section raises a Key Error""" + path = OrcidTestAssets.invalid_key_orcid.values()[0] + with self.assertRaises(KeyError): + transform_orcid_record(path) + + def test_mismatched_orcid(self): + """Tests that a ValueError is raised if the ORCID in the file name does not match the ORCID in the record""" + path = OrcidTestAssets.mismatched_orcid.values()[0] + with self.assertRaisesRegex(ValueError, "does not match ORCID in record"): + transform_orcid_record(path)