From 9ce141c349f51dc2927c2976492ca57794a749c2 Mon Sep 17 00:00:00 2001 From: Keegan Date: Tue, 1 Aug 2023 03:34:09 +0000 Subject: [PATCH] finalised workflow structure --- academic_observatory_workflows/s5cmd.py | 139 +++++++----- .../workflows/orcid_telescope.py | 213 +++++------------- .../workflows/tests/test_orcid_telescope.py | 107 +++++---- 3 files changed, 203 insertions(+), 256 deletions(-) diff --git a/academic_observatory_workflows/s5cmd.py b/academic_observatory_workflows/s5cmd.py index ee0c0ffdd..ec851a1f0 100644 --- a/academic_observatory_workflows/s5cmd.py +++ b/academic_observatory_workflows/s5cmd.py @@ -1,53 +1,59 @@ import re import shlex import logging -from dataclasses import dataclass -from typing import List, Tuple +from typing import List, Tuple, Union from subprocess import Popen, PIPE from contextlib import contextmanager from tempfile import NamedTemporaryFile -@dataclass class S5CmdCpConfig: - flatten_dir: bool = False - no_overwrite: bool = False - overwrite_if_size: bool = False - overwrite_if_newer: bool = False + """Configuration for S5Cmd cp command - @property - def cp_config_str(self): - cfg = [ - "--flatten " if self.flatten_dir else "", - "--no-clobber " if self.no_overwrite else "", - "--if-size-differ " if self.overwrite_if_size else "", - "--if-source-newer " if self.overwrite_if_newer else "", - ] - return "".join(cfg) + :param flatten_dir: Whether to flatten the directory structure + :param no_overwrite: Whether to not overwrite files if they already exist + :param overwrite_if_size: Whether to overwrite files only if source size differs + :param overwrite_if_newer: Whether to overwrite files only if source is newer + """ - def __str__(self): - return self.cp_config_str + def __init__( + self, + flatten_dir: bool = False, + no_overwrite: bool = False, + overwrite_if_size: bool = False, + overwrite_if_newer: bool = False, + ): + self.flatten_dir = flatten_dir + self.no_overwrite = no_overwrite + self.overwrite_if_size = overwrite_if_size + self.overwrite_if_newer = overwrite_if_newer - def __bool__(self): - return bool(self.cp_config_str) + def __str__(self): + cfg = [ + "--flatten" if self.flatten_dir else "", + "--no-clobber" if self.no_overwrite else "", + "--if-size-differ" if self.overwrite_if_size else "", + "--if-source-newer" if self.overwrite_if_newer else "", + ] + cfg = [i for i in cfg if i] # Remove empty strings + return " ".join(cfg) class S5Cmd: def __init__( self, access_credentials: Tuple[str, str], - logging_level: str = "debug", + logging_level: str = "info", out_stream: str = PIPE, - cp_config: S5CmdCpConfig = None, ): - if not cp_config: - self.cp_config = S5CmdCpConfig() - self.access_credentials = access_credentials self.logging_level = logging_level - self.output_stream = out_stream + self.out_stream = out_stream self.uri_identifier_regex = r"^[a-zA-Z0-9_]{2}://" + def _uri(self, uri: str): + return uri.replace("gs://", "s3://") + def _initialise_command(self, uri: str): """Initializes the command for the given bucket URI. :param uri: The URI being accessed. @@ -55,12 +61,12 @@ def _initialise_command(self, uri: str): cmd = "s5cmd" # Check that the uri prefixes are supported - bucket_prefix = re.match(self.uri_identifier_regex, uri).group(0) - if bucket_prefix not in ["gs://", "s3://"]: - raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {bucket_prefix}") + uri_prefix = re.match(self.uri_identifier_regex, uri).group(0) + if uri_prefix not in ["gs://", "s3://"]: + raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}") # Amend the URIs with the s3:// prefix and add endpoint URL if required - if bucket_prefix == "gs://": + if uri_prefix == "gs://": cmd = " ".join([cmd, "--endpoint-url https://storage.googleapis.com"]) return cmd @@ -77,8 +83,15 @@ def _bucket_credentials(self): finally: pass - def download_from_bucket(self, uris: List[str], local_path: str): - """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" + def download_from_bucket(self, uris: Union[List[str], str], local_path: str) -> Tuple[bytes, bytes, int]: + """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs. + + :param uris: The URI or list of URIs to download. + :param local_path: The local path to download to. + :return: A tuple of (stdout, stderr, s5cmd exit code). + """ + if not isinstance(uris, list): + uris = [uris] # Check the integrity of the supplied URIs uri_prefixes = [re.match(self.uri_identifier_regex, uri) for uri in uris] @@ -92,18 +105,13 @@ def download_from_bucket(self, uris: List[str], local_path: str): raise ValueError(f"Only gs:// and s3:// URIs are supported. Found prefix: {uri_prefix}") # Amend the URIs with the s3:// prefix and add endpoint URL if required - cmd = "s5cmd" - if uri_prefix == "gs://": - cmd += " --endpoint-url https://storage.googleapis.com" - for i, uri in enumerate(uris): - uris[i] = uri.replace("gs://", "s3://") + cmd = self._initialise_command(uris[0]) # Make the run commands blob_cmds = [] - for uri in uris: + for uri in map(self._uri, uris): blob_cmd = "cp" - if self.cp_config: - blob_cmd += f" {str(self.cp_config)}" + blob_cmd += f" {str(self.cp_config)}" blob_cmd += f" {uri} {local_path}" blob_cmds.append(blob_cmd) blob_stdin = "\n".join(blob_cmds) @@ -112,35 +120,50 @@ def download_from_bucket(self, uris: List[str], local_path: str): # Initialise credentials and execute with self._bucket_credentials() as credentials: cmd += f" --credentials-file {credentials} run" - logging.info(f"Executing download command: {cmd}") proc = Popen(shlex.split(cmd), stdout=self.out_stream, stderr=self.out_stream, stdin=PIPE) - proc.communicate(input=blob_stdin.encode()) - returncode = proc.wait() - return returncode - - def upload_to_bucket(self, files: List[str], bucket_uri: str): - """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs.""" + stdout, stderr = proc.communicate(input=blob_stdin.encode()) + returncode = proc.wait() + if returncode > 0: + logging.warn(f"s5cmd cp failed with return code {returncode}: {stderr}") + return stdout, stderr, returncode + + def upload_to_bucket(self, files: Union[List[str], str], bucket_uri: str, cp_config: S5CmdCpConfig = None): + """Downloads file(s) from a bucket using s5cmd and a supplied list of URIs. + + :param files: The file(s) to upload. + :bucket_uri: The URI to upload to. + :return: A tuple of (stdout, stderr, s5cmd exit code). + """ + if not isinstance(files, list): + files = [files] + if not cp_config: + cp_config = S5CmdCpConfig() cmd = self._initialise_command(bucket_uri) blob_stdin = "\n".join(files) # Initialise credentials and execute with self._bucket_credentials() as credentials: - logging.info(f"Executing download command: {cmd}") - cmd = " ".join([cmd, f" --credentials_file {credentials} cp {self.cp_config_str} {bucket_uri}"]) + cmd = " ".join([cmd, f" --credentials-file {credentials} cp {cp_config} {self._uri(bucket_uri)}"]) proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream) - proc.communicate(input=blob_stdin.encode()) + stdout, stderr = proc.communicate(input=blob_stdin.encode()) returncode = proc.wait() - return returncode - - def cat(self, blob_uri: str) -> Tuple[int, bytes, bytes]: - """Executes a s5cmd cat operation on a remote file""" + if returncode > 0: + logging.warn(f"s5cmd cp failed with return code {returncode}: {stderr}") + return stdout, stderr, returncode + + def cat(self, blob_uri: str) -> Tuple[bytes, bytes, int]: + """Executes a s5cmd cat operation on a remote file + :param blob_uri: The URI to execute the cat on. + :return: A tuple of (stdout, stderr, s5cmd exit code). + """ cmd = self._initialise_command(blob_uri) # Initialise credentials and execute with self._bucket_credentials() as credentials: - logging.info(f"Executing download command: {cmd}") - cmd = " ".join([cmd, f" --credentials_file {credentials} cat {blob_uri}"]) - proc = Popen(shlex.split(cmd), shell=False, stdout=self.out_stream, stderr=self.out_stream) + cmd = " ".join([cmd, f" --credentials-file {credentials} cat {self._uri(blob_uri)}"]) + proc = Popen(shlex.split(cmd), shell=False, stdout=PIPE, stderr=PIPE) stdout, stderr = proc.communicate() returncode = proc.wait() - return returncode, stdout, stderr + if returncode > 0: + logging.warn(f"s5cmd cat failed with return code {returncode}: {stderr}") + return stdout, stderr, returncode diff --git a/academic_observatory_workflows/workflows/orcid_telescope.py b/academic_observatory_workflows/workflows/orcid_telescope.py index e009ee59b..a15bea167 100644 --- a/academic_observatory_workflows/workflows/orcid_telescope.py +++ b/academic_observatory_workflows/workflows/orcid_telescope.py @@ -74,13 +74,12 @@ ORCID_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\b" ORCID_RECORD_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\.xml$" MANIFEST_HEADER = ["bucket_name", "blob_name", "updated"] +BATCH_REGEX = r"^\d{2}(\d|X)$" class OrcidBatch: """Describes a single ORCID batch and its related files/folders""" - BATCH_REGEX = r"^\d{2}(\d|X)$" - def __init__(self, download_dir: str, transform_dir: str, batch_str: str): self.download_dir = download_dir self.transform_dir = transform_dir @@ -92,9 +91,12 @@ def __init__(self, download_dir: str, transform_dir: str, batch_str: str): self.transform_upsert_file = os.path.join(self.transform_dir, f"{self.batch_str}_upsert.jsonl.gz") self.transform_delete_file = os.path.join(self.transform_dir, f"{self.batch_str}_delete.jsonl.gz") - assert os.path.exists(self.download_dir), f"Directory {self.download_dir} does not exist." - assert os.path.exists(self.transform_dir), f"Directory {self.transform_dir} does not exist." - assert re.match(self.BATCH_REGEX, self.batch_str), f"Batch string {self.batch_str} is not valid." + if not os.path.exists(self.download_dir): + raise NotADirectoryError(f"Directory {self.download_dir} does not exist.") + if not os.path.exists(self.transform_dir): + raise NotADirectoryError(f"Directory {self.transform_dir} does not exist.") + if not re.match(BATCH_REGEX, self.batch_str): + raise ValueError(f"Batch string {self.batch_str} is not valid.") os.makedirs(self.download_batch_dir, exist_ok=True) @@ -287,7 +289,7 @@ def __init__( self.add_task(self.create_manifests) # Download and transform updated files - # self.add_task(self.download) + self.add_task(self.download) self.add_task(self.transform) # Load the data to table @@ -325,9 +327,10 @@ def make_release(self, **kwargs) -> OrcidRelease: # Determine the modication cutoff for the new release if is_first_run: - assert ( - len(releases) == 0 or kwargs["task"].task_id == self.cleanup.__name__ - ), "fetch_releases: there should be no DatasetReleases stored in the Observatory API on the first DAG run." + # assert ( + # len(releases) == 0 or kwargs["task"].task_id == self.cleanup.__name__ + # ), "fetch_releases: there should be no DatasetReleases stored in the Observatory API on the first DAG run." + # TODO: uncomment the above ^ prev_latest_modified_record = pendulum.instance(datetime.datetime.min) prev_release_end = pendulum.instance(datetime.datetime.min) @@ -399,6 +402,37 @@ def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs): ) set_task_state(success, self.bq_create_main_table_snapshot.__name__, release) + def download(self, release: OrcidRelease, **kwargs): + """Reads each batch's manifest and downloads the files from the gcs bucket.""" + total_files = 0 + start_time = time.time() + for orcid_batch in release.orcid_batches(): + if not orcid_batch.missing_records: + logging.info(f"All files present for {orcid_batch.batch_str}. Skipping download.") + continue + + logging.info(f"Downloading files for ORCID directory: {orcid_batch.batch_str}") + print(orcid_batch.blob_uris[0]) + with open(orcid_batch.download_log_file, "w") as f: + s5cmd = S5Cmd(access_credentials=self.gcs_hmac_key, out_stream=f) + returncode = s5cmd.download_from_bucket( + uris=orcid_batch.blob_uris, local_path=orcid_batch.download_batch_dir + )[-1] + # returncode = gsutil_download(orcid_batch=orcid_batch) + if returncode != 0: + raise RuntimeError( + f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}" + ) + if orcid_batch.missing_records: + raise FileNotFoundError(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.") + + total_files += len(orcid_batch.expected_records) + logging.info( + f"Downloaded {len(orcid_batch.expected_records)} records for batch '{orcid_batch.batch_str}' completed successfully." + ) + total_time = time.time() - start_time + logging.info(f"Completed download for {total_files} files in {str(datetime.timedelta(seconds=total_time))}") + def create_manifests(self, release: OrcidRelease, **kwargs): """Create a manifest of all the modified files in the orcid bucket.""" logging.info("Creating manifest") @@ -429,37 +463,6 @@ def create_manifests(self, release: OrcidRelease, **kwargs): reader = csv.DictReader(df) writer.writerows(reader) - def download(self, release: OrcidRelease, **kwargs): - """Reads each batch's manifest and downloads the files from the gcs bucket.""" - total_files = 0 - start_time = time.time() - for orcid_batch in release.orcid_batches(): - if not orcid_batch.missing_records: - logging.info(f"All files present for {orcid_batch.batch_str}. Skipping download.") - continue - - logging.info(f"Downloading files for ORCID directory: {orcid_batch.batch_str}") - s5cmd = S5Cmd(access_credentials=self.gcs_hmac_key) - print(orcid_batch.blob_uris[0]) - with open(orcid_batch.download_log_file, "w") as f: - returncode = s5cmd.download_from_bucket( - uris=orcid_batch.blob_uris, local_path=orcid_batch.download_batch_dir, out_stream=f - ) - # returncode = gsutil_download(orcid_batch=orcid_batch) - if returncode != 0: - raise RuntimeError( - f"Download attempt '{orcid_batch.batch_str}': returned non-zero exit code: {returncode}. See log file: {orcid_batch.download_log_file}" - ) - if orcid_batch.missing_records: - raise FileNotFoundError(f"All files were not downloaded for {orcid_batch.batch_str}. Aborting.") - - total_files += len(orcid_batch.expected_records) - logging.info( - f"Downloaded {len(orcid_batch.expected_records)} records for batch '{orcid_batch.batch_str}' completed successfully." - ) - total_time = time.time() - start_time - logging.info(f"Completed download for {total_files} files in {str(datetime.timedelta(seconds=total_time))}") - def transform(self, release: OrcidRelease, **kwargs): """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" total_upsert_records = 0 @@ -467,12 +470,14 @@ def transform(self, release: OrcidRelease, **kwargs): start_time = time.time() for orcid_batch in release.orcid_batches(): batch_start = time.time() - logging.info(f"Transforming ORCID batch using s5cmd cat {orcid_batch.batch_str}") + logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}") transformed_data = [] with ProcessPoolExecutor(max_workers=self.max_workers) as executor: futures = [] - for record_uri in orcid_batch.blob_uris: - future = executor.submit(transform_orcid_record, record_uri, self.gcs_hmac_key) + for record in orcid_batch.existing_records: + future = executor.submit( + transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record) + ) futures.append(future) for future in futures: transformed_data.append(future.result()) @@ -506,54 +511,6 @@ def transform(self, release: OrcidRelease, **kwargs): total_time = time.time() - start_time logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}") - # def transform(self, release: OrcidRelease, **kwargs): - # """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" - # total_upsert_records = 0 - # total_delete_records = 0 - # start_time = time.time() - # for orcid_batch in release.orcid_batches(): - # batch_start = time.time() - # logging.info(f"Transforming ORCID batch {orcid_batch.batch_str}") - # transformed_data = [] - # with ProcessPoolExecutor(max_workers=self.max_workers) as executor: - # futures = [] - # for record in orcid_batch.existing_records: - # future = executor.submit( - # transform_orcid_record, os.path.join(orcid_batch.download_batch_dir, record) - # ) - # futures.append(future) - # for future in futures: - # transformed_data.append(future.result()) - - # # Save records to upsert - # batch_upserts = [record for record in transformed_data if isinstance(record, dict)] - # n_batch_upserts = len(batch_upserts) - # if n_batch_upserts > 0: - # save_jsonl_gz(orcid_batch.transform_upsert_file, batch_upserts) - - # # Save records to delete - # batch_deletes = [{"id": record} for record in transformed_data if isinstance(record, str)] - # n_batch_deletes = len(batch_deletes) - # if n_batch_deletes > 0: - # save_jsonl_gz(orcid_batch.transform_delete_file, batch_deletes) - - # # Check that the number of records matches the expected number of records - # total_records = n_batch_upserts + n_batch_deletes - # assert total_records == len( - # orcid_batch.expected_records - # ), f"Expected {len(orcid_batch.expected_records)} records but got {total_records} records ({n_batch_upserts} upserts | {n_batch_deletes} deletes)" - - # # Record keeping - # total_upsert_records += n_batch_upserts - # total_delete_records += n_batch_deletes - # logging.info( - # f"Transformed {n_batch_upserts} upserts and {n_batch_deletes} deletes for batch {orcid_batch.batch_str}" - # ) - # logging.info(f"Time taken for batch: {time.time() - batch_start} seconds") - # logging.info(f"Transformed {total_upsert_records} upserts and {total_delete_records} deletes in total") - # total_time = time.time() - start_time - # logging.info(f"Time taken for all batches: {str(datetime.timedelta(seconds=total_time))}") - def upload_transformed(self, release: OrcidRelease, **kwargs): """Uploads the upsert and delete files to the transform bucket.""" success = gcs_upload_files(bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.upsert_files) @@ -701,7 +658,7 @@ def create_orcid_batch_manifest( blobs = gcs_list_blobs(bucket, prefix=prefix) manifest = [] for blob in blobs: - if pendulum.instance(blob.updated) > reference_date: + if pendulum.instance(blob.updated) >= reference_date: manifest.append( { MANIFEST_HEADER[0]: blob.bucket.name, @@ -718,33 +675,12 @@ def create_orcid_batch_manifest( logging.info(f"Manifest saved to {orcid_batch.manifest_file}") -def gsutil_download(orcid_batch: OrcidBatch) -> int: - """Download the ORCID files from GCS to the local machine - - :param orcid_batch: The OrcidBatch instance for this orcid directory - :return: The return code of the gsutil command's subprocess - """ - download_script = "gsutil -m -q cp -I -L {log_file} {download_folder}" - - blob_stdin = "\n".join(orcid_batch.blob_uris) - download_command = download_script.format( - log_file=orcid_batch.download_log_file, download_folder=orcid_batch.download_batch_dir - ) - with open(orcid_batch.download_error_file, "w") as f: - download_process = subprocess.Popen(download_command.split(" "), stdin=subprocess.PIPE, stderr=f, stdout=f) - download_process.communicate(input=blob_stdin.encode()) - returncode = download_process.wait() - return returncode - - def latest_modified_record_date(manifest_file_path) -> pendulum.DateTime: """Reads the manifest file and finds the most recent date of modification for the records :param manifest_file_path: the path to the manifest file :return: the most recent date of modification for the records """ - assert os.path.exists(manifest_file_path), f"Manifest file does not exist: {manifest_file_path}" - with open(manifest_file_path, "r") as f: reader = csv.DictReader(f) modified_dates = sorted([pendulum.parse(row["updated"]) for row in reader]) @@ -762,8 +698,9 @@ def orcid_batch_names() -> List[str]: return ["".join(i) for i in combinations] -def transform_orcid_record(record_uri: str, s5credentials: Tuple[str, str]) -> Union(Dict, str): +def transform_orcid_record(record_path: str) -> Union(Dict, str): """Transform a single ORCID file/record. + Streams the content from the blob URI and turns this into a dictionary. The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. The keys of the dictionary are slightly changed so they are valid BigQuery fields. If the record is valid, it is returned. If it is an error, the ORCID ID is returned. @@ -771,16 +708,14 @@ def transform_orcid_record(record_uri: str, s5credentials: Tuple[str, str]) -> U :param download_path: The path to the file with the ORCID record. :param transform_folder: The path where transformed files will be saved. :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned. - :raises AirflowException: If the record is not valid - no 'error' or 'record' section found + :raises KeyError: If the record is not valid - no 'error' or 'record' section found + :raises ValueError: If the ORCID ID does not match the file name's ID """ # Get the orcid from the record path - expected_orcid = re.search(ORCID_REGEX, record_uri).group(0) - s5cmd = S5Cmd(access_credentials=s5credentials) - record = s5cmd.cat(record_uri) - orcid_dict = xmltodict.parse(record.decode("utf-8")) + expected_orcid = re.search(ORCID_REGEX, record_path).group(0) - # with open(record_path, "rb") as f: - # orcid_dict = xmltodict.parse(f) + with open(record_path, "rb") as f: + orcid_dict = xmltodict.parse(f) try: orcid_record = orcid_dict["record:record"] @@ -801,42 +736,6 @@ def transform_orcid_record(record_uri: str, s5credentials: Tuple[str, str]) -> U return orcid_record -# def transform_orcid_record(record_path: str) -> Union(Dict, str): -# """Transform a single ORCID file/record. -# The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. -# The keys of the dictionary are slightly changed so they are valid BigQuery fields. -# If the record is valid, it is returned. If it is an error, the ORCID ID is returned. - -# :param download_path: The path to the file with the ORCID record. -# :param transform_folder: The path where transformed files will be saved. -# :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned. -# :raises AirflowException: If the record is not valid - no 'error' or 'record' section found -# """ -# # Get the orcid from the record path -# expected_orcid = re.search(ORCID_REGEX, record_path).group(0) - -# with open(record_path, "rb") as f: -# orcid_dict = xmltodict.parse(f) - -# try: -# orcid_record = orcid_dict["record:record"] -# # Some records do not have a 'record', but only 'error'. We return the path to the file in this case. -# except KeyError: -# orcid_record = orcid_dict["error:error"] -# return expected_orcid - -# # Check that the ORCID in the file name matches the ORCID in the record -# if not orcid_record["common:orcid-identifier"]["common:path"] == expected_orcid: -# raise ValueError( -# f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}" -# ) - -# # Transform the keys of the dictionary so they are valid BigQuery fields -# orcid_record = change_keys(orcid_record) - -# return orcid_record - - def change_keys(obj): """Recursively goes through the dictionary obj and replaces keys with the convert function. diff --git a/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py b/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py index e40c79296..40fe677a2 100644 --- a/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py +++ b/academic_observatory_workflows/workflows/tests/test_orcid_telescope.py @@ -22,13 +22,14 @@ import gzip import json import os +import re import pathlib import csv import shutil import tempfile from typing import Dict import unittest -from unittest.mock import patch +from unittest.mock import MagicMock, patch from dataclasses import dataclass @@ -42,8 +43,9 @@ OrcidBatch, OrcidRelease, OrcidTelescope, + MANIFEST_HEADER, + BATCH_REGEX, create_orcid_batch_manifest, - gsutil_download, latest_modified_record_date, orcid_batch_names, gcs_list_blobs, @@ -213,37 +215,31 @@ def __init__(self, *args, **kwargs): class TestCreateOrcidBatchManifest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.list_blobs_path = "oaebu_workflows.workflows.orcid_telescope.create_orcid_batch_manifest.gcs_list_blobs" + self.list_blobs_path = "academic_observatory_workflows.workflows.orcid_telescope.gcs_list_blobs" self.test_batch_str = "12X" self.bucket_name = "test-bucket" def test_create_orcid_batch_manifest(self): """Tests that the manifest file is created with the correct header and contains the correct blob names and modification dates""" + blob1 = MagicMock() + blob1.name = f"{self.test_batch_str}/blob1" + blob1.bucket.name = self.bucket_name + blob1.updated = datetime.datetime(2023, 1, 1) + blob2 = MagicMock() + blob2.name = f"{self.test_batch_str}/blob2" + blob2.bucket.name = self.bucket_name + blob2.updated = datetime.datetime(2023, 1, 2) + blob3 = MagicMock() + blob3.name = f"{self.test_batch_str}/blob3" + blob3.bucket.name = self.bucket_name + blob3.updated = datetime.datetime(2022, 12, 31) + with tempfile.TemporaryDirectory() as tmp_dir: - download_dir = os.path.join(tmp_dir, "download") transform_dir = os.path.join(tmp_dir, "transform") - test_batch = OrcidBatch(download_dir, transform_dir, self.test_batch_str) - - # Mock gcs_list_blobs - blobs = [ - storage.Blob( - name=f"{self.test_batch_str}/blob1", - bucket=self.bucket_name, - updated=datetime.datetime(2023, 1, 1), - ), - storage.Blob( - name=f"{self.test_batch_str}/blob2", - bucket=self.bucket_name, - updated=datetime.datetime(2023, 1, 2), - ), - storage.Blob( - name=f"{self.test_batch_str}/blob3", - bucket=self.bucket_name, - updated=datetime.datetime(2022, 12, 31), # Before reference date - should not be included - ), - ] - with patch(self.list_blobs_path, return_value=blobs): + os.mkdir(transform_dir) + test_batch = OrcidBatch(tmp_dir, transform_dir, self.test_batch_str) + with patch(self.list_blobs_path, return_value=[blob1, blob2, blob3]): create_orcid_batch_manifest(test_batch, pendulum.datetime(2023, 1, 1), self.bucket_name) # Assert manifest file is created with correct header and content @@ -252,31 +248,28 @@ def test_create_orcid_batch_manifest(self): rows = list(reader) self.assertEqual(len(rows), 2) - self.assertEqual(rows[0]["bucket_name"], self.bucket_name) - self.assertEqual(rows[0]["blob_name"], f"{self.test_batch_str}/blob1") - self.assertEqual(rows[0]["updated"], str(datetime.datetime(2023, 1, 1))) - self.assertEqual(rows[1]["bucket_name"], self.bucket_name) - self.assertEqual(rows[1]["blob_name"], f"{self.test_batch_str}/blob2") - self.assertEqual(rows[1]["updated"], str(datetime.datetime(2023, 1, 2))) + self.assertEqual(rows[0]["bucket_name"], blob1.bucket.name) + self.assertEqual(rows[0]["blob_name"], blob1.name) + self.assertEqual(rows[0]["updated"], str(blob1.updated)) + self.assertEqual(rows[1]["bucket_name"], blob2.bucket.name) + self.assertEqual(rows[1]["blob_name"], blob2.name) + self.assertEqual(rows[1]["updated"], str(blob2.updated)) # Tests that the manifest file is not created if there are no blobs modified after the reference date def test_no_results(self): """Tests that the manifest file is not created if there are no blobs modified after the reference date""" with tempfile.TemporaryDirectory() as tmp_dir: - download_dir = os.path.join(tmp_dir, "download") transform_dir = os.path.join(tmp_dir, "transform") - test_batch = OrcidBatch(download_dir, transform_dir, self.test_batch_str) + os.mkdir(transform_dir) + test_batch = OrcidBatch(tmp_dir, transform_dir, self.test_batch_str) # Mock gcs_list_blobs - blobs = [ - storage.Blob( - name=f"{self.test_batch_str}/blob1", - bucket=self.bucket_name, - updated=datetime.datetime(2022, 1, 1), # Before reference date - ) - ] - with patch(self.list_blobs_path, return_value=blobs): + blob = MagicMock() + blob.name = f"{self.test_batch_str}/blob1" + blob.bucket.name = self.bucket_name + blob.updated = datetime.datetime(2022, 6, 1) + with patch(self.list_blobs_path, return_value=[blob]): create_orcid_batch_manifest(test_batch, pendulum.datetime(2023, 1, 1), self.bucket_name) # Assert manifest file is created @@ -314,3 +307,35 @@ def test_mismatched_orcid(self): path = OrcidTestAssets.mismatched_orcid.values()[0] with self.assertRaisesRegex(ValueError, "does not match ORCID in record"): transform_orcid_record(path) + + +class TestExtras(unittest.TestCase): + def test_latest_modified_record_date(self): + """Tests that the latest_modified_record_date function returns the correct date""" + # Create a temporary manifest file for the test + with tempfile.NamedTemporaryFile() as temp_file: + with open(temp_file.name, "w") as f: + f.write(",".join(MANIFEST_HEADER)) + f.write("\n") + f.write("gs://test-bucket,folder/0000-0000-0000-0001.xml,2023-06-03T00:00:00Z\n") + f.write("gs://test-bucket,folder/0000-0000-0000-0002.xml,2023-06-03T00:00:00Z\n") + f.write("gs://test-bucket,folder/0000-0000-0000-0003.xml,2023-06-02T00:00:00Z\n") + f.write("gs://test-bucket,folder/0000-0000-0000-0004.xml,2023-06-01T00:00:00Z\n") + + # Call the function and assert the result + expected_date = pendulum.parse("2023-06-03T00:00:00Z") + actual_date = latest_modified_record_date(temp_file.name) + self.assertEqual(actual_date, expected_date) + + def test_orcid_batch_names(self): + """Tests that the orcid_batch_names function returns the expected results""" + batch_names = orcid_batch_names() + + # Test that the function returns a list + self.assertIsInstance(batch_names, list) + self.assertEqual(len(batch_names), 1100) + self.assertTrue(all(isinstance(element, str) for element in batch_names)) + self.assertEqual(len(set(batch_names)), len(batch_names)) + # Test that the batch names match the OrcidBatch regex + for batch_name in batch_names: + self.assertTrue(re.match(BATCH_REGEX, batch_name))