Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog committed May 8, 2023
1 parent d160676 commit c05921c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Author: Tuan Chien, James Diprose

import datetime
import os
from typing import List
from unittest.mock import patch
Expand Down Expand Up @@ -282,6 +283,7 @@ def test_telescope(self):
changefile_date,
)
],
prev_end_date=pendulum.instance(datetime.datetime.min),
)

# Wait for the previous DAG run to finish
Expand All @@ -308,13 +310,16 @@ def test_telescope(self):
task_ids=task_id,
include_prior_dates=False,
)
actual_snapshot_date, actual_changefiles, actual_is_first_run = parse_release_msg(msg)
actual_snapshot_date, actual_changefiles, actual_is_first_run, actual_prev_end_date = parse_release_msg(
msg
)
self.assertEqual(snapshot_date, actual_snapshot_date)
self.assertListEqual(
release.changefiles,
actual_changefiles,
)
self.assertTrue(actual_is_first_run)
self.assertEqual(pendulum.instance(datetime.datetime.min), actual_prev_end_date)

# Create datasets
ti = env.run_task(workflow.create_datasets.__name__)
Expand Down Expand Up @@ -489,6 +494,7 @@ def test_telescope(self):
self.assertEqual(len(dataset_releases), 1)

# Third run: waiting a couple of days and applying multiple changefiles
prev_end_date = pendulum.datetime(2023, 4, 25, 8, 0, 1)
data_interval_start = pendulum.datetime(2023, 4, 27)
changefile_start_date = pendulum.datetime(2023, 4, 26, 8, 0, 1)
changefile_end_date = pendulum.datetime(2023, 4, 27, 8, 0, 1)
Expand All @@ -511,6 +517,7 @@ def test_telescope(self):
changefile_start_date,
),
],
prev_end_date=prev_end_date,
)

# Fetch releases and check that we have received the expected snapshot date and changefiles
Expand All @@ -535,7 +542,7 @@ def test_telescope(self):
workflow.cloud_workspace.output_project_id,
workflow.bq_dataset_id,
f"{workflow.bq_table_name}_snapshot",
release.changefile_release.changefile_end_date,
prev_end_date,
)
self.assert_table_integrity(dst_table_id, expected_rows=10)

Expand Down
20 changes: 15 additions & 5 deletions academic_observatory_workflows/workflows/unpaywall_telescope.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import datetime
import logging
import os
import re
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
changefile_start_date: pendulum.DateTime,
changefile_end_date: pendulum.DateTime,
changefiles: List[Changefile],
prev_end_date: pendulum.DateTime,
):
"""Construct an UnpaywallRelease instance
Expand All @@ -131,6 +133,8 @@ def __init__(
:param snapshot_date: the date of the Unpaywall snapshot.
:param changefile_start_date: the start date of the Unpaywall changefiles processed in this release.
:param changefile_end_date: the end date of the Unpaywall changefiles processed in this release.
:param changefiles: changefiles.
:param prev_end_date: the previous end date.
"""

super().__init__(
Expand All @@ -148,6 +152,7 @@ def __init__(
self.changefiles = changefiles
for changefile in changefiles:
changefile.changefile_release = self.changefile_release
self.prev_end_date = prev_end_date

# Paths used during processing
self.snapshot_download_file_path = os.path.join(
Expand Down Expand Up @@ -289,6 +294,7 @@ def fetch_releases(self, **kwargs) -> bool:
logging.info(f"fetch_releases: {len(all_changefiles)} JSONL changefiles discovered")
changefiles = []
is_first_run = is_first_dag_run(dag_run)
prev_end_date = pendulum.instance(datetime.datetime.min)

if is_first_run:
assert (
Expand Down Expand Up @@ -319,9 +325,9 @@ def fetch_releases(self, **kwargs) -> bool:
# On subsequent runs, fetch changefiles from after the previous changefile date
prev_release = get_latest_dataset_release(releases, "changefile_end_date")
snapshot_date = prev_release.snapshot_date # so that we can easily see what snapshot is being used
prev_changefile_date = prev_release.changefile_end_date
prev_end_date = prev_release.changefile_end_date
for changefile in all_changefiles:
if prev_changefile_date < changefile.changefile_date:
if prev_end_date < changefile.changefile_date:
changefiles.append(changefile)

# Sort from oldest to newest
Expand All @@ -338,13 +344,15 @@ def fetch_releases(self, **kwargs) -> bool:
logging.info(f"is_first_run: {is_first_run}")
logging.info(f"snapshot_date: {snapshot_date}")
logging.info(f"changefiles: {changefiles}")
logging.info(f"prev_end_date: {prev_end_date}")

# Publish release information
ti: TaskInstance = kwargs["ti"]
msg = dict(
snapshot_date=snapshot_date.isoformat(),
changefiles=changefiles,
is_first_run=is_first_run,
prev_end_date=prev_end_date.isoformat(),
)
ti.xcom_push(UnpaywallTelescope.RELEASE_INFO, msg, kwargs["logical_date"])

Expand All @@ -357,7 +365,7 @@ def make_release(self, **kwargs) -> UnpaywallRelease:
msg = ti.xcom_pull(
key=UnpaywallTelescope.RELEASE_INFO, task_ids=self.fetch_releases.__name__, include_prior_dates=False
)
snapshot_date, changefiles, is_first_run = parse_release_msg(msg)
snapshot_date, changefiles, is_first_run, prev_end_date = parse_release_msg(msg)
run_id = kwargs["run_id"]

# The first changefile is the oldest and the last one is the newest
Expand All @@ -372,6 +380,7 @@ def make_release(self, **kwargs) -> UnpaywallRelease:
changefile_start_date=changefile_start_date,
changefile_end_date=changefile_end_date,
changefiles=changefiles,
prev_end_date=prev_end_date,
)

# Set changefile_release
Expand Down Expand Up @@ -402,7 +411,7 @@ def bq_create_main_table_snapshot(self, release: UnpaywallRelease, **kwargs) ->
self.cloud_workspace.output_project_id,
self.bq_dataset_id,
f"{self.bq_table_name}_snapshot",
release.changefile_release.changefile_end_date,
release.prev_end_date,
)
expiry_date = pendulum.now().add(days=self.snapshot_expiry_days)
success = bq_snapshot(src_table_id=self.bq_main_table_id, dst_table_id=dst_table_id, expiry_date=expiry_date)
Expand Down Expand Up @@ -649,8 +658,9 @@ def parse_release_msg(msg: Dict):
snapshot_date = pendulum.parse(msg["snapshot_date"])
changefiles = [Changefile.from_dict(changefile) for changefile in msg["changefiles"]]
is_first_run = msg["is_first_run"]
prev_end_date = pendulum.parse(msg["prev_end_date"])

return snapshot_date, changefiles, is_first_run
return snapshot_date, changefiles, is_first_run, prev_end_date


def snapshot_url(api_key: str) -> str:
Expand Down

0 comments on commit c05921c

Please sign in to comment.