Skip to content

Commit

Permalink
dictreader and record deletion
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Jul 11, 2023
1 parent 8e14258 commit 68ef2cd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 52 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[
{
"mode": "REQUIRED",
"name": "id",
"type": "STRING",
"description": "The ORCID ID to delete."
}
]
115 changes: 63 additions & 52 deletions academic_observatory_workflows/workflows/orcid_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
ORCID_AWS_SUMMARIES_BUCKET = "v2.0-summaries"
ORCID_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\b"
ORCID_RECORD_REGEX = r"\d{4}-\d{4}-\d{4}-\d{3}(\d|X)\.xml$"
MANIFEST_HEADER = ["bucket_name", "blob_name", "updated"]


class OrcidBatch:
Expand Down Expand Up @@ -99,8 +100,8 @@ def __init__(self, download_dir: str, transform_dir: str, batch_str: str):
def expected_records(self) -> List[str]:
"""List of expected ORCID records for this ORCID directory. Derived from the manifest file"""
with open(self.manifest_file, "r") as f:
reader = csv.reader(f)
return [os.path.basename(row[1]) for row in reader]
reader = csv.DictReader(f)
return [os.path.basename(row["blob_name"]) for row in reader]

@property
def existing_records(self) -> List[str]:
Expand All @@ -116,8 +117,8 @@ def missing_records(self) -> List[str]:
def blob_uris(self) -> List[str]:
"""List of blob URIs from the manifest this ORCID directory."""
with open(self.manifest_file, "r") as f:
reader = csv.reader(f)
return [gcs_blob_uri(bucket_name=row[0], blob_name=row[1]) for row in reader]
reader = csv.DictReader(f)
return [gcs_blob_uri(bucket_name=row["bucket_name"], blob_name=row["blob_name"]) for row in reader]


class OrcidRelease(ChangefileRelease):
Expand Down Expand Up @@ -218,7 +219,8 @@ def __init__(
dataset_description: str = "The ORCID dataset and supporting tables",
table_description: str = "The ORCID dataset",
snapshot_expiry_days: int = 31,
schema_file_path: str = os.path.join(default_schema_folder(), "orcid"),
schema_file_path: str = os.path.join(default_schema_folder(), "orcid", "orcid.json"),
delete_schema_file_path: str = os.path.join(default_schema_folder(), "orcid", "orcid_delete.json"),
transfer_attempts: int = 5,
batch_size: int = 25000,
max_workers: int = os.cpu_count() * 2,
Expand Down Expand Up @@ -250,6 +252,7 @@ def __init__(
self.bq_delete_table_name = bq_delete_table_name
self.api_dataset_id = api_dataset_id
self.schema_file_path = schema_file_path
self.delete_schema_file_path = delete_schema_file_path
self.transfer_attempts = transfer_attempts
self.batch_size = batch_size
self.dataset_description = dataset_description
Expand Down Expand Up @@ -287,7 +290,9 @@ def __init__(
self.add_task(self.upload_transformed)
self.add_task(self.bq_load_main_table)
self.add_task(self.bq_load_upsert_table)
self.add_task(self.bq_load_delete_table)
self.add_task(self.bq_upsert_records)
self.add_task(self.bq_delete_records)

# Finish
self.add_task(self.add_new_dataset_release)
Expand All @@ -310,7 +315,7 @@ def make_release(self, **kwargs) -> OrcidRelease:
# Determine the modication cutoff for the new release
if is_first_run:
assert (
len(releases) == 0
len(releases) == 0 or kwargs["task"].task_id == self.cleanup.__name__
), "fetch_releases: there should be no DatasetReleases stored in the Observatory API on the first DAG run."

prev_latest_modified_record = pendulum.instance(datetime.datetime.min)
Expand All @@ -321,7 +326,7 @@ def make_release(self, **kwargs) -> OrcidRelease:
), f"fetch_releases: there should be at least 1 DatasetRelease in the Observatory API after the first DAG run"
prev_release = get_latest_dataset_release(releases, "changefile_end_date")
prev_release_end = prev_release.changefile_end_date
prev_latest_modified_record = prev_release.extra["latest_modified_record_date"]
prev_latest_modified_record = pendulum.parse(prev_release.extra["latest_modified_record_date"])

return OrcidRelease(
dag_id=self.dag_id,
Expand Down Expand Up @@ -379,7 +384,7 @@ def bq_create_main_table_snapshot(self, release: OrcidRelease, **kwargs):

expiry_date = pendulum.now().add(days=self.snapshot_expiry_days)
success = bq_snapshot(
src_table_id=self.bq_main_table_id, dst_table_id=release.bq_snapshot_table_id, expiry_date=expiry_date
src_table_id=release.bq_main_table_id, dst_table_id=release.bq_snapshot_table_id, expiry_date=expiry_date
)
set_task_state(success, self.bq_create_main_table_snapshot.__name__, release)

Expand All @@ -394,7 +399,7 @@ def create_manifests(self, release: OrcidRelease, **kwargs):
futures.append(
executor.submit(
create_orcid_batch_manifest,
orcid_directory=orcid_batch,
orcid_batch=orcid_batch,
reference_date=release.prev_latest_modified_record,
bucket=self.orcid_bucket,
bucket_prefix=self.orcid_summaries_prefix,
Expand All @@ -403,13 +408,15 @@ def create_manifests(self, release: OrcidRelease, **kwargs):
for future in as_completed(futures):
future.result()

# Open and write each batch manifest to the master manifest file
logging.info("Joining manifest files")
with open(release.master_manifest_file, "w") as f:
# Open and write each directory manifest to the main manifest file
writer = csv.DictWriter(f, fieldnames=MANIFEST_HEADER)
writer.writeheader()
for orcid_batch in orcid_batches:
with open(orcid_batch.manifest_file, "r") as df:
for line in df:
f.write(f"{line}\n")
reader = csv.DictReader(df)
writer.writerows(reader)

def download(self, release: OrcidRelease, **kwargs):
"""Reads each batch's manifest and downloads the files from the gcs bucket."""
Expand Down Expand Up @@ -507,21 +514,21 @@ def bq_load_main_table(self, release: OrcidRelease, **kwargs):
if not release.is_first_run:
logging.info(f"bq_load_main_table: skipping as the main table is only created on the first run")
return
raise Exception("asdfasd")

# Check that the number of files matches the number of blobs
# TODO: update with uri list
storage_client = storage.Client()
blobs = list(storage_client.list_blobs(self.cloud_workspace.transform_bucket, prefix=release.upsert_blob_glob))
assert len(blobs) == len(
release.upsert_files
), f"Number of blobs {len(blobs)} does not match number of files {len(release.upsert_files)}"
blobs = storage_client.list_blobs(self.cloud_workspace.transform_bucket, prefix=release.upsert_blob_glob)
# assert len(blobs) == len(
# release.upsert_files
# ), f"Number of blobs ({len(blobs)}) does not match number of files ({len(release.upsert_files)})"

success = bq_load_table(
uri=release.table_uri,
table_id=release.main_table_id,
uri=release.upsert_table_uri,
table_id=release.bq_main_table_id,
schema_file_path=self.schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
write_disposition=bigquery.WriteDisposition.WRITE_EMPTY,
ignore_unknown_values=False,
)
set_task_state(success, self.bq_load_main_table.__name__, release)
Expand All @@ -533,15 +540,16 @@ def bq_load_upsert_table(self, release: OrcidRelease, **kwargs):
return

# Check that the number of files matches the number of blobs
# TODO: update with uri list
storage_client = storage.Client()
blobs = list(storage_client.list_blobs(self.cloud_workspace.transform_bucket, prefix=release.upsert_blob_glob))
assert len(blobs) == len(
release.upsert_files
), f"Number of blobs {len(blobs)} does not match number of files {len(release.upsert_files)}"
# assert len(blobs) == len(
# release.upsert_files
# ), f"Number of blobs {len(blobs)} does not match number of files {len(release.upsert_files)}"

success = bq_load_table(
uri=release.table_uri,
table_id=release.upsert_table_id,
uri=release.upsert_table_uri,
table_id=release.bq_upsert_table_id,
schema_file_path=self.schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
Expand All @@ -554,21 +562,22 @@ def bq_load_delete_table(self, release: OrcidRelease, **kwargs):
if release.is_first_run:
logging.info(f"bq_load_delete_table: skipping as no records are deleted on the first run")
return
if not os.path.exists(release.delete_files):
logging.info(f"bq_load_delete_table: skipping as no delete file exists")
if not len(release.delete_files):
logging.info(f"bq_load_delete_table: skipping as no records require deleting")
return

# Check that the number of files matches the number of blobs
# Check that the number of files matches the number of blobs
# TODO: update with uri list
storage_client = storage.Client()
blobs = list(storage_client.list_blobs(self.cloud_workspace.transform_bucket, prefix=release.upsert_blob_glob))
assert len(blobs) == len(
release.upsert_files
), f"Number of blobs {len(blobs)} does not match number of files {len(release.upsert_files)}"
# assert len(blobs) == len(
# release.upsert_files
# ), f"Number of blobs {len(blobs)} does not match number of files {len(release.upsert_files)}"

success = bq_load_table(
uri=release.table_uri,
table_id=release.delete_table_id,
schema_file_path=self.schema_file_path,
uri=release.delete_table_uri,
table_id=release.bq_delete_table_id,
schema_file_path=self.delete_schema_file_path,
source_format=SourceFormat.NEWLINE_DELIMITED_JSON,
write_disposition=bigquery.WriteDisposition.WRITE_TRUNCATE,
ignore_unknown_values=True,
Expand All @@ -581,30 +590,27 @@ def bq_upsert_records(self, release: OrcidRelease, **kwargs):
logging.info("bq_upsert_records: skipping as no records are upserted on the first run")
return

success = bq_upsert_records(
bq_upsert_records(
main_table_id=release.bq_main_table_id,
upsert_table_id=release.bq_upsert_table_id,
primary_key="orcid_identifier.orcid",
primary_key="orcid_identifier",
)
set_task_state(success, self.bq_upsert_records.__name__, release)

def bq_delete_records(self, release: OrcidRelease, **kwargs):
"""Delete the records in the delete table from the main table."""
if release.is_first_run:
logging.info("bq_delete_records: skipping as no records are deleted on the first run")
return
if not os.path.exists(release.delete_files):
logging.info(f"bq_delete_records: skipping as no delete file exists")
if not len(release.delete_files):
logging.info(f"bq_load_delete_table: skipping as no records require deleting")
return

success = bq_delete_records(
bq_delete_records(
main_table_id=release.bq_main_table_id,
delete_table_id=release.bq_delete_table_id,
main_table_primary_key="path",
main_table_primary_key="orcid_identifier.path",
delete_table_primary_key="id",
main_table_primary_key_prefix="orcid_identifier",
)
set_task_state(success, self.bq_delete_records.__name__, release)

def add_new_dataset_release(self, release: OrcidRelease, **kwargs) -> None:
"""Adds release information to API."""
Expand Down Expand Up @@ -638,16 +644,22 @@ def create_orcid_batch_manifest(
"""
prefix = f"{bucket_prefix}/{orcid_batch.batch_str}/" if bucket_prefix else f"{orcid_batch.batch_str}/"

logging.info(f"Creating manifests for {orcid_batch}")
logging.info(f"Creating manifests for {orcid_batch.batch_str}")
blobs = gcs_list_blobs(bucket, prefix=prefix)
manifest = []
for blob in blobs:
if pendulum.instance(blob.updated) > reference_date:
orcid = re.search(ORCID_REGEX, blob.name).group(0) # Extract the orcid ID from the blob name
manifest.append([blob.bucket.name, blob.name, orcid, blob.updated])
manifest.append(
{
MANIFEST_HEADER[0]: blob.bucket.name,
MANIFEST_HEADER[1]: blob.name,
MANIFEST_HEADER[2]: blob.updated,
}
)

with open(orcid_batch.manifest_file, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer = csv.DictWriter(csvfile, fieldnames=MANIFEST_HEADER)
writer.writeheader()
writer.writerows(manifest)

logging.info(f"Manifest saved to {orcid_batch.manifest_file}")
Expand All @@ -663,9 +675,9 @@ def gsutil_download(orcid_batch: OrcidBatch) -> int:

blob_stdin = "\n".join(orcid_batch.blob_uris)
download_command = download_script.format(
log_file=orcid_batch.log_file, download_folder=orcid_batch.orcid_directory
log_file=orcid_batch.download_log_file, download_folder=orcid_batch.download_batch_dir
)
with open(orcid_batch.error_file, "w") as f:
with open(orcid_batch.download_error_file, "w") as f:
download_process = subprocess.Popen(download_command.split(" "), stdin=subprocess.PIPE, stderr=f, stdout=f)
download_process.communicate(input=blob_stdin.encode())
returncode = download_process.wait()
Expand All @@ -678,12 +690,11 @@ def latest_modified_record_date(manifest_file_path) -> pendulum.DateTime:
:param manifest_file_path: the path to the manifest file
:return: the most recent date of modification for the records
"""
if not os.path.exists(manifest_file_path):
return None
assert os.path.exists(manifest_file_path), f"Manifest file does not exist: {manifest_file_path}"

with open(manifest_file_path, "r") as f:
reader = csv.DictReader(f)
modified_dates = [pendulum.parse(row["updated"]) for row in reader].sort()
modified_dates = sorted([pendulum.parse(row["updated"]) for row in reader])
return modified_dates[-1]


Expand Down

0 comments on commit 68ef2cd

Please sign in to comment.