Skip to content

Commit

Permalink
Add test and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmassen-hane committed Jun 29, 2023
1 parent 90b57b9 commit 54b5db9
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 64 deletions.
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
66 changes: 41 additions & 25 deletions academic_observatory_workflows/workflows/pubmed_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

# Common
import os
import re
import gzip
import json
import math
import logging
import pendulum
from datetime import timedelta
Expand Down Expand Up @@ -333,9 +333,8 @@ def __init__(
max_download_retry: int = 5,
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API,
queue: str = "remote_queue",
snapshot_expiry_days: int = 7,
snapshot_expiry_days: int = 31,
max_processes: int = 4, # Limited to 4 due to RAM usage when transforming files
batch_size: int = 4,
):
"""Construct an PubMed Telescope instance.
Expand All @@ -352,7 +351,6 @@ def __init__(
:param queue: the queue that the tasks should run on.
:param snapshot_expiry_days: How long until the snapshot expires.
:param max_processes: Max number of parallel processors.
:param batch_size: Number of changefiles per batch.
:param bq_table_id: Name of Pubmed table.
:param table_description: Description of the main table.
"""
Expand Down Expand Up @@ -416,7 +414,6 @@ def __init__(
self.check_md5_hash = check_md5_hash
self.max_download_retry = max_download_retry
self.max_processes = max_processes
self.batch_size = batch_size

# Required file size of the update files.
self.merged_file_size = 3.8 # Gb
Expand Down Expand Up @@ -483,7 +480,7 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
for file in baseline_list_ftp:
if file.endswith(".xml.gz"): # Find all the xml.gz files available from the server.
filename = file
file_index = int(file[9:13])
file_index = int(re.findall("\d{4}", file)[0])
path_on_ftp = self.baseline_path + file
changefile = Changefile(
filename=filename,
Expand All @@ -493,6 +490,10 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
changefile_date=self.start_date,
)
files_to_download.append(changefile)

logging.info(f"List of files to download from the PubMed FTP server for 'baseline':")
for changefile in files_to_download:
logging.info(f"{changefile.filename}")
else:
logging.info(
f"Grabbing list of 'updatefiles' for this release: {data_interval_start} to {data_interval_end}"
Expand All @@ -511,9 +512,9 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
file_upload_time = ftp_conn.sendcmd("MDTM {}".format(file))[4:]
file_upload_date = pendulum.from_format(file_upload_time, "YYYYMMDDHHmmss")
if file_upload_date in pendulum.period(data_interval_start, data_interval_end):
# Grab name and metadata for this release file.
# Grab metadata and path of the file.
filename = file
file_index = int(file[9:13])
file_index = int(re.findall("\d{4}", file)[0])
path_on_ftp = self.updatefiles_path + file
changefile_date = file_upload_date

Expand Down Expand Up @@ -542,7 +543,7 @@ def list_changefiles_for_release(self, **kwargs) -> bool:
file_index_last = changefile.file_index
else:
raise AirflowException(

Check warning on line 545 in academic_observatory_workflows/workflows/pubmed_telescope.py

View check run for this annotation

Codecov / codecov/patch

academic_observatory_workflows/workflows/pubmed_telescope.py#L545

Added line #L545 was not covered by tests
f"The update files are not going to be sequential. Please investigate download {changefile.file_index} and {file_index_last+1}"
f"The updatefiles are not going to be sequential. Please investigate download {changefile.file_index} and {file_index_last+1}"
)

# Make sure the first changefile file index for this release is n + 1 ahead of the last release.
Expand Down Expand Up @@ -589,9 +590,6 @@ def make_release(self, **kwargs) -> PubMedRelease:
# Sort the incoming list.
changefile_list.sort(key=lambda c: c.file_index, reverse=False)

# limit to the first 20 files for testing
changefile_list = changefile_list[:100]

run_id = kwargs["run_id"]
dag_run = kwargs["dag_run"]
is_first_run = is_first_dag_run(dag_run)
Expand Down Expand Up @@ -775,20 +773,43 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):
for entity in self.entity_list:
merged_updatefiles[entity.name] = []

# Get the size of all the updatefiles
# Determine what files are merged together by summing up each file
# and creating chunks off of the total size.
file_size_sum = 0
temp_chunk = []
chunks = []
part_counter = 1
for changefile in files_to_merge:
transform_file = changefile.transform_file_path(entity.type)
transform_file_stats = os.stat(transform_file)
transform_file_size = transform_file_stats.st_size / 1024.0**3
file_size_sum += transform_file_size

logging.info(f"Total size of updatefiles for {entity.type} for this release: {file_size_sum} ")
if (
changefile.file_index == release.changefile_list[-1].file_index
or file_size_sum + transform_file_size > self.merged_file_size
):
# If last in the list, still needs to be added to the chunks to be merged.
if changefile.file_index == release.changefile_list[-1].file_index:
temp_chunk.append(changefile)
file_size_sum += transform_file_size

# Start a new chunk as this one fits the size requirement.
chunks.append(temp_chunk)
logging.info(
f"Rough file size of part {part_counter} = {format(file_size_sum, '.2f')} Gb for changefiles {temp_chunk[0].file_index} to {temp_chunk[-1].file_index}"
)

num_chunks = math.ceil(file_size_sum / self.merged_file_size)
# Reset variables
file_size_sum = 0
temp_chunk = [changefile]
file_size_sum = transform_file_size
part_counter += 1

logging.info(f"Aproximate size of each merged: {file_size_sum/num_chunks} Gb")
else:
temp_chunk.append(changefile)
file_size_sum += transform_file_size

num_chunks = len(chunks)
if num_chunks == 1:
logging.info(f"There were will be 1 part for the merged updatefiles.")

Expand All @@ -798,11 +819,6 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):
logging.info(f"Successfully merged updatefiles to - {merged_updatefile_path}")

else:
chunk_size = math.floor(len(files_to_merge) / num_chunks)
chunks = [
chunk for i, chunk in enumerate(get_chunks(input_list=files_to_merge, chunk_size=chunk_size))
]

if num_chunks > self.max_processes:
processes_to_use = self.max_processes

Check warning on line 823 in academic_observatory_workflows/workflows/pubmed_telescope.py

View check run for this annotation

Codecov / codecov/patch

academic_observatory_workflows/workflows/pubmed_telescope.py#L823

Added line #L823 was not covered by tests

Expand All @@ -811,10 +827,10 @@ def merge_updatefiles(self, release: PubMedRelease, **kwargs):

logging.info(f"There were will be {len(chunks)} parts for the merged updatefiles.")

Check warning on line 828 in academic_observatory_workflows/workflows/pubmed_telescope.py

View check run for this annotation

Codecov / codecov/patch

academic_observatory_workflows/workflows/pubmed_telescope.py#L828

Added line #L828 was not covered by tests

# Multiple output for merged files, do in parallel.
# Multiple outputs for merged files, do in parallel.
for j, sub_chunks in enumerate(get_chunks(input_list=chunks, chunk_size=processes_to_use)):
# Pass off each chunk to a process for them to merge files in parallel.
with ProcessPoolExecutor(max_workers=processes_to_use) as executor:
with ProcessPoolExecutor(max_workers=len(sub_chunks)) as executor:
futures = []

Check warning on line 834 in academic_observatory_workflows/workflows/pubmed_telescope.py

View check run for this annotation

Codecov / codecov/patch

academic_observatory_workflows/workflows/pubmed_telescope.py#L833-L834

Added lines #L833 - L834 were not covered by tests
for i, chunk in enumerate(sub_chunks):
futures.append(

Check warning on line 836 in academic_observatory_workflows/workflows/pubmed_telescope.py

View check run for this annotation

Codecov / codecov/patch

academic_observatory_workflows/workflows/pubmed_telescope.py#L836

Added line #L836 was not covered by tests
Expand Down Expand Up @@ -1053,7 +1069,7 @@ def bq_add_updates_to_table(self, release: PubMedRelease, **kwargs):
backup_table_id = bq_sharded_table_id(
self.cloud_workspace.project_id, self.bq_dataset_id, f"{self.bq_table_id}_backup", prev_end_date
)
expiry_date = pendulum.now().add(days=31)
expiry_date = pendulum.now().add(days=self.snapshot_expiry_days)
success = bq_snapshot(src_table_id=full_table_id, dst_table_id=backup_table_id, expiry_date=expiry_date)

set_task_state(success, kwargs["ti"].task_id, release=release)
Expand Down
147 changes: 108 additions & 39 deletions academic_observatory_workflows/workflows/tests/test_pubmed_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pendulum
import datetime
from ftplib import FTP
from typing import List, Dict
from click.testing import CliRunner
from airflow.utils.state import State

Expand All @@ -32,7 +33,7 @@
from Bio.Entrez.Parser import StringElement, ListElement, DictionaryElement
from observatory.platform.gcs import gcs_blob_name_from_path, gcs_download_blob
from observatory.platform.observatory_environment import ObservatoryEnvironment, ObservatoryTestCase
from observatory.platform.bigquery import bq_sharded_table_id, bq_create_table_from_query, bq_export_table
from observatory.platform.bigquery import bq_run_query, bq_sharded_table_id, bq_create_table_from_query, bq_export_table
from observatory.platform.observatory_environment import (
ObservatoryEnvironment,
ObservatoryTestCase,
Expand All @@ -48,10 +49,25 @@
PubMedTelescope,
PubmedEntity,
add_attributes_to_data_from_biopython_classes,
merge_changefiles_together,
transform_pubmed_xml_file_to_jsonl,
)


def query_table(table_id: str, select_columns: str, order_by_field: str) -> List[Dict]:
"""Query a BigQuery table, sorting the results and returning results as a list of dicts.
:param table_id: the table id.
:param select_columns: Columns to pull from the table.
:param order_by_field: what field or fields to order by.
:return: the table rows.
"""

return [
dict(row) for row in bq_run_query(f"SELECT {select_columns} FROM {table_id} ORDER BY {order_by_field} ASC;")
]


class TestPubMedTelescope(ObservatoryTestCase):
"""Tests for the Pubmed telescope"""

Expand Down Expand Up @@ -375,26 +391,12 @@ def test_telescope(self):
self.assert_table_integrity(main_table_id, 4)

# Run query to get list of PMIDs that are present in the table and compare against what it should be.
PMID_list = f"{env.project_id}.{workflow.bq_dataset_id}.{workflow.bq_table_id}_PMID_list_first_run"
bq_query_list_PMIDs = f"""
SELECT (MedlineCitation.PMID.Version, MedlineCitation.PMID.value)
FROM `{main_table_id}`
ORDER BY MedlineCitation.PMID.value
"""
destination_uri = f"gs://{env.transform_bucket}/PMID_list_first_run.jsonl"
PMID_list_path = os.path.join(release.transform_folder, "PMID_list_first_run.jsonl")
bq_create_table_from_query(sql=bq_query_list_PMIDs, table_id=PMID_list)
bq_export_table(table_id=PMID_list, file_type="jsonl", destination_uri=destination_uri)
gcs_download_blob(
bucket_name=env.transform_bucket,
blob_name="PMID_list_first_run.jsonl",
file_path=PMID_list_path,
actual_output = query_table(
main_table_id,
"(MedlineCitation.PMID.Version, MedlineCitation.PMID.value)",
"MedlineCitation.PMID.value",
)
logging.info(f"Downloaded table to: {PMID_list_path}")
with open(PMID_list_path, "rb") as f_in:
PMID_list = [json.loads(line) for line in f_in]

self.assertEqual(PMID_list, run["PMID_list"])
self.assertEqual(actual_output, run["PMID_list"])

### add_new_dataset_release ###
task_id = workflow.add_new_dataset_release.__name__
Expand Down Expand Up @@ -566,26 +568,12 @@ def test_telescope(self):
self.assert_table_integrity(main_table_id, 5)

# Run query to get list of PMIDs that are present in the table and compare against what it should be.
PMID_list = f"{env.project_id}.{workflow.bq_dataset_id}.{workflow.bq_table_id}_PMID_list_second_run"
bq_query_list_PMIDs = f"""
SELECT (MedlineCitation.PMID.Version, MedlineCitation.PMID.value)
FROM `{main_table_id}`
ORDER BY MedlineCitation.PMID.value
"""
destination_uri = f"gs://{env.transform_bucket}/PMID_list_second_run.jsonl"
PMID_list_path = os.path.join(release.transform_folder, "PMID_list_second_run.jsonl")
bq_create_table_from_query(sql=bq_query_list_PMIDs, table_id=PMID_list)
bq_export_table(table_id=PMID_list, file_type="jsonl", destination_uri=destination_uri)
gcs_download_blob(
bucket_name=env.transform_bucket,
blob_name="PMID_list_second_run.jsonl",
file_path=PMID_list_path,
actual_output = query_table(
main_table_id,
"(MedlineCitation.PMID.Version, MedlineCitation.PMID.value)",
"MedlineCitation.PMID.value",
)
logging.info(f"Downloaded table to: {PMID_list_path}")
with open(PMID_list_path, "rb") as f_in:
PMID_list = [json.loads(line) for line in f_in]

self.assertEqual(PMID_list, run["PMID_list"])
self.assertEqual(actual_output, run["PMID_list"])

### add_new_dataset_release ###
task_id = workflow.add_new_dataset_release.__name__
Expand Down Expand Up @@ -694,6 +682,87 @@ def test_transform_pubmed_xml_file_to_jsonl(self):
for entity in entity_list:
self.assertTrue(os.path.exists(changefile_returned.transform_file_path(entity.type)))

def test_merge_changefiles_together(
self,
):
"""Test that *.jsonl.gz files can be reliably merged into one or more files."""

expected_hash = {
"article_additions": "8823ea43ca4619175d21dad430a03826",
"article_deletions": "c8b6f684ad613e1e8be46022afc83916",
}

# Setup environment
env = ObservatoryEnvironment(self.project_id, self.data_location, api_port=find_free_port())

with env.create(task_logging=True):
changefile_release = ChangefileRelease(
dag_id="pubmed_telescope",
run_id="something",
start_date=pendulum.now(),
end_date=pendulum.now(),
sequence_start=1,
sequence_end=1,
)

entity_list = [
PubmedEntity(
name="article_additions",
type="additions",
sub_key="PubmedArticle",
set_key="PubmedArticleSet",
pmid_location="MedlineCitation",
table_description="""PubmedArticle""",
),
PubmedEntity(
name="article_deletions",
type="deletions",
sub_key="PMID",
set_key="DeleteCitation",
pmid_location=None,
table_description="""DeleteCitation""",
),
]

# Changefiles to merge
changefile_list = [
Changefile(
filename="pubmed23n0003.xml.gz",
file_index=3,
path_on_ftp="dummy_string",
is_first_run=False,
changefile_date=pendulum.now(),
changefile_release=changefile_release,
),
Changefile(
filename="pubmed23n0004.xml.gz",
file_index=4,
path_on_ftp="dummy_string",
is_first_run=False,
changefile_date=pendulum.now(),
changefile_release=changefile_release,
),
]

# Perform merge step on the tranformed files.
for entity in entity_list:
# Copy test files into temp test directory
for changefile in changefile_list:
copy_path = os.path.join(
test_fixtures_folder(),
"pubmed",
"other",
f"test_{entity.type}_{changefile.filename.split('.')[0]}.jsonl.gz",
)
shutil.copy2(copy_path, changefile.transform_file_path(entity.type))

output_file = merge_changefiles_together(changefile_list, part_num=1, entity_type=entity.type)

# Check against expected hash for the files.
self.assertEqual(
hashlib.md5(gzip.open(output_file, "rb").read()).hexdigest(), expected_hash[entity.name]
)

def test_add_attributes_to_data_from_biopython(self):
"""
Test that attributes from the Biopython data classes can be reliably pulled out and added to the dictionary.
Expand Down

0 comments on commit 54b5db9

Please sign in to comment.