Skip to content

Commit

Permalink
manifest update - not finished
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Jul 4, 2023
1 parent d5b448d commit 35ac8bc
Showing 1 changed file with 82 additions and 49 deletions.
131 changes: 82 additions & 49 deletions academic_observatory_workflows/workflows/orcid_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import datetime
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import timedelta
from typing import List, Dict, Tuple, Union
import itertools
Expand All @@ -35,6 +35,7 @@
from airflow.hooks.base import BaseHook
from google.cloud import bigquery, storage
from google.cloud.bigquery import SourceFormat
from tenacity import retry, wait_exponential, stop_after_attempt


from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag
Expand Down Expand Up @@ -85,7 +86,7 @@ def __init__(
bq_delete_table_name: str,
start_date: pendulum.DateTime,
end_date: pendulum.DateTime,
prev_end_date: pendulum.DateTime,
prev_release_end: pendulum.DateTime,
prev_latest_modified_record: pendulum.DateTime,
is_first_run: bool,
):
Expand Down Expand Up @@ -115,15 +116,17 @@ def __init__(
self.bq_main_table_name = bq_main_table_name
self.bq_upsert_table_name = bq_upsert_table_name
self.bq_delete_table_name = bq_delete_table_name
self.prev_end_date = prev_end_date
self.prev_release_end = prev_release_end
self.prev_latest_modified_record = prev_latest_modified_record
self.is_first_run = is_first_run

# Files/folders
self.manifest_file_path = os.path.join(self.workflow_folder, "manifest.csv")
self.local_manifest_file_path = os.path.join(self.workflow_folder, "local_manifest.csv")
self.delete_file_path = os.path.join(self.transform_folder, "delete.jsonl.gz")
self.batch_folder = os.path.join(self.workflow_folder, "batch")
self.transfer_manifest_folder = os.path.join(self.workflow_folder, "transfer_manifest")
os.makedirs(self.batch_folder, exist_ok=True)
os.makedirs(self.transfer_manifest_folder, exist_ok=True)

# Table names and URIs
self.table_uri = gcs_blob_uri(
Expand All @@ -133,7 +136,7 @@ def __init__(
self.bq_upsert_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, bq_upsert_table_name)
self.bq_delete_table_id = bq_table_id(cloud_workspace.project_id, bq_dataset_id, bq_delete_table_name)
self.bq_snapshot_table_id = bq_sharded_table_id(
cloud_workspace.project_id, bq_dataset_id, bq_main_table_name, prev_end_date
cloud_workspace.project_id, bq_dataset_id, bq_main_table_name, prev_release_end
)

@property
Expand Down Expand Up @@ -162,12 +165,14 @@ def __init__(
bq_dataset_id: str = "orcid",
bq_main_table_name: str = "orcid",
bq_upsert_table_name: str = "orcid_upsert",
dataset_description: str = "The ORCID dataset",
bq_delete_table_name: str = "orcid_delete",
dataset_description: str = "The ORCID dataset and supporting tables",
table_description: str = "The ORCID dataset",
snapshot_expiry_days: int = 31,
schema_file_path: str = os.path.join(default_schema_folder(), "orcid"),
transfer_attempts: int = 5,
batch_size: int = 100000,
batch_size: int = 20000,
transfer_size: int = 250000,
max_workers: int = os.cpu_count(),
api_dataset_id: str = "orcid",
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API,
Expand All @@ -194,10 +199,12 @@ def __init__(
self.bq_dataset_id = bq_dataset_id
self.bq_main_table_name = bq_main_table_name
self.bq_upsert_table_name = bq_upsert_table_name
self.bq_delete_table_name = bq_delete_table_name
self.api_dataset_id = api_dataset_id
self.schema_file_path = schema_file_path
self.transfer_attempts = transfer_attempts
self.batch_size = batch_size
self.transfer_size = transfer_size
self.dataset_description = dataset_description
self.table_description = table_description
self.snapshot_expiry_days = snapshot_expiry_days
Expand All @@ -223,7 +230,8 @@ def __init__(
self.add_task(self.bq_create_main_table_snapshot)

# Scour the data for updates
self.add_task(self.create_manifest)
self.add_task(self.create_local_manifest)
self.add_task(self.create_gcs_transfer_manifests)
self.add_task(self.create_batches)

# Download and transform updated files
Expand Down Expand Up @@ -253,11 +261,6 @@ def aws_orcid_key(self) -> Tuple[str, str]:
connection = BaseHook.get_connection(self.aws_orcid_conn_id)
return connection.login, connection.password

def skipcheck(self, release):
# This should never realistically happen unless ORCID goes down for a week
if len(release.batch_files) == 0:
raise AirflowSkipException("No files found to process. Skipping remaining tasks.")

def make_release(self, **kwargs) -> OrcidRelease:
dag_run = kwargs["dag_run"]
is_first_run = is_first_dag_run(dag_run)
Expand Down Expand Up @@ -304,7 +307,7 @@ def create_datasets(self, release: OrcidRelease, **kwargs) -> None:
description=self.dataset_description,
)

def transfer_orcid(self):
def transfer_orcid(self, release: OrcidRelease, **kwargs):
"""Sync files from AWS bucket to Google Cloud bucket."""
for i in range(self.transfer_attempts):
logging.info(f"Beginning AWS to GCP transfer attempt no. {i+1}")
Expand All @@ -330,17 +333,16 @@ def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs):
if something goes wrong. The snapshot expires after self.snapshot_expiry_days."""

if release.is_first_run:
raise AirflowSkipException(
f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run"
)
logging.info(f"bq_create_main_table_snapshots: skipping as snapshots are not created on the first run")
return

expiry_date = pendulum.now().add(days=self.snapshot_expiry_days)
success = bq_snapshot(
src_table_id=self.bq_main_table_id, dst_table_id=release.bq_snapshot_table_id, expiry_date=expiry_date
)
set_task_state(success, self.bq_create_main_table_snapshot.__name__, release)

def create_manifest(self, release: OrcidRelease, **kwargs):
def create_local_manifest(self, release: OrcidRelease, **kwargs):
"""Create a manifest of all the modified files in the orcid bucket."""
logging.info("Creating manifest")
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
Expand All @@ -359,25 +361,53 @@ def create_manifest(self, release: OrcidRelease, **kwargs):
manifest_files = [future.result() for future in futures]

logging.info("Joining manifest files")
with open(os.path.join("manifest", "manifest.csv"), "w") as f:
with open(release.local_manifest_file_path, "w") as f:
f.write("path,orcid,size,updated")
for batch_file in manifest_files:
with open(batch_file, "r") as bf:
f.write("\n")
for line in bf:
f.write(line)

def create_gcs_transfer_manifests(self, release: OrcidRelease, **kwargs):
"""Creates manifest file(s) for the GCS transfer job(s). One file per 'transfer_size' records."""
# Count the lines in the manifest file
with open(release.local_manifest_file_path, "r") as f:
f.readline() # Skip the header
num_lines = sum(1 for _ in f)
manifests = math.ceil(num_lines / self.transfer_size)
if manifests == 0:
# This should never realistically happen unless ORCID goes down for a week
raise AirflowSkipException("No files found to process. Skipping remaining tasks.")

lines_read = 0
with open(release.local_manifest_file_path, "r") as f:
f.readline() # Skip the header
for man_n in range(manifests):
blobs = []
# Read 'transfer_size' lines
for _ in range(self.transfer_size):
line = f.readline()
if not line: # End of file gives empty string
break
lines_read += 1
blobs.append(f.readline().strip().split(",")[0])
# Write the manifest file
with open(os.path.join(release.transfer_manifest_folder, f"manifest_{man_n}.csv"), "w") as bf:
bf.write(",\n".join(blobs))
logging.info(f"Read {lines_read} of {num_lines} lines in the manifest file")
assert lines_read == num_lines, "Not all lines in the manifest file were read. Aborting."

def create_batches(self, release: OrcidRelease, **kwargs):
"""Create batches of files to be processed"""
# Count the lines in the manifest file
with open(release.manifest_file_path, "r") as f:
with open(release.local_manifest_file_path, "r") as f:
f.readline() # Skip the header
num_lines = sum(1 for _ in f)
batches = math.ceil(num_lines / self.batch_size)
if batches == 0:
raise AirflowSkipException("No files found to process. Skipping remaining tasks.")

with open(release.manifest_file_path, "r") as f:
lines_read = 0
with open(release.local_manifest_file_path, "r") as f:
f.readline() # Skip the header
for batch_n in range(batches):
batch_blobs = []
Expand All @@ -386,37 +416,43 @@ def create_batches(self, release: OrcidRelease, **kwargs):
line = f.readline()
if not line: # End of file gives empty string
break
lines_read += 1
batch_blobs.append(f.readline().strip().split(",")[0])
# Write the batch file
with open(os.path.join(release.batch_folder, f"batch_{batch_n}.txt"), "w") as bf:
bf.write("\n".join(batch_blobs))
logging.info(f"Read {lines_read} of {num_lines} lines in the manifest file")
assert lines_read == num_lines, "Not all lines in the manifest file were read. Aborting."

def download(self, release: OrcidRelease, **kwargs):
"""Reads the batch files and downloads the files from the gcs bucket."""
logging.info(f"Number of batches: {len(release.batch_files)}")
self.skipcheck(release)

release.make_download_folders()
for batch_file in release.batch_files:
for i, batch_file in enumerate(release.batch_files):
logging.info(f"Downloading batch {i+1} of {len(release.batch_files)}")
with open(batch_file, "r") as f:
orcid_blobs = [line.strip() for line in f.readlines()]

# Download the blobs in parallel
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
with ThreadPoolExecutor(max_workers=os.cpu_count() * 2) as executor:
futures = []
logging.disable(logging.INFO) # Turn off logging to avoid spamming the logs
for blob in orcid_blobs:
file_path = os.path.join(release.download_folder, blob)
# Get path components of the blob (path / to / summaries / 123 / blob_name.xml)
components = os.path.normpath(blob).split(os.sep)
file_path = os.path.join(release.download_folder, components[-2], components[-1])
future = executor.submit(
gcs_download_blob, bucket_name=self.orcid_bucket, blob_name=blob, file_path=file_path
)
futures.append(future)
for future in futures:
for future in as_completed(futures):
future.result()
logging.disable(logging.NOTSET)

def transform(self, release: OrcidRelease, **kwargs):
"""Transforms the downloaded files into serveral bigquery-compatible .jsonl files"""
logging.info(f"Number of batches: {len(release.batch_files)}")
self.skipcheck(release)
delete_paths = []
for i, batch_file in enumerate(os.listdir(release.batch_folder)):
with open(os.path.join(release.batch_folder, batch_file), "r") as f:
Expand Down Expand Up @@ -446,7 +482,6 @@ def transform(self, release: OrcidRelease, **kwargs):

def upload_transformed(self, release: OrcidRelease, **kwargs):
"""Uploads the transformed files to the transform bucket."""
self.skipcheck(release)
success = gcs_upload_files(
bucket_name=self.cloud_workspace.transform_bucket, file_paths=release.transform_files
)
Expand All @@ -455,11 +490,10 @@ def upload_transformed(self, release: OrcidRelease, **kwargs):
def bq_load_main_table(self, release: OrcidRelease, **kwargs):
"""Load the main table."""
if not release.is_first_run:
raise AirflowSkipException(
f"bq_load_main_table: skipping as the main table is only created on the first run"
)
assert len(release.batch_files), "No batch files found. Batch files must exist before loading the main table."
logging.info(f"bq_load_main_table: skipping as the main table is only created on the first run")
return

assert len(release.batch_files), "No batch files found. Batch files must exist before loading the main table."
success = bq_load_table(
uri=release.table_uri,
table_id=release.main_table_id,
Expand All @@ -473,8 +507,8 @@ def bq_load_main_table(self, release: OrcidRelease, **kwargs):
def bq_load_upsert_table(self, release: OrcidRelease, **kwargs):
"""Load the upsert table into bigquery"""
if release.is_first_run:
raise AirflowSkipException(f"bq_load_upsert_table: skipping as no records are upserted on the first run")
self.skipcheck(release)
logging.info(f"bq_load_upsert_table: skipping as no records are upserted on the first run")
return

success = bq_load_table(
uri=release.table_uri,
Expand All @@ -489,10 +523,11 @@ def bq_load_upsert_table(self, release: OrcidRelease, **kwargs):
def bq_load_delete_table(self, release: OrcidRelease, **kwargs):
"""Load the delete table into bigquery"""
if release.is_first_run:
raise AirflowSkipException(f"bq_load_delete_table: skipping as no records are deleted on the first run")
self.skipcheck(release)
logging.info(f"bq_load_delete_table: skipping as no records are deleted on the first run")
return
if not os.path.exists(release.delete_file_path):
raise AirflowSkipException(f"bq_load_delete_table: skipping as no delete file exists")
logging.info(f"bq_load_delete_table: skipping as no delete file exists")
return

success = bq_load_table(
uri=release.table_uri,
Expand All @@ -507,8 +542,8 @@ def bq_load_delete_table(self, release: OrcidRelease, **kwargs):
def bq_upsert_records(self, release: OrcidRelease, **kwargs):
"""Upsert the records from the upserts table into the main table."""
if release.is_first_run:
raise AirflowSkipException("bq_upsert_records: skipping as no records are upserted on the first run")
self.skipcheck(release)
logging.info("bq_upsert_records: skipping as no records are upserted on the first run")
return

success = bq_upsert_records(
main_table_id=release.bq_main_table_id,
Expand All @@ -520,10 +555,11 @@ def bq_upsert_records(self, release: OrcidRelease, **kwargs):
def bq_delete_records(self, release: OrcidRelease, **kwargs):
"""Delete the records in the delete table from the main table."""
if release.is_first_run:
raise AirflowSkipException("bq_delete_records: skipping as no records are deleted on the first run")
self.skipcheck(release)
logging.info("bq_delete_records: skipping as no records are deleted on the first run")
return
if not os.path.exists(release.delete_file_path):
raise AirflowSkipException(f"bq_delete_records: skipping as no delete file exists")
logging.info(f"bq_delete_records: skipping as no delete file exists")
return

success = bq_delete_records(
main_table_id=release.bq_main_table_id,
Expand All @@ -536,22 +572,19 @@ def bq_delete_records(self, release: OrcidRelease, **kwargs):

def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None:
"""Adds release information to API."""
self.skipcheck(release)

dataset_release = DatasetRelease(
dag_id=self.dag_id,
dataset_id=self.api_dataset_id,
dag_run_id=release.run_id,
changefile_start_date=release.start_date,
changefile_end_date=release.end_date,
extra={"latest_modified_record_date": latest_modified_record_date(release.manifest_file_path)},
extra={"latest_modified_record_date": latest_modified_record_date(release.local_manifest_file_path)},
)
api = make_observatory_api(observatory_api_conn_id=self.observatory_api_conn_id)
api.post_dataset_release(dataset_release)

def cleanup(self, release: OrcidRelease, **kwargs) -> None:
"""Delete all files, folders and XComs associated with this release."""

cleanup(dag_id=self.dag_id, execution_date=kwargs["logical_date"], workflow_folder=release.workflow_folder)


Expand Down Expand Up @@ -602,7 +635,7 @@ def orcid_directories() -> List[str]:
:return: A list of all the possible ORCID directories
"""
n_1_2 = [str(i) for i in range(2)] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
n_1_2 = [str(i) for i in range(4)] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
n_3 = n_1_2.copy() + ["X"] # 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, X
combinations = list(itertools.product(n_1_2, n_1_2, n_3)) # Creates the 000 to 99X directory structure
return ["".join(i) for i in combinations]
Expand Down

0 comments on commit 35ac8bc

Please sign in to comment.