From 2eadfe1089b158488f7b808dd8143c2041dd67c9 Mon Sep 17 00:00:00 2001 From: keegansmith21 Date: Thu, 29 Jun 2023 14:46:10 +0800 Subject: [PATCH] Update --- .../{orcid_2020-01-01.json => orcid.json} | 0 .../database/schema/orcid/orcid_lambda.json | 26 - .../workflows/orcid_telescope.py | 617 ++++++++++++++---- 3 files changed, 499 insertions(+), 144 deletions(-) rename academic_observatory_workflows/database/schema/orcid/{orcid_2020-01-01.json => orcid.json} (100%) delete mode 100644 academic_observatory_workflows/database/schema/orcid/orcid_lambda.json diff --git a/academic_observatory_workflows/database/schema/orcid/orcid_2020-01-01.json b/academic_observatory_workflows/database/schema/orcid/orcid.json similarity index 100% rename from academic_observatory_workflows/database/schema/orcid/orcid_2020-01-01.json rename to academic_observatory_workflows/database/schema/orcid/orcid.json diff --git a/academic_observatory_workflows/database/schema/orcid/orcid_lambda.json b/academic_observatory_workflows/database/schema/orcid/orcid_lambda.json deleted file mode 100644 index d1a2aae2b..000000000 --- a/academic_observatory_workflows/database/schema/orcid/orcid_lambda.json +++ /dev/null @@ -1,26 +0,0 @@ -[ - { - "mode": "NULLABLE", - "name": "orcid", - "type": "STRING", - "description": "ORCID iD." - }, - { - "mode": "NULLABLE", - "name": "path", - "type": "STRING", - "description": "ORCID iD in URI form." - }, - { - "mode": "NULLABLE", - "name": "date_created", - "type": "DATE", - "description": "The date this ID was created" - }, - { - "mode": "NULLABLE", - "name": "last_modified", - "type": "DATE", - "description": "The most recent modification instance for this ID" - } -] \ No newline at end of file diff --git a/academic_observatory_workflows/workflows/orcid_telescope.py b/academic_observatory_workflows/workflows/orcid_telescope.py index 1187760ed..d66239a77 100644 --- a/academic_observatory_workflows/workflows/orcid_telescope.py +++ b/academic_observatory_workflows/workflows/orcid_telescope.py @@ -20,49 +20,56 @@ import datetime import logging import os -import pathlib -import time -from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta -from typing import List, Dict, Tuple -import tarfile +from typing import List, Dict, Tuple, Union +import itertools +import math +import csv +import re -import jsonlines import pendulum -import requests -from airflow.exceptions import AirflowSkipException -from airflow.models.taskinstance import TaskInstance +import xmltodict +from airflow.exceptions import AirflowException, AirflowSkipException from airflow.operators.dummy import DummyOperator from airflow.hooks.base import BaseHook -import boto3 -from google.cloud import bigquery +from google.cloud import bigquery, storage from google.cloud.bigquery import SourceFormat -from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag + +from academic_observatory_workflows.config import schema_folder as default_schema_folder from observatory.api.client.model.dataset_release import DatasetRelease from observatory.platform.airflow import PreviousDagRunSensor, is_first_dag_run from observatory.platform.api import get_dataset_releases, get_latest_dataset_release from observatory.platform.api import make_observatory_api from observatory.platform.bigquery import ( bq_table_id, - bq_find_schema, bq_load_table, bq_upsert_records, - bq_snapshot, + bq_delete_records, bq_sharded_table_id, bq_create_dataset, - bq_delete_records, + bq_snapshot, ) from observatory.platform.config import AirflowConns -from observatory.platform.files import list_files, yield_jsonl, merge_update_files, save_jsonl_gz, load_csv -from observatory.platform.gcs import gcs_upload_files, gcs_blob_uri, gcs_blob_name_from_path +from observatory.platform.files import list_files, save_jsonl_gz +from observatory.platform.gcs import ( + gcs_upload_files, + gcs_download_blob, + gcs_blob_uri, + gcs_blob_name_from_path, + gcs_create_aws_transfer, +) from observatory.platform.observatory_config import CloudWorkspace -from observatory.platform.utils.url_utils import get_user_agent, retry_get_url -from observatory.platform.workflows.workflow import Workflow, ChangefileRelease, cleanup, set_task_state +from observatory.platform.workflows.workflow import ( + Workflow, + ChangefileRelease, + cleanup, + set_task_state, +) ORCID_AWS_SUMMARIES_BUCKET = "v2.0-summaries" -ORCID_AWS_LAMBDA_BUCKET = "orcid-lambda-file" -ORCID_LAMBDA_OBJECT = "last_modified.csv.tar" +ORCID_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\b" class OrcidRelease(ChangefileRelease): @@ -71,47 +78,85 @@ def __init__( *, dag_id: str, run_id: str, + cloud_workspace: CloudWorkspace, + bq_dataset_id: str, + bq_table_name: str, start_date: pendulum.DateTime, end_date: pendulum.DateTime, + prev_end_date: pendulum.DateTime, + modification_cutoff: pendulum.DateTime, + is_first_run: bool, ): """Construct a CrossrefEventsRelease instance :param dag_id: the id of the DAG. :param start_date: the start_date of the release. Inclusive. :param end_date: the end_date of the release. Exclusive. + :param modification_cutoff: the record modification cutoff date of the release. Files modified after this date will be included. """ - super().__init__( dag_id=dag_id, run_id=run_id, start_date=start_date, end_date=end_date, ) - self.upsert_table_file_path = os.path.join(self.transform_folder, "upsert_table.jsonl") - self.delete_table_file_path = os.path.join(self.transform_folder, "delete_table.jsonl") + self.cloud_workspace = cloud_workspace + self.bq_dataset_id = bq_dataset_id + self.bq_table_name = bq_table_name + self.prev_end_date = prev_end_date + self.modification_cutoff = modification_cutoff + self.is_first_run = is_first_run + + # Files/folders + self.manifest_file_path = os.path.join(self.workflow_folder, "manifest.csv") + self.delete_file_path = os.path.join(self.transform_folder, "delete.jsonl.gz") + self.batch_folder = os.path.join(self.workflow_folder, "batch") + os.makedirs(self.batch_folder, exist_ok=True) + + # Table names and URIs + self.table_uri = gcs_blob_uri( + self.cloud_workspace.transform_bucket, f"{gcs_blob_name_from_path(self.transform_folder)}/*.jsonl.gz" + ) + self.bq_main_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, bq_table_name) + self.bq_upsert_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, f"{bq_table_name}_upsert") + self.bq_delete_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, f"{bq_table_name}_delete") + self.bq_snapshot_table_id = bq_sharded_table_id( + cloud_workspace.project_id, bq_dataset_id, bq_table_name, prev_end_date + ) - self.lambda_download_file_path = os.path.join(self.download_folder, "lambda", "orcid_lambda.csv.tar") - self.lambda_transform_file_path = os.path.join(self.transform_folder, "lambda", "orcid_lambda.csv") + @property + def batch_files(self): # Batch files in the batch_folder + return list_files(self.batch_folder, r"^batch_\d+.txt") @property - def lambda_extract_files(self): - return list_files(self.extract_folder, r"^.*\.csv$") # Match files ending with .csv + def transform_files(self): # Transform files in the transform folder + return list_files(self.transform_folder, r"^transformed_batch_\d+.jsonl.gz$") + + def make_download_folders(self) -> None: + """Creates the orcid directories in the download folder if they don't exist""" + for folder in orcid_directories(): + os.makedirs(os.path.join(self.download_folder, folder), exist_ok=True) -class CrossrefEventsTelescope(Workflow): - """Crossref Events telescope""" +class OrcidTelescope(Workflow): + """ORCID telescope""" def __init__( self, dag_id: str, cloud_workspace: CloudWorkspace, + orcid_bucket: str = "orcid-testing", # TODO: change before publishing + orcid_summaries_prefix: str = "orcid_summaries", bq_dataset_id: str = "orcid", - bq_summary_table_name: str = "orcid", - bq_lambda_table_name: str = "orcid_lambda", - schema_folder: str = os.path.join(default_schema_folder(), "orcid"), + bq_main_table_name: str = "orcid", + bq_upsert_table_name: str = "orcid_upsert", dataset_description: str = "The ORCID dataset", table_description: str = "The ORCID dataset", - max_processes: int = os.cpu_count(), + snapshot_expiry_days: int = 31, + schema_file_path: str = os.path.join(default_schema_folder(), "orcid"), + transfer_attempts: int = 5, + batch_size: int = 100000, + max_workers: int = os.cpu_count(), api_dataset_id: str = "orcid", observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API, aws_orcid_conn_id: str = "aws_orcid", @@ -124,14 +169,19 @@ def __init__( self.dag_id = dag_id self.cloud_workspace = cloud_workspace + self.orcid_bucket = orcid_bucket + self.orcid_summaries_prefix = orcid_summaries_prefix self.bq_dataset_id = bq_dataset_id - self.bq_summary_table_name = bq_summary_table_name - self.bq_lambda_table_name = bq_lambda_table_name + self.bq_main_table_name = bq_main_table_name + self.bq_upsert_table_name = bq_upsert_table_name self.api_dataset_id = api_dataset_id - self.schema_folder = schema_folder + self.schema_file_path = schema_file_path + self.transfer_attempts = transfer_attempts + self.batch_size = batch_size self.dataset_description = dataset_description self.table_description = table_description - self.max_processes = max_processes + self.snapshot_expiry_days = snapshot_expiry_days + self.max_workers = max_workers self.observatory_api_conn_id = observatory_api_conn_id self.aws_orcid_conn_id = aws_orcid_conn_id self.start_date = start_date @@ -147,14 +197,32 @@ def __init__( execution_delta=timedelta(days=7), # To match the @weekly schedule_interval ) ) - self.add_task(self.check_dependencies) + self.add_setup_task(self.check_dependencies) + self.add_task(self.create_datasets) + + # Download the data self.add_task(self.transfer_orcid) - self.add_task(self.process_lambda) - self.add_task(self.bq_load_lambda) - # self.add_task(self.upsert_files) - # self.add_task(self.add_new_dataset_release) - # self.add_task(self.cleanup) + # Create snapshots of main table in case we mess up + self.add_task(self.bq_create_main_table_snapshot) + + # Scour the data for updates + self.add_task(self.create_manifest) + self.add_task(self.create_batches) + + # Download and transform updated files + self.add_task(self.download) + self.add_task(self.transform) + + # Load the data to table + self.add_task(self.upload_transformed) + self.add_task(self.bq_load_main_table) + self.add_task(self.bq_load_upsert_table) + self.add_task(self.bq_upsert_records) + + # Finish + self.add_task(self.add_new_dataset_release) + self.add_task(self.cleanup) # The last task that the next DAG run's ExternalTaskSensor waits for. self.add_operator( @@ -169,116 +237,429 @@ def aws_orcid_key(self) -> Tuple[str, str]: connection = BaseHook.get_connection(self.aws_orcid_conn_id) return connection.login, connection.password + def skipcheck(self, release): + # This should never realistically happen unless ORCID goes down for a week + if len(release.batch_files) == 0: + raise AirflowSkipException("No files found to process. Skipping remaining tasks.") + def make_release(self, **kwargs) -> OrcidRelease: - is_first_run = is_first_dag_run(kwargs["dag_run"]) + dag_run = kwargs["dag_run"] + is_first_run = is_first_dag_run(dag_run) releases = get_dataset_releases(dag_id=self.dag_id, dataset_id=self.api_dataset_id) - # Get start date + # Determine the modication cutoff for the new release if is_first_run: assert ( len(releases) == 0 ), "fetch_releases: there should be no DatasetReleases stored in the Observatory API on the first DAG run." - start_date = pendulum.instance(datetime.datetime.min) + + modification_cutoff = pendulum.instance(datetime.datetime.min) else: assert ( len(releases) >= 1 ), f"fetch_releases: there should be at least 1 DatasetRelease in the Observatory API after the first DAG run" - start_date = kwargs["data_interval_start"] + prev_release = get_latest_dataset_release(releases, "changefile_end_date") + modification_cutoff = prev_release.extra["latest_modified_record_date"] return OrcidRelease( dag_id=self.dag_id, run_id=kwargs["run_id"], - start_date=start_date, + start_date=kwargs["data_interval_start"], end_date=kwargs["data_interval_end"], + modification_cutoff=modification_cutoff, + is_first_run=is_first_run, + ) + + def create_datasets(self, release: OrcidRelease, **kwargs) -> None: + """Create datasets""" + + bq_create_dataset( + project_id=self.cloud_workspace.project_id, + dataset_id=self.bq_dataset_id, + location=self.cloud_workspace.data_location, + description=self.dataset_description, ) def transfer_orcid(self): - pass + """Sync files from AWS bucket to Google Cloud bucket.""" + for i in range(self.transfer_attempts): + logging.info(f"Beginning AWS to GCP transfer attempt no. {i+1}") + success, objects_count = gcs_create_aws_transfer( + aws_key=self.aws_orcid_key, + aws_bucket=ORCID_AWS_SUMMARIES_BUCKET, + include_prefixes=[], + gc_project_id=self.cloud_workspace.project_id, + gc_bucket_dst_uri=gcs_blob_uri(self.orcid_bucket, "orcid_summaries/"), + description="Transfer ORCID data from AWS to GCP", + ) + logging.info(f"Attempt {i+1}: Total number of objects transferred: {objects_count}") + logging.info(f"Attempt {i+1}: Success? {success}") + if success: + break + if not success: + raise AirflowException( + f"Failed to transfer ORCID data from AWS to GCP after {self.transfer_attempts} attempts" + ) - def process_lambda(self, release: OrcidRelease, **kwargs): - """Downloads, extracts, transforms and uploads the ORCID Lambda manifest file""" - aws_key_id, aws_key = self.aws_orcid_key - s3client = boto3.client("s3", aws_access_key_id=aws_key_id, aws_secret_access_key=aws_key) + def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs): + """Create a snapshot of each main table. The purpose of this table is to be able to rollback the table + if something goes wrong. The snapshot expires after self.snapshot_expiry_days.""" - # Download from S3 bucket - s3client.download_file(ORCID_AWS_LAMBDA_BUCKET, ORCID_LAMBDA_OBJECT, release.lambda_download_file_path) + if release.is_first_run: + raise AirflowSkipException( + f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run" + ) - # Extract - with tarfile.open(release.lambda_download_file_path) as lambda_tar: - lambda_tar.extractall(release.extract_folder) - assert len(release.lambda_extract_files) == 1, "Unexpected number of files in extract folder" + expiry_date = pendulum.now().add(days=self.snapshot_expiry_days) + success = bq_snapshot( + src_table_id=self.bq_main_table_id, dst_table_id=release.bq_snapshot_table_id, expiry_date=expiry_date + ) + set_task_state(success, self.bq_create_main_table_snapshot.__name__, release) - # Transform - orcid_lambda = load_csv(release.lambda_extract_files[0]) - with ProcessPoolExecutor(max_workers=self.max_processes) as executor: + def create_manifest(self, release: OrcidRelease, **kwargs): + """Create a manifest of all the modified files in the orcid bucket.""" + logging.info("Creating manifest") + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: futures = [] - for row in orcid_lambda: - future = executor.submit(transform_item, row) - futures.append(future) - - transformed = [] - for future in as_completed(futures): - transformed.append(future.result()) - finished += 1 - if finished % 100000 == 0: - logging.info(f"Transformed {finished}/{len(orcid_lambda)} rows") - save_jsonl_gz(release.lambda_transform_file_path, transformed) - gcs_upload_files( - bucket_name=self.cloud_workspace.transform_bucket, file_paths=[release.lambda_transform_file_path] + for orcid_dir in orcid_directories(): + dir_name = f"{self.orcid_summaries_prefix}/{orcid_dir}/" + futures.append( + executor.submit( + create_manifest_batch, + bucket=self.orcid_bucket, + dir_name=dir_name, + reference_date=release.end_date, + save_path=os.path.join(release.download_folder, f"{orcid_dir}_manifest.csv"), + ) + ) + manifest_files = [future.result() for future in futures] + + logging.info("Joining manifest files") + with open(os.path.join("manifest", "manifest.csv"), "w") as f: + f.write("path,orcid,size,updated") + for batch_file in manifest_files: + with open(batch_file, "r") as bf: + f.write("\n") + for line in bf: + f.write(line) + + def create_batches(self, release: OrcidRelease, **kwargs): + """Create batches of files to be processed""" + # Count the lines in the manifest file + with open(release.manifest_file_path, "r") as f: + f.readline() # Skip the header + num_lines = sum(1 for _ in f) + batches = math.ceil(num_lines / self.batch_size) + if batches == 0: + raise AirflowSkipException("No files found to process. Skipping remaining tasks.") + + with open(release.manifest_file_path, "r") as f: + f.readline() # Skip the header + for batch_n in range(batches): + batch_blobs = [] + # Read 'batch_size' lines + for _ in range(self.batch_size): + line = f.readline() + if not line: # End of file gives empty string + break + batch_blobs.append(f.readline().strip().split(",")[0]) + # Write the batch file + with open(os.path.join(release.batch_folder, f"batch_{batch_n}.txt"), "w") as bf: + bf.write("\n".join(batch_blobs)) + + def download(self, release: OrcidRelease, **kwargs): + """Reads the batch files and downloads the files from the gcs bucket.""" + logging.info(f"Number of batches: {len(release.batch_files)}") + self.skipcheck(release) + + release.make_download_folders() + for batch_file in release.batch_files: + with open(batch_file, "r") as f: + orcid_blobs = [line.strip() for line in f.readlines()] + + # Download the blobs in parallel + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + futures = [] + for blob in orcid_blobs: + file_path = os.path.join(release.download_folder, blob) + future = executor.submit( + gcs_download_blob, bucket_name=self.orcid_bucket, blob_name=blob, file_path=file_path + ) + futures.append(future) + for future in futures: + future.result() + + def transform(self, release: OrcidRelease, **kwargs): + """Transforms the downloaded files into serveral bigquery-compatible .jsonl files""" + logging.info(f"Number of batches: {len(release.batch_files)}") + self.skipcheck(release) + delete_paths = [] + for i, batch_file in enumerate(os.listdir(release.batch_folder)): + with open(os.path.join(release.batch_folder, batch_file), "r") as f: + orcid_records = [line.strip() for line in f.readlines()] + + # Transfrom the files in parallel + transformed_data = [] + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [] + for record in orcid_records: + future = executor.submit(transform_orcid_record, record) + futures.append(future) + for future in futures: + transformed_data.append(future.result()) + + # Save transformed records + transformed_data = [record for record in transformed_data if isinstance(record, dict)] + transform_file = os.path.join(release.transform_folder, f"transformed_batch_{i}.jsonl.gz") + save_jsonl_gz(transformed_data, transform_file) + + # Keep track of reccords to delete + delete_paths.extend([{"id": record} for record in transformed_data if isinstance(record, str)]) + + # Save the delete paths if there are any + if delete_paths: + save_jsonl_gz(delete_paths, release.delete_file_path) + + def upload_transformed(self, release: OrcidRelease, **kwargs): + """Uploads the transformed files to the transform bucket.""" + self.skipcheck(release) + success = gcs_upload_files( + bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.transform_files ) + set_task_state(success, self.upload_transformed.__name__, release) - def bq_load_lambda(self, release: OrcidRelease, **kwargs): - bq_create_dataset( - project_id=self.cloud_workspace.project_id, - dataset_id=self.bq_dataset_id, - location=self.cloud_workspace.data_location, - description=self.dataset_description, - ) + def bq_load_main_table(self, release: OrcidRelease, **kwargs): + """Load the main table.""" + if not release.is_first_run: + raise AirflowSkipException( + f"bq_load_main_table: skipping as the main table is only created on the first run" + ) + assert len(release.batch_files), "No batch files found. Batch files must exist before loading the main table." - # Selects all jsonl.gz files in the releases transform folder on the Google Cloud Storage bucket and all of its - # subfolders: https://cloud.google.com/bigquery/docs/batch-loading-data#load-wildcards - uri = gcs_blob_uri( - self.cloud_workspace.transform_bucket, - gcs_blob_name_from_path(release.lambda_transform_file_path), + success = bq_load_table( + uri=release.table_uri, + table_id=release.main_table_id, + schema_file_path=self.schema_file_path, + source_format=SourceFormat.NEWLINE_DELIMITED_JSON, + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + ignore_unknown_values=False, ) - table_id = bq_table_id( - self.cloud_workspace.project_id, self.bq_dataset_id, self.bq_lambda_table_name, release.end_date + set_task_state(success, self.bq_load_main_table.__name__, release) + + def bq_load_upsert_table(self, release: OrcidRelease, **kwargs): + """Load the upsert table into bigquery""" + if release.is_first_run: + raise AirflowSkipException(f"bq_load_upsert_table: skipping as no records are upserted on the first run") + self.skipcheck(release) + + success = bq_load_table( + uri=release.table_uri, + table_id=release.upsert_table_id, + schema_file_path=self.schema_file_path, + source_format=SourceFormat.NEWLINE_DELIMITED_JSON, + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, + ignore_unknown_values=True, ) - schema_file_path = bq_find_schema(path=self.schema_folder, table_name=self.bq_table_name) + set_task_state(success, self.bq_load_upsert_table.__name__, release) + + def bq_load_delete_table(self, release: OrcidRelease, **kwargs): + """Load the delete table into bigquery""" + if release.is_first_run: + raise AirflowSkipException(f"bq_load_delete_table: skipping as no records are deleted on the first run") + self.skipcheck(release) + if not os.path.exists(release.delete_file_path): + raise AirflowSkipException(f"bq_load_delete_table: skipping as no delete file exists") + success = bq_load_table( - uri=uri, - table_id=table_id, - schema_file_path=schema_file_path, + uri=release.table_uri, + table_id=release.delete_table_id, + schema_file_path=self.schema_file_path, source_format=SourceFormat.NEWLINE_DELIMITED_JSON, - table_description="The ORCID Lambda Manifest", + write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE, ignore_unknown_values=True, ) - set_task_state(success, self.bq_load.__name__, release) + set_task_state(success, self.bq_load_delete_table.__name__, release) + + def bq_upsert_records(self, release: OrcidRelease, **kwargs): + """Upsert the records from the upserts table into the main table.""" + if release.is_first_run: + raise AirflowSkipException("bq_upsert_records: skipping as no records are upserted on the first run") + self.skipcheck(release) + + success = bq_upsert_records( + main_table_id=release.bq_main_table_id, + upsert_table_id=release.bq_upsert_table_id, + primary_key="orcid_identifier.orcid", + ) + set_task_state(success, self.bq_upsert_records.__name__, release) + + def bq_delete_records(self, release: OrcidRelease, **kwargs): + """Delete the records in the delete table from the main table.""" + if release.is_first_run: + raise AirflowSkipException("bq_delete_records: skipping as no records are deleted on the first run") + self.skipcheck(release) + if not os.path.exists(release.delete_file_path): + raise AirflowSkipException(f"bq_delete_records: skipping as no delete file exists") + + success = bq_delete_records( + main_table_id=release.bq_main_table_id, + delete_table_id=release.bq_delete_table_id, + main_table_primary_key="path", + delete_table_primary_key="id", + main_table_primary_key_prefix="orcid_identifier", + ) + set_task_state(success, self.bq_delete_records.__name__, release) + + def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None: + """Adds release information to API.""" + self.skipcheck(release) + + dataset_release = DatasetRelease( + dag_id=self.dag_id, + dataset_id=self.api_dataset_id, + dag_run_id=release.run_id, + changefile_start_date=release.start_date, + changefile_end_date=release.end_date, + extra={"latest_modified_record_date": latest_modified_record_date(release.manifest_file_path)}, + ) + api = make_observatory_api(observatory_api_conn_id=self.observatory_api_conn_id) + api.post_dataset_release(dataset_release) + def cleanup(self, release: OrcidRelease, **kwargs) -> None: + """Delete all files, folders and XComs associated with this release.""" -def transform_item(item): - """Transform a single Crossref Metadata JSON value. + cleanup(dag_id=self.dag_id, execution_date=kwargs["logical_date"], workflow_folder=release.workflow_folder) - :param item: a JSON value. - :return: the transformed item. + +def create_manifest_batch(bucket: str, dir_name: str, save_path: str, reference_date: pendulum.DateTime) -> str: + """Create a manifest of all the modified files in the orcid bucket for a subfolder + + :param bucket: The name of the bucket + :param dir_name: The name of the subfolder. e.g. 025 or 94X - see orcid_directories() + :param save_path: The path to save the manifest to + :param reference_date: The date to use as a reference for the manifest + :return: The path to the manifest """ + logging.info(f"Creating manifest for {dir_name}") + if not save_path: + save_path = f"{dir_name}_manifest.csv" + blobs = gcs_list_blobs(bucket, prefix=dir_name) + manifest = [] + for blob in blobs: + if pendulum.instance(blob.updated) > reference_date: + # Extract the orcid ID from the blob name + orcid = re.search(ORCID_REGEX, blob.name).group(0) + manifest.append((blob.name, orcid, blob.size, blob.updated)) + + if manifest: + with open(save_path, "w") as f: + f.write("\n".join([",".join([str(i) for i in row]) for row in manifest])) + + return save_path + + +def latest_modified_record_date(manifest_file_path) -> pendulum.DateTime: + """Reads the manifest file and finds the most recent date of modification for the records + + :param manifest_file_path: the path to the manifest file + :return: the most recent date of modification for the records + """ + if not os.path.exists(manifest_file_path): + return None + + with open(manifest_file_path, "r") as f: + reader = csv.DictReader(f) + modified_dates = [pendulum.parse(row["updated"]) for row in reader].sort() + return modified_dates[-1] + + +def orcid_directories() -> List[str]: + """Create a list of all the possible ORCID directories - if isinstance(item, dict): - new = {} - for k, v in item.items(): - # Replace hyphens with underscores for BigQuery compatibility - k = k.replace("-", "_") - - # Get inner array for date parts - if k == "date_created" or k == "last_modified": - try: - datetime.strptime(v, "%Y-%m-%dT%H:%M:%SZ") - except ValueError: - v = "" - - new[k] = transform_item(v) - return new - elif isinstance(item, list): - return [transform_item(i) for i in item] + :return: A list of all the possible ORCID directories + """ + n_1_2 = [str(i) for i in range(2)] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + n_3 = n_1_2.copy() + ["X"] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, X + combinations = list(itertools.product(n_1_2, n_1_2, n_3)) # Creates the 000 to 99X directory structure + return ["".join(i) for i in combinations] + + +def gcs_list_blobs(bucket_name: str, prefix: str = None) -> List[storage.Blob]: + """List blobs in a bucket using a gcs_uri. + + :param bucket_name: The name of the bucket + :param prefix: The prefix to filter by + """ + storage_client = storage.Client() + return list(storage_client.list_blobs(bucket_name, prefix=prefix)) + + +def transform_orcid_record(record_path: str) -> Union(Dict[str, any], str): + """Transform a single ORCID file/record. + The xml file is turned into a dictionary, a record should have either a valid 'record' section or an 'error'section. + The keys of the dictionary are slightly changed so they are valid BigQuery fields. + If the record is valid, it is returned. If it is an error, the ORCID ID is returned. + + :param download_path: The path to the file with the ORCID record. + :param transform_folder: The path where transformed files will be saved. + :return: The transformed ORCID record. If the record is an error, the ORCID ID is returned. + """ + # Get the orcid from the record path + expected_orcid = re.search(ORCID_REGEX, record_path).group(0) + + with open(record_path, "r") as f: + orcid_dict = xmltodict.parse(f) + orcid_record = orcid_dict.get("record:record") + + # Some records do not have a 'record', but only 'error'. We return the path to the file in this case. + if not orcid_record: + orcid_record = orcid_dict.get("error:error") + if not orcid_record: + raise AirflowException(f"Key error for file: {record_path}") + return expected_orcid + + # Check that the ORCID in the file name matches the ORCID in the record + assert ( + orcid_record["record:record"]["common:orcid-identifier"]["common:path"] == expected_orcid + ), f"Expected ORCID {expected_orcid} does not match ORCID in record {orcid_record['common:path']}" + + # Transform the keys of the dictionary so they are valid BigQuery fields + orcid_record = change_keys(orcid_record, convert) + + return orcid_record + + +def change_keys(obj, convert): + """Recursively goes through the dictionary obj and replaces keys with the convert function. + + :param obj: The dictionary value, can be object of any type + :param convert: The convert function. + :return: The transformed object. + """ + if isinstance(obj, (str, int, float)): + return obj + if isinstance(obj, dict): + new = obj.__class__() + for k, v in list(obj.items()): + if k.startswith("@xmlns"): + continue + new[convert(k)] = change_keys(v, convert) + elif isinstance(obj, (list, set, tuple)): + new = obj.__class__(change_keys(v, convert) for v in obj) else: - return item + return obj + return new + + +def convert(k: str) -> str: + """Convert key of dictionary to valid BQ key. + + :param k: Key + :return: The converted key + """ + if len(k.split(":")) > 1: + k = k.split(":")[1] + if k.startswith("@") or k.startswith("#"): + k = k[1:] + k = k.replace("-", "_") + return k