-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
35ac8bc
commit 8e14258
Showing
11 changed files
with
442 additions
and
185 deletions.
There are no files selected for viewing
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0001-5000-5000.xml
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0001-5001-3000.xml
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0001-5002-1000.xml
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0001-5007-2000.xml
Git LFS file not shown
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0001-5010-1000.xml
Git LFS file not shown
3 changes: 0 additions & 3 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0002-9227-8610.xml
This file was deleted.
Oops, something went wrong.
3 changes: 0 additions & 3 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0002-9228-8514.xml
This file was deleted.
Oops, something went wrong.
3 changes: 0 additions & 3 deletions
3
academic_observatory_workflows/fixtures/orcid/0000-0002-9229-8514.xml
This file was deleted.
Oops, something went wrong.
3 changes: 3 additions & 0 deletions
3
academic_observatory_workflows/fixtures/orcid/test_manifest.csv
Git LFS file not shown
408 changes: 232 additions & 176 deletions
408
academic_observatory_workflows/workflows/orcid_telescope.py
Large diffs are not rendered by default.
Oops, something went wrong.
192 changes: 192 additions & 0 deletions
192
academic_observatory_workflows/workflows/tests/test_orcid_telescope.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright 2023 Curtin University | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# Author: Keegan Smith | ||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import datetime | ||
import gzip | ||
import json | ||
import os | ||
import pathlib | ||
import csv | ||
import shutil | ||
import tempfile | ||
from typing import Dict | ||
from unittest.mock import patch | ||
|
||
import pendulum | ||
from airflow.models import Connection | ||
from airflow.utils.state import State | ||
from google.cloud import storage | ||
|
||
from academic_observatory_workflows.config import test_fixtures_folder | ||
from academic_observatory_workflows.workflows.orcid_telescope import ( | ||
OrcidBatch, | ||
OrcidRelease, | ||
OrcidTelescope, | ||
create_orcid_batch_manifest, | ||
gsutil_download, | ||
latest_modified_record_date, | ||
orcid_batch_names, | ||
gcs_list_blobs, | ||
transform_orcid_record, | ||
) | ||
from observatory.platform.api import get_dataset_releases | ||
from observatory.platform.bigquery import bq_table_id, bq_sharded_table_id | ||
from observatory.platform.files import save_jsonl_gz, load_file | ||
from observatory.platform.gcs import gcs_blob_name_from_path | ||
from observatory.platform.observatory_config import Workflow, CloudWorkspace | ||
from observatory.platform.observatory_environment import ( | ||
ObservatoryEnvironment, | ||
ObservatoryTestCase, | ||
aws_bucket_test_env, | ||
find_free_port, | ||
load_and_parse_json, | ||
random_id, | ||
) | ||
|
||
|
||
class TestOrcidUtils(ObservatoryTestCase): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.dag_id = "orcid" | ||
self.aws_key = (os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")) | ||
self.aws_region_name = os.getenv("AWS_DEFAULT_REGION") | ||
self.fixtures_folder = test_fixtures_folder("orcid") | ||
|
||
def test_orcid_batch(self): | ||
"""Test that the orcid batches are correctly constructed""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
download_dir = os.path.join(tmp_dir, "download") | ||
transform_dir = os.path.join(tmp_dir, "transform") | ||
test_batch_str = "12X" | ||
|
||
# Download/transform dirs don't exist | ||
with self.assertRaises(AssertionError): | ||
OrcidBatch(download_dir, transform_dir, test_batch_str) | ||
shutil.makedirs(download_dir) | ||
with self.assertRaises(AssertionError): | ||
OrcidBatch(download_dir, transform_dir, test_batch_str) | ||
shutil.makedirs(transform_dir) | ||
|
||
# Invalid batch string | ||
for batch_str in ["0000", "12C", "99", "XXX"]: | ||
with self.assertRaises(AssertionError): | ||
OrcidBatch(download_dir, transform_dir, batch_str) | ||
|
||
# Create a batch for testing | ||
test_batch = OrcidBatch(download_dir, transform_dir, test_batch_str) | ||
|
||
# Check that expected folders exist | ||
self.assertTrue(os.path.isdir(test_batch.download_dir)) | ||
|
||
# Check that file names are as expected | ||
self.assertEqual(test_batch.download_batch_dir, os.path.join(download_dir, test_batch_str)) | ||
self.assertEqual(test_batch.download_log_file, os.path.join(download_dir, f"{test_batch_str}_log.txt")) | ||
self.assertEqual(test_batch.download_error_file, os.path.join(download_dir, f"{test_batch_str}_error.txt")) | ||
self.assertEqual(test_batch.manifest_file, os.path.join(download_dir, f"{test_batch_str}_manifest.csv")) | ||
self.assertEqual( | ||
test_batch.transform_upsert_file, os.path.join(download_dir, f"{test_batch_str}_upsert.jsonl.gz") | ||
) | ||
self.assertEqual( | ||
test_batch.transform_delete_file, os.path.join(download_dir, f"{test_batch_str}_delete.jsonl.gz") | ||
) | ||
|
||
# Make the manifest file | ||
shutil.copy(os.path.join(self.fixtures_folder, "manifest.csv"), test_batch.manifest_file) | ||
|
||
# Check that missing, expected and existing records are correctly identified | ||
records = [ | ||
"0000-0001-5000-5000.xml", | ||
"0000-0001-5001-3000.xml", | ||
"0000-0001-5002-1000.xml", | ||
"0000-0001-5007-2000.xml", | ||
"0000-0001-5010-1000.xml", | ||
] | ||
self.assertEqual(set(test_batch.expected_records), set(records)) | ||
self.assertEqual(test_batch.existing_records, []) | ||
self.assertEqual(set(test_batch.missing_records), set(records)) | ||
for record in records: | ||
shutil.copy( | ||
os.path.join(self.fixtures_folder, record), os.path.join(test_batch.download_batch_dir, record) | ||
) | ||
self.assertEqual(set(test_batch.expected_records), set(records)) | ||
self.assertEqual(test_batch.existing_records, set(records)) | ||
self.assertEqual(test_batch.missing_records, []) | ||
|
||
# Check that the blob uris are correctly generated | ||
expected_blob_uris = [ | ||
"gs://orcid-testing/orcid_summaries/000/0000-0001-5000-5000.xml", | ||
"gs://orcid-testing/orcid_summaries/000/0000-0001-5001-3000.xml", | ||
"gs://orcid-testing/orcid_summaries/000/0000-0001-5002-1000.xml", | ||
"gs://orcid-testing/orcid_summaries/000/0000-0001-5007-2000.xml", | ||
"gs://orcid-testing/orcid_summaries/000/0000-0001-5010-1000.xml", | ||
] | ||
self.assertEqual(set(test_batch.blob_uris), set(expected_blob_uris)) | ||
|
||
def test_create_orcid_batch_manifest(self): | ||
"""Tests the create_orcid_batch_manifest function""" | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
download_dir = os.path.join(tmp_dir, "download") | ||
transform_dir = os.path.join(tmp_dir, "transform") | ||
test_batch_str = "12X" | ||
# Create a batch for testing | ||
test_batch = OrcidBatch(download_dir, transform_dir, test_batch_str) | ||
|
||
# Upload the .xml files to the test bucket | ||
client = storage.Client() | ||
bucket_id = f"orcid_test_{random_id()}" | ||
bucket = client.create_bucket(bucket_id) | ||
|
||
blob1 = storage.Blob(f"{test_batch_str}/0000-0001-5000-1000.xml", bucket) | ||
blob1.upload_from_string("Test data 1") | ||
# Make now the reference time - blob1 should be ignored | ||
reference_time = pendulum.now() | ||
blob2 = storage.Blob(f"{test_batch_str}/0000-0001-5000-2000.xml", bucket) | ||
blob2.upload_from_string("Test data 2") | ||
blob3 = storage.Blob(f"{test_batch_str}/0000-0001-5000-3000.xml", bucket) | ||
blob3.upload_from_string("Test data 3") | ||
# Put a blob in a different folder - should be ignored | ||
blob4 = storage.Blob(f"somewhere_else/{test_batch_str}/0000-0001-5000-4000.xml", bucket) | ||
blob4.upload_from_string("Test data 4") | ||
|
||
create_orcid_batch_manifest(orcid_batch=test_batch, reference_time=reference_time, bucket=bucket_id) | ||
with open(test_batch.manifest_file, "w", newline="") as csvfile: | ||
reader = csv.reader(csvfile) | ||
manifest_rows = [row for row in reader] | ||
bucket = [row[0] for row in manifest_rows] | ||
blobs = [row[1] for row in manifest_rows] | ||
orcid = [row[2] for row in manifest_rows] | ||
modification_times = [row[3] for row in manifest_rows] | ||
self.assertEqual(len(manifest_rows), 2) | ||
self.assertEqual(set(blobs), set([blob2.name, blob3.name])) | ||
self.assertEqual(set(orcid), set(["0000-0001-5000-2000", "0000-0001-5000-3000"])) | ||
self.assertEqual(set(modification_times), set([blob2.updated.isoformat(), blob3.updated.isoformat()])) | ||
|
||
|
||
class TestOrcidTelescope(ObservatoryTestCase): | ||
"""Tests for the OpenAlex telescope""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.dag_id = "orcid" | ||
self.project_id = os.getenv("TEST_GCP_PROJECT_ID") | ||
self.data_location = os.getenv("TEST_GCP_DATA_LOCATION") | ||
self.aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID") | ||
self.aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") | ||
self.aws_region_name = os.getenv("AWS_DEFAULT_REGION") |