-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5019bbc
commit e927830
Showing
2 changed files
with
316 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
# Copyright 2023 Curtin University | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Author: Alex Massen-HSane | ||
|
||
|
||
from __future__ import annotations | ||
|
||
import copy | ||
import json | ||
import logging | ||
import os | ||
from concurrent.futures import ThreadPoolExecutor, as_completed | ||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from typing import Dict, List, Set, Optional, Tuple, Union | ||
|
||
import pendulum | ||
from airflow.exceptions import AirflowException | ||
|
||
from academic_observatory_workflows.config import sql_folder, Tag | ||
from observatory.api.client.model.dataset_release import DatasetRelease | ||
from observatory.platform.api import make_observatory_api | ||
from observatory.platform.bigquery import ( | ||
bq_sharded_table_id, | ||
bq_create_dataset, | ||
bq_create_table_from_query, | ||
bq_create_view, | ||
bq_select_table_shard_dates, | ||
bq_run_query, | ||
bq_table_id, | ||
bq_select_latest_table, | ||
bq_load_from_memory, | ||
bq_update_table_description, | ||
) | ||
from observatory.platform.config import AirflowConns | ||
from observatory.platform.observatory_config import CloudWorkspace | ||
from observatory.platform.utils.dag_run_sensor import DagRunSensor | ||
from observatory.platform.utils.jinja2_utils import ( | ||
make_sql_jinja2_filename, | ||
render_template, | ||
) | ||
from observatory.platform.workflows.workflow import Workflow, make_snapshot_date, set_task_state, Release | ||
|
||
|
||
# @dataclass | ||
# class Dataset: | ||
# def __init__(self, *, table_id: str, source_dataset: str, sharded_table: bool, primary_key: str): | ||
# """Create a metadata class for each of the tables to be produced. | ||
|
||
# There will be one one table for each dataset made. | ||
|
||
# Each table will hold the QA information from each table, e.g. | ||
# """ | ||
# self.table_id = table_id | ||
# self.source_dataset = source_dataset | ||
# self.sharded_table = sharded_table | ||
# self.primary_key = primary_key | ||
class QA_CheckRelease(Release): | ||
|
||
def __init__(self, *, | ||
# list of tables that were processed. | ||
# what operations were done for this release. | ||
|
||
|
||
|
||
): | ||
|
||
self. | ||
|
||
@dataclass | ||
class Table: | ||
def __init__( | ||
self, | ||
*, | ||
table_id: str, | ||
sharded: bool, | ||
primary_key_id_loc: Union[List[str], str], | ||
): | ||
"""Create a metadata class for each of the tables to be produced. | ||
There will be one one table for each dataset made. | ||
Each table will hold the QA information from each table, e.g. | ||
:param table_id: The table name (not the full table name) | ||
:param source_dataset: Where the table is from. | ||
:param sharded: True if the table is shared or not. | ||
:param primary_key_id_loc: Location of where the primary key is located in the table e.g. MedlineCiation.PMID.value, could be multiple different identifiers. | ||
""" | ||
self.table_id = table_id | ||
self.sharded = sharded | ||
self.primary_key_id_loc = [primary_key_id_loc] if isinstance(primary_key_id_loc, str) else primary_key_id_loc | ||
|
||
|
||
class QA_Check_Workflow(Workflow): | ||
SENSOR_DAG_IDS = [ | ||
"crossref_metadata", | ||
"crossref_fundref", | ||
"geonames", | ||
"ror", | ||
"open_citations", | ||
"unpaywall", | ||
"orcid", | ||
"crossref_events", | ||
"openalex", | ||
"pubmed", | ||
"doi_workflow", | ||
"grid", | ||
] | ||
|
||
def __init__( | ||
self, | ||
*, | ||
dag_id: str, | ||
cloud_workspace: CloudWorkspace, | ||
bq_dataset_id: str = "qa_checks", | ||
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API, | ||
start_date: Optional[pendulum.DateTime] = pendulum.datetime(2020, 1, 1), | ||
schedule_interval: Optional[str] = "@weekly", | ||
sensor_dag_ids: List[str] = None, | ||
): | ||
"""Create the DoiWorkflow. | ||
:param dag_id: the DAG ID. | ||
:param cloud_workspace: the cloud workspace settings. | ||
:param bq_dataset_id: | ||
:param observatory_api_conn_id: COnnection ID for the observatory API. | ||
:param start_date: the start date. | ||
:param schedule_interval: the schedule interval. | ||
""" | ||
|
||
super().__init__( | ||
dag_id=dag_id, | ||
start_date=start_date, | ||
schedule_interval=schedule_interval, | ||
catchup=False, | ||
airflow_conns=[observatory_api_conn_id], | ||
tags=[Tag.academic_observatory], | ||
) | ||
|
||
self.bq_dataset_id = bq_dataset_id | ||
self.data_location = cloud_workspace.data_location | ||
|
||
self.qa_check_template_path = "academic_observatory_workflows/database/sql/create_qa_table.sql.jinja2" | ||
|
||
self.sensor_dag_ids = sensor_dag_ids | ||
if sensor_dag_ids is None: | ||
self.sensor_dag_ids = QA_Check_Workflow.SENSOR_DAG_IDS | ||
|
||
# List of all datasets to go through and produce QA tables of | ||
|
||
# main IDs to take note of: | ||
# doi, PMID, | ||
self.workflows_to_check = { | ||
"pubmed": [ | ||
Table( | ||
table_id="pubmed", | ||
sharded=False, | ||
primary_key_location=["MedlineCitation.PMID.value", "MedlineCitation.PMID.Version"], | ||
) | ||
], | ||
"openalex": [ | ||
Table(table_id="authors", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
Table(table_id="concepts", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
Table(table_id="funders", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
Table( | ||
table_id="institutions", source_dataset="openalex", sharded=False, primary_key_location="ids.doi" | ||
), | ||
Table(table_id="publishers", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
Table(table_id="sources", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
Table(table_id="works", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"), | ||
], | ||
"doi_workflow": [ | ||
Table(table_id="author", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="book", source_dataset="observatory", sharded=True, primary_key_location="isbn"), | ||
Table(table_id="country", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="doi", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="funder", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="group", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="institution", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="journal", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="publisher", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="region", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
Table(table_id="subregion", source_dataset="observatory", sharded=True, primary_key_location="id"), | ||
], | ||
"ror": [Table(table_id="ror", sharded=True, primary_key_id_loc="id")], | ||
} | ||
self.observatory_api_conn_id = observatory_api_conn_id | ||
self.input_table_id_tasks = [] | ||
|
||
self.add_task(self.create_qa_check_dataset) | ||
self.create_tasks() | ||
|
||
def create_tasks(self): | ||
# Add sensors | ||
with self.parallel_tasks(): | ||
for ext_dag_id in self.sensor_dag_ids: | ||
sensor = DagRunSensor( | ||
task_id=f"{ext_dag_id}_sensor", | ||
external_dag_id=ext_dag_id, | ||
mode="reschedule", | ||
duration=timedelta(days=7), # Look back up to 7 days from execution date | ||
poke_interval=int(timedelta(hours=1).total_seconds()), # Check at this interval if dag run is ready | ||
timeout=int(timedelta(days=3).total_seconds()), # Sensor will fail after 3 days of waiting | ||
) | ||
self.add_operator(sensor) | ||
|
||
# Setup tasks | ||
self.add_setup_task(self.check_dependencies) | ||
|
||
# Create tasks creating the QA Metadata tables | ||
self.input_table_task_ids = [] | ||
with self.parallel_tasks(): | ||
for workflow, table_list in self.workflows_to_check: | ||
task_id = f"qa_check_{workflow}" | ||
self.add_task( | ||
self.qa_check_dataset, | ||
op_kwargs={"qa_check_workflow": workflow, "task_id": task_id, "table_list": table_list}, | ||
task_id=task_id, | ||
) | ||
self.input_table_task_ids.append(task_id) | ||
|
||
def make_release(self, **kwargs) -> Release: | ||
"""Make a release instance. The release is passed as an argument to the function (TelescopeFunction) that is | ||
called in 'task_callable'. | ||
:param kwargs: the context passed from the PythonOperator. See | ||
https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed | ||
to this argument. | ||
:return: A release instance or list of release instances | ||
""" | ||
|
||
snapshot_date = make_snapshot_date(**kwargs) | ||
return Release( | ||
dag_id=self.dag_id, | ||
run_id=kwargs["run_id"], | ||
snapshot_date=snapshot_date, | ||
) | ||
|
||
def create_qa_check_dataset(self, release: QA_CheckRelease, **kwargs): | ||
"""Create dataset for all the QA Check tables.""" | ||
|
||
success = bq_create_dataset(project_id=CloudWorkspace.project_id, dataset_id=self.bq_dataset_id) | ||
|
||
set_task_state(success, self.create_qa_check_dataset.__name__, release) | ||
|
||
def qa_check_dataset(self, release: QA_CheckRelease, **kwargs): | ||
""" | ||
For each dataset, create a table where the rows of the table hold the qa metadata of each of the tables. | ||
use a Jinja script to create the metadata on the tables | ||
and append it onto the last row that exists.""" | ||
|
||
# if a sharded table - and if not done beofre - use the observatory api | ||
|
||
table_list: List[Table] = kwargs["table_list"] | ||
|
||
# Get list of tables in the dataset - make sure that the number of tables match whats given? | ||
# exclude upsert and delete tables, and snapshots | ||
|
||
for table in table_list: | ||
# if the table is sharded, loop through each of the shards | ||
|
||
# Checl that hte shards havent been done before. | ||
|
||
# if its a new one, do the QA check. | ||
|
||
# assert that the table actually exists, else throw an error. | ||
assert bq_table_exists( | ||
table.table_id | ||
), f"The table {table.table_id} does not exist in dataset: {table.source_dataset}" | ||
|
||
logging.info(f"QA for table - {table.table_id}") | ||
|
||
success = bq_create_empty_table_() | ||
|
||
# needs qa_table_id, needs to append to the last row of info, run date and time, | ||
|
||
# Compile the sql to make the qa check | ||
sql = render_template( | ||
self.qa_check_template_path, | ||
last_table_created_date=release.snapshot_date, | ||
) | ||
|
||
# dictionary from sql run | ||
|
||
# query the table - if last run was the same as the last sharhded date. | ||
|
||
# append onto the table | ||
|
||
# assert that number of rows has increased for n number of tables in the source dataset. | ||
|
||
|
||
def bq_table_exists(table_id: str) -> bool: | ||
"""Check if a Bigquery table exists. | ||
:param table_id: Fully qualified table id. | ||
:return exists: True if the table exists.""" | ||
|
||
project_id, dataset_id, table_name = table_id.split(".") | ||
|
||
bq_client = bigquery.client() |