Skip to content

Commit

Permalink
Add test and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmassen-hane committed Jun 28, 2023
1 parent 90b57b9 commit 34c2913
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 68 deletions.
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
66 changes: 41 additions & 25 deletions academic_observatory_workflows/workflows/pubmed_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

# Common
import os
import re
import gzip
import json
import math
import logging
import pendulum
from datetime import timedelta
Expand Down Expand Up @@ -333,9 +333,8 @@ def __init__(
max_download_retry: int = 5,
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API,
queue: str = "remote_queue",
snapshot_expiry_days: int = 7,
snapshot_expiry_days: int = 31,
max_processes: int = 4, # Limited to 4 due to RAM usage when transforming files
batch_size: int = 4,
):
"""Construct an PubMed Telescope instance.
Expand All @@ -352,7 +351,6 @@ def __init__(
:param queue: the queue that the tasks should run on.
:param snapshot_expiry_days: How long until the snapshot expires.
:param max_processes: Max number of parallel processors.
:param batch_size: Number of changefiles per batch.
:param bq_table_id: Name of Pubmed table.
:param table_description: Description of the main table.
"""
Expand Down Expand Up @@ -416,7 +414,6 @@ def __init__(
self.check_md5_hash = check_md5_hash
self.max_download_retry = max_download_retry
self.max_processes = max_processes
self.batch_size = batch_size

# Required file size of the update files.
self.merged_file_size = 3.8 # Gb
Expand Down Expand Up @@ -483,7 +480,7 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
for file in baseline_list_ftp:
if file.endswith(".xml.gz"): # Find all the xml.gz files available from the server.
filename = file
file_index = int(file[9:13])
file_index = int(re.findall("\d{4}", file)[0])
path_on_ftp = self.baseline_path + file
changefile = Changefile(
filename=filename,
Expand All @@ -493,6 +490,10 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
changefile_date=self.start_date,
)
files_to_download.append(changefile)

logging.info(f"List of files to download from the PubMed FTP server for 'baseline':")
for changefile in files_to_download:
logging.info(f"{changefile.filename}")
else:
logging.info(
f"Grabbing list of 'updatefiles' for this release: {data_interval_start} to {data_interval_end}"
Expand All @@ -511,9 +512,9 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
file_upload_time = ftp_conn.sendcmd("MDTM {}".format(file))[4:]
file_upload_date = pendulum.from_format(file_upload_time, "YYYYMMDDHHmmss")
if file_upload_date in pendulum.period(data_interval_start, data_interval_end):
# Grab name and metadata for this release file.
# Grab metadata and path of the file.
filename = file
file_index = int(file[9:13])
file_index = int(re.findall("\d{4}", file)[0])
path_on_ftp = self.updatefiles_path + file
changefile_date = file_upload_date

Expand Down Expand Up @@ -542,7 +543,7 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
file_index_last = changefile.file_index
else:
raise AirflowException(
f"The update files are not going to be sequential. Please investigate download {changefile.file_index} and {file_index_last+1}"
f"The updatefiles are not going to be sequential. Please investigate download {changefile.file_index} and {file_index_last+1}"
)

# Make sure the first changefile file index for this release is n + 1 ahead of the last release.
Expand Down Expand Up @@ -589,9 +590,6 @@ def make_release(self, **kwargs) -> PubMedRelease:
# Sort the incoming list.
changefile_list.sort(key=lambda c: c.file_index, reverse=False)

# limit to the first 20 files for testing
changefile_list = changefile_list[:100]

run_id = kwargs["run_id"]
dag_run = kwargs["dag_run"]
is_first_run = is_first_dag_run(dag_run)
Expand Down Expand Up @@ -775,20 +773,43 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):
for entity in self.entity_list:
merged_updatefiles[entity.name] = []

# Get the size of all the updatefiles
# Determine what files are merged together by summing up each file
# and creating chunks off of the total size.
file_size_sum = 0
temp_chunk = []
chunks = []
part_counter = 1
for changefile in files_to_merge:
transform_file = changefile.transform_file_path(entity.type)
transform_file_stats = os.stat(transform_file)
transform_file_size = transform_file_stats.st_size / 1024.0**3
file_size_sum += transform_file_size

logging.info(f"Total size of updatefiles for {entity.type} for this release: {file_size_sum} ")
if (
changefile.file_index == release.changefile_list[-1].file_index
or file_size_sum + transform_file_size > self.merged_file_size
):
# If last in the list, still needs to be added to the chunks to be merged.
if changefile.file_index == release.changefile_list[-1].file_index:
temp_chunk.append(changefile)
file_size_sum += transform_file_size

# Start a new chunk as this one fits the size requirement.
chunks.append(temp_chunk)
logging.info(
f"Rough file size of part {part_counter} = {format(file_size_sum, '.2f')} Gb for changefiles {temp_chunk[0].file_index} to {temp_chunk[-1].file_index}"
)

num_chunks = math.ceil(file_size_sum / self.merged_file_size)
# Reset variables
file_size_sum = 0
temp_chunk = [changefile]
file_size_sum = transform_file_size
part_counter += 1

logging.info(f"Aproximate size of each merged: {file_size_sum/num_chunks} Gb")
else:
temp_chunk.append(changefile)
file_size_sum += transform_file_size

num_chunks = len(chunks)
if num_chunks == 1:
logging.info(f"There were will be 1 part for the merged updatefiles.")

Expand All @@ -798,11 +819,6 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):
logging.info(f"Successfully merged updatefiles to - {merged_updatefile_path}")

else:
chunk_size = math.floor(len(files_to_merge) / num_chunks)
chunks = [
chunk for i, chunk in enumerate(get_chunks(input_list=files_to_merge, chunk_size=chunk_size))
]

if num_chunks > self.max_processes:
processes_to_use = self.max_processes

Expand All @@ -811,10 +827,10 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):

logging.info(f"There were will be {len(chunks)} parts for the merged updatefiles.")

# Multiple output for merged files, do in parallel.
# Multiple outputs for merged files, do in parallel.
for j, sub_chunks in enumerate(get_chunks(input_list=chunks, chunk_size=processes_to_use)):
# Pass off each chunk to a process for them to merge files in parallel.
with ProcessPoolExecutor(max_workers=processes_to_use) as executor:
with ProcessPoolExecutor(max_workers=len(sub_chunks)) as executor:
futures = []
for i, chunk in enumerate(sub_chunks):
futures.append(
Expand Down Expand Up @@ -1053,7 +1069,7 @@ def bq_add_updates_to_table(self, release: PubMedRelease, **kwargs):
backup_table_id = bq_sharded_table_id(
self.cloud_workspace.project_id, self.bq_dataset_id, f"{self.bq_table_id}_backup", prev_end_date
)
expiry_date = pendulum.now().add(days=31)
expiry_date = pendulum.now().add(days=self.snapshot_expiry_days)
success = bq_snapshot(src_table_id=full_table_id, dst_table_id=backup_table_id, expiry_date=expiry_date)

set_task_state(success, kwargs["ti"].task_id, release=release)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

from __future__ import annotations

import os
from datetime import timedelta
from typing import Dict, List

import pendulum
import os
import vcr
import json
import pendulum
from typing import Dict, List
from datetime import timedelta
from airflow.utils.state import State

from academic_observatory_workflows.config import test_fixtures_folder
Expand Down Expand Up @@ -205,6 +206,7 @@ def test_dag_structure(self):
"orcid_sensor": ["check_dependencies"],
"crossref_events_sensor": ["check_dependencies"],
"openalex_sensor": ["check_dependencies"],
"pubmed_sensor": ["check_dependencies"],
"check_dependencies": ["create_datasets"],
"create_datasets": ["create_repo_institution_to_ror_table"],
"create_repo_institution_to_ror_table": ["create_ror_hierarchy_table"],
Expand All @@ -217,6 +219,7 @@ def test_dag_structure(self):
"create_open_citations",
"create_unpaywall",
"create_openalex",
"create_pubmed",
],
"create_crossref_events": ["create_doi"],
"create_crossref_fundref": ["create_doi"],
Expand All @@ -226,6 +229,7 @@ def test_dag_structure(self):
"create_open_citations": ["create_doi"],
"create_unpaywall": ["create_doi"],
"create_openalex": ["create_doi"],
"create_pubmed": ["create_doi"],
"create_doi": ["create_book"],
"create_book": [
"create_country",
Expand Down Expand Up @@ -335,6 +339,7 @@ def test_telescope(self):
dataset_id_observatory=bq_observatory_dataset_id,
dataset_id_observatory_intermediate=bq_intermediate_dataset_id,
dataset_id_openalex=fake_dataset_id,
dataset_id_pubmed=fake_dataset_id,
)
transforms, transform_doi, transform_book = dataset_transforms

Expand Down Expand Up @@ -494,6 +499,14 @@ def test_telescope(self):
table_id = bq_sharded_table_id(self.project_id, bq_observatory_dataset_id, "doi", snapshot_date)
actual_output = query_table(table_id, "doi")

with open("/home/alexmassen-hane/doi_table_expected.jsonl", "w") as f_out:
for line in expected_output:
f_out.write(json.dumps(line) + "\n")

with open("/home/alexmassen-hane/doi_table_actual.jsonl", "w") as f_out:
for line in actual_output:
f_out.write(json.dumps(line) + "\n")

self.assert_doi(expected_output, actual_output)

# Test create book
Expand Down Expand Up @@ -675,6 +688,9 @@ def assert_doi(self, expected: List[Dict], actual: List[Dict]):
# Check affiliations
self.assert_doi_affiliations(expected_record["affiliations"], actual_record["affiliations"])

# Check that Pubmed matches
self.assertEqual(expected_record["pubmed"], actual_record["pubmed"])

def assert_doi_events(self, expected: Dict, actual: Dict):
"""Assert the DOI table events field.
Expand Down
Loading

0 comments on commit 34c2913

Please sign in to comment.