From 35ac8bc64e2aaa7995cb7e33b224bc0601445425 Mon Sep 17 00:00:00 2001 From: Keegan Date: Tue, 4 Jul 2023 02:29:09 +0000 Subject: [PATCH] manifest update - not finished --- .../workflows/orcid_telescope.py | 131 +++++++++++------- 1 file changed, 82 insertions(+), 49 deletions(-) diff --git a/academic_observatory_workflows/workflows/orcid_telescope.py b/academic_observatory_workflows/workflows/orcid_telescope.py index 24448e78..86f22ce0 100644 --- a/academic_observatory_workflows/workflows/orcid_telescope.py +++ b/academic_observatory_workflows/workflows/orcid_telescope.py @@ -20,7 +20,7 @@ import datetime import logging import os -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import timedelta from typing import List, Dict, Tuple, Union import itertools @@ -35,6 +35,7 @@ from airflow.hooks.base import BaseHook from google.cloud import bigquery, storage from google.cloud.bigquery import SourceFormat +from tenacity import retry, wait_exponential, stop_after_attempt from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag @@ -85,7 +86,7 @@ def __init__( bq_delete_table_name: str, start_date: pendulum.DateTime, end_date: pendulum.DateTime, - prev_end_date: pendulum.DateTime, + prev_release_end: pendulum.DateTime, prev_latest_modified_record: pendulum.DateTime, is_first_run: bool, ): @@ -115,15 +116,17 @@ def __init__( self.bq_main_table_name = bq_main_table_name self.bq_upsert_table_name = bq_upsert_table_name self.bq_delete_table_name = bq_delete_table_name - self.prev_end_date = prev_end_date + self.prev_release_end = prev_release_end self.prev_latest_modified_record = prev_latest_modified_record self.is_first_run = is_first_run # Files/folders - self.manifest_file_path = os.path.join(self.workflow_folder, "manifest.csv") + self.local_manifest_file_path = os.path.join(self.workflow_folder, "local_manifest.csv") self.delete_file_path = os.path.join(self.transform_folder, "delete.jsonl.gz") self.batch_folder = os.path.join(self.workflow_folder, "batch") + self.transfer_manifest_folder = os.path.join(self.workflow_folder, "transfer_manifest") os.makedirs(self.batch_folder, exist_ok=True) + os.makedirs(self.transfer_manifest_folder, exist_ok=True) # Table names and URIs self.table_uri = gcs_blob_uri( @@ -133,7 +136,7 @@ def __init__( self.bq_upsert_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, bq_upsert_table_name) self.bq_delete_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, bq_delete_table_name) self.bq_snapshot_table_id = bq_sharded_table_id( - cloud_workspace.project_id, bq_dataset_id, bq_main_table_name, prev_end_date + cloud_workspace.project_id, bq_dataset_id, bq_main_table_name, prev_release_end ) @property @@ -162,12 +165,14 @@ def __init__( bq_dataset_id: str = "orcid", bq_main_table_name: str = "orcid", bq_upsert_table_name: str = "orcid_upsert", - dataset_description: str = "The ORCID dataset", + bq_delete_table_name: str = "orcid_delete", + dataset_description: str = "The ORCID dataset and supporting tables", table_description: str = "The ORCID dataset", snapshot_expiry_days: int = 31, schema_file_path: str = os.path.join(default_schema_folder(), "orcid"), transfer_attempts: int = 5, - batch_size: int = 100000, + batch_size: int = 20000, + transfer_size: int = 250000, max_workers: int = os.cpu_count(), api_dataset_id: str = "orcid", observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API, @@ -194,10 +199,12 @@ def __init__( self.bq_dataset_id = bq_dataset_id self.bq_main_table_name = bq_main_table_name self.bq_upsert_table_name = bq_upsert_table_name + self.bq_delete_table_name = bq_delete_table_name self.api_dataset_id = api_dataset_id self.schema_file_path = schema_file_path self.transfer_attempts = transfer_attempts self.batch_size = batch_size + self.transfer_size = transfer_size self.dataset_description = dataset_description self.table_description = table_description self.snapshot_expiry_days = snapshot_expiry_days @@ -223,7 +230,8 @@ def __init__( self.add_task(self.bq_create_main_table_snapshot) # Scour the data for updates - self.add_task(self.create_manifest) + self.add_task(self.create_local_manifest) + self.add_task(self.create_gcs_transfer_manifests) self.add_task(self.create_batches) # Download and transform updated files @@ -253,11 +261,6 @@ def aws_orcid_key(self) -> Tuple[str, str]: connection = BaseHook.get_connection(self.aws_orcid_conn_id) return connection.login, connection.password - def skipcheck(self, release): - # This should never realistically happen unless ORCID goes down for a week - if len(release.batch_files) == 0: - raise AirflowSkipException("No files found to process. Skipping remaining tasks.") - def make_release(self, **kwargs) -> OrcidRelease: dag_run = kwargs["dag_run"] is_first_run = is_first_dag_run(dag_run) @@ -304,7 +307,7 @@ def create_datasets(self, release: OrcidRelease, **kwargs) -> None: description=self.dataset_description, ) - def transfer_orcid(self): + def transfer_orcid(self, release: OrcidRelease, **kwargs): """Sync files from AWS bucket to Google Cloud bucket.""" for i in range(self.transfer_attempts): logging.info(f"Beginning AWS to GCP transfer attempt no. {i+1}") @@ -330,9 +333,8 @@ def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs): if something goes wrong. The snapshot expires after self.snapshot_expiry_days.""" if release.is_first_run: - raise AirflowSkipException( - f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run" - ) + logging.info(f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run") + return expiry_date = pendulum.now().add(days=self.snapshot_expiry_days) success = bq_snapshot( @@ -340,7 +342,7 @@ def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs): ) set_task_state(success, self.bq_create_main_table_snapshot.__name__, release) - def create_manifest(self, release: OrcidRelease, **kwargs): + def create_local_manifest(self, release: OrcidRelease, **kwargs): """Create a manifest of all the modified files in the orcid bucket.""" logging.info("Creating manifest") with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: @@ -359,7 +361,7 @@ def create_manifest(self, release: OrcidRelease, **kwargs): manifest_files = [future.result() for future in futures] logging.info("Joining manifest files") - with open(os.path.join("manifest", "manifest.csv"), "w") as f: + with open(release.local_manifest_file_path, "w") as f: f.write("path,orcid,size,updated") for batch_file in manifest_files: with open(batch_file, "r") as bf: @@ -367,17 +369,45 @@ def create_manifest(self, release: OrcidRelease, **kwargs): for line in bf: f.write(line) + def create_gcs_transfer_manifests(self, release: OrcidRelease, **kwargs): + """Creates manifest file(s) for the GCS transfer job(s). One file per 'transfer_size' records.""" + # Count the lines in the manifest file + with open(release.local_manifest_file_path, "r") as f: + f.readline() # Skip the header + num_lines = sum(1 for _ in f) + manifests = math.ceil(num_lines / self.transfer_size) + if manifests == 0: + # This should never realistically happen unless ORCID goes down for a week + raise AirflowSkipException("No files found to process. Skipping remaining tasks.") + + lines_read = 0 + with open(release.local_manifest_file_path, "r") as f: + f.readline() # Skip the header + for man_n in range(manifests): + blobs = [] + # Read 'transfer_size' lines + for _ in range(self.transfer_size): + line = f.readline() + if not line: # End of file gives empty string + break + lines_read += 1 + blobs.append(f.readline().strip().split(",")[0]) + # Write the manifest file + with open(os.path.join(release.transfer_manifest_folder, f"manifest_{man_n}.csv"), "w") as bf: + bf.write(",\n".join(blobs)) + logging.info(f"Read {lines_read} of {num_lines} lines in the manifest file") + assert lines_read == num_lines, "Not all lines in the manifest file were read. Aborting." + def create_batches(self, release: OrcidRelease, **kwargs): """Create batches of files to be processed""" # Count the lines in the manifest file - with open(release.manifest_file_path, "r") as f: + with open(release.local_manifest_file_path, "r") as f: f.readline() # Skip the header num_lines = sum(1 for _ in f) batches = math.ceil(num_lines / self.batch_size) - if batches == 0: - raise AirflowSkipException("No files found to process. Skipping remaining tasks.") - with open(release.manifest_file_path, "r") as f: + lines_read = 0 + with open(release.local_manifest_file_path, "r") as f: f.readline() # Skip the header for batch_n in range(batches): batch_blobs = [] @@ -386,37 +416,43 @@ def create_batches(self, release: OrcidRelease, **kwargs): line = f.readline() if not line: # End of file gives empty string break + lines_read += 1 batch_blobs.append(f.readline().strip().split(",")[0]) # Write the batch file with open(os.path.join(release.batch_folder, f"batch_{batch_n}.txt"), "w") as bf: bf.write("\n".join(batch_blobs)) + logging.info(f"Read {lines_read} of {num_lines} lines in the manifest file") + assert lines_read == num_lines, "Not all lines in the manifest file were read. Aborting." def download(self, release: OrcidRelease, **kwargs): """Reads the batch files and downloads the files from the gcs bucket.""" logging.info(f"Number of batches: {len(release.batch_files)}") - self.skipcheck(release) release.make_download_folders() - for batch_file in release.batch_files: + for i, batch_file in enumerate(release.batch_files): + logging.info(f"Downloading batch {i+1} of {len(release.batch_files)}") with open(batch_file, "r") as f: orcid_blobs = [line.strip() for line in f.readlines()] # Download the blobs in parallel - with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + with ThreadPoolExecutor(max_workers=os.cpu_count() * 2) as executor: futures = [] + logging.disable(logging.INFO) # Turn off logging to avoid spamming the logs for blob in orcid_blobs: - file_path = os.path.join(release.download_folder, blob) + # Get path components of the blob (path / to / summaries / 123 / blob_name.xml) + components = os.path.normpath(blob).split(os.sep) + file_path = os.path.join(release.download_folder, components[-2], components[-1]) future = executor.submit( gcs_download_blob, bucket_name=self.orcid_bucket, blob_name=blob, file_path=file_path ) futures.append(future) - for future in futures: + for future in as_completed(futures): future.result() + logging.disable(logging.NOTSET) def transform(self, release: OrcidRelease, **kwargs): """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" logging.info(f"Number of batches: {len(release.batch_files)}") - self.skipcheck(release) delete_paths = [] for i, batch_file in enumerate(os.listdir(release.batch_folder)): with open(os.path.join(release.batch_folder, batch_file), "r") as f: @@ -446,7 +482,6 @@ def transform(self, release: OrcidRelease, **kwargs): def upload_transformed(self, release: OrcidRelease, **kwargs): """Uploads the transformed files to the transform bucket.""" - self.skipcheck(release) success = gcs_upload_files( bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.transform_files ) @@ -455,11 +490,10 @@ def upload_transformed(self, release: OrcidRelease, **kwargs): def bq_load_main_table(self, release: OrcidRelease, **kwargs): """Load the main table.""" if not release.is_first_run: - raise AirflowSkipException( - f"bq_load_main_table: skipping as the main table is only created on the first run" - ) - assert len(release.batch_files), "No batch files found. Batch files must exist before loading the main table." + logging.info(f"bq_load_main_table: skipping as the main table is only created on the first run") + return + assert len(release.batch_files), "No batch files found. Batch files must exist before loading the main table." success = bq_load_table( uri=release.table_uri, table_id=release.main_table_id, @@ -473,8 +507,8 @@ def bq_load_main_table(self, release: OrcidRelease, **kwargs): def bq_load_upsert_table(self, release: OrcidRelease, **kwargs): """Load the upsert table into bigquery""" if release.is_first_run: - raise AirflowSkipException(f"bq_load_upsert_table: skipping as no records are upserted on the first run") - self.skipcheck(release) + logging.info(f"bq_load_upsert_table: skipping as no records are upserted on the first run") + return success = bq_load_table( uri=release.table_uri, @@ -489,10 +523,11 @@ def bq_load_upsert_table(self, release: OrcidRelease, **kwargs): def bq_load_delete_table(self, release: OrcidRelease, **kwargs): """Load the delete table into bigquery""" if release.is_first_run: - raise AirflowSkipException(f"bq_load_delete_table: skipping as no records are deleted on the first run") - self.skipcheck(release) + logging.info(f"bq_load_delete_table: skipping as no records are deleted on the first run") + return if not os.path.exists(release.delete_file_path): - raise AirflowSkipException(f"bq_load_delete_table: skipping as no delete file exists") + logging.info(f"bq_load_delete_table: skipping as no delete file exists") + return success = bq_load_table( uri=release.table_uri, @@ -507,8 +542,8 @@ def bq_load_delete_table(self, release: OrcidRelease, **kwargs): def bq_upsert_records(self, release: OrcidRelease, **kwargs): """Upsert the records from the upserts table into the main table.""" if release.is_first_run: - raise AirflowSkipException("bq_upsert_records: skipping as no records are upserted on the first run") - self.skipcheck(release) + logging.info("bq_upsert_records: skipping as no records are upserted on the first run") + return success = bq_upsert_records( main_table_id=release.bq_main_table_id, @@ -520,10 +555,11 @@ def bq_upsert_records(self, release: OrcidRelease, **kwargs): def bq_delete_records(self, release: OrcidRelease, **kwargs): """Delete the records in the delete table from the main table.""" if release.is_first_run: - raise AirflowSkipException("bq_delete_records: skipping as no records are deleted on the first run") - self.skipcheck(release) + logging.info("bq_delete_records: skipping as no records are deleted on the first run") + return if not os.path.exists(release.delete_file_path): - raise AirflowSkipException(f"bq_delete_records: skipping as no delete file exists") + logging.info(f"bq_delete_records: skipping as no delete file exists") + return success = bq_delete_records( main_table_id=release.bq_main_table_id, @@ -536,22 +572,19 @@ def bq_delete_records(self, release: OrcidRelease, **kwargs): def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None: """Adds release information to API.""" - self.skipcheck(release) - dataset_release = DatasetRelease( dag_id=self.dag_id, dataset_id=self.api_dataset_id, dag_run_id=release.run_id, changefile_start_date=release.start_date, changefile_end_date=release.end_date, - extra={"latest_modified_record_date": latest_modified_record_date(release.manifest_file_path)}, + extra={"latest_modified_record_date": latest_modified_record_date(release.local_manifest_file_path)}, ) api = make_observatory_api(observatory_api_conn_id=self.observatory_api_conn_id) api.post_dataset_release(dataset_release) def cleanup(self, release: OrcidRelease, **kwargs) -> None: """Delete all files, folders and XComs associated with this release.""" - cleanup(dag_id=self.dag_id, execution_date=kwargs["logical_date"], workflow_folder=release.workflow_folder) @@ -602,7 +635,7 @@ def orcid_directories() -> List[str]: :return: A list of all the possible ORCID directories """ - n_1_2 = [str(i) for i in range(2)] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + n_1_2 = [str(i) for i in range(4)] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 n_3 = n_1_2.copy() + ["X"] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, X combinations = list(itertools.product(n_1_2, n_1_2, n_3)) # Creates the 000 to 99X directory structure return ["".join(i) for i in combinations]