Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmassen-hane committed Jul 26, 2023
1 parent 5019bbc commit e927830
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 0 deletions.
Empty file.
316 changes: 316 additions & 0 deletions academic_observatory_workflows/workflows/qa_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# Copyright 2023 Curtin University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Author: Alex Massen-HSane


from __future__ import annotations

import copy
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, List, Set, Optional, Tuple, Union

import pendulum
from airflow.exceptions import AirflowException

from academic_observatory_workflows.config import sql_folder, Tag
from observatory.api.client.model.dataset_release import DatasetRelease
from observatory.platform.api import make_observatory_api
from observatory.platform.bigquery import (
bq_sharded_table_id,
bq_create_dataset,
bq_create_table_from_query,
bq_create_view,
bq_select_table_shard_dates,
bq_run_query,
bq_table_id,
bq_select_latest_table,
bq_load_from_memory,
bq_update_table_description,
)
from observatory.platform.config import AirflowConns
from observatory.platform.observatory_config import CloudWorkspace
from observatory.platform.utils.dag_run_sensor import DagRunSensor
from observatory.platform.utils.jinja2_utils import (
make_sql_jinja2_filename,
render_template,
)
from observatory.platform.workflows.workflow import Workflow, make_snapshot_date, set_task_state, Release


# @dataclass
# class Dataset:
# def __init__(self, *, table_id: str, source_dataset: str, sharded_table: bool, primary_key: str):
# """Create a metadata class for each of the tables to be produced.

# There will be one one table for each dataset made.

# Each table will hold the QA information from each table, e.g.
# """
# self.table_id = table_id
# self.source_dataset = source_dataset
# self.sharded_table = sharded_table
# self.primary_key = primary_key
class QA_CheckRelease(Release):

def __init__(self, *,
# list of tables that were processed.
# what operations were done for this release.



):

self.

@dataclass
class Table:
def __init__(
self,
*,
table_id: str,
sharded: bool,
primary_key_id_loc: Union[List[str], str],
):
"""Create a metadata class for each of the tables to be produced.
There will be one one table for each dataset made.
Each table will hold the QA information from each table, e.g.
:param table_id: The table name (not the full table name)
:param source_dataset: Where the table is from.
:param sharded: True if the table is shared or not.
:param primary_key_id_loc: Location of where the primary key is located in the table e.g. MedlineCiation.PMID.value, could be multiple different identifiers.
"""
self.table_id = table_id
self.sharded = sharded
self.primary_key_id_loc = [primary_key_id_loc] if isinstance(primary_key_id_loc, str) else primary_key_id_loc


class QA_Check_Workflow(Workflow):
SENSOR_DAG_IDS = [
"crossref_metadata",
"crossref_fundref",
"geonames",
"ror",
"open_citations",
"unpaywall",
"orcid",
"crossref_events",
"openalex",
"pubmed",
"doi_workflow",
"grid",
]

def __init__(
self,
*,
dag_id: str,
cloud_workspace: CloudWorkspace,
bq_dataset_id: str = "qa_checks",
observatory_api_conn_id: str = AirflowConns.OBSERVATORY_API,
start_date: Optional[pendulum.DateTime] = pendulum.datetime(2020, 1, 1),
schedule_interval: Optional[str] = "@weekly",
sensor_dag_ids: List[str] = None,
):
"""Create the DoiWorkflow.
:param dag_id: the DAG ID.
:param cloud_workspace: the cloud workspace settings.
:param bq_dataset_id:
:param observatory_api_conn_id: COnnection ID for the observatory API.
:param start_date: the start date.
:param schedule_interval: the schedule interval.
"""

super().__init__(
dag_id=dag_id,
start_date=start_date,
schedule_interval=schedule_interval,
catchup=False,
airflow_conns=[observatory_api_conn_id],
tags=[Tag.academic_observatory],
)

self.bq_dataset_id = bq_dataset_id
self.data_location = cloud_workspace.data_location

self.qa_check_template_path = "academic_observatory_workflows/database/sql/create_qa_table.sql.jinja2"

self.sensor_dag_ids = sensor_dag_ids
if sensor_dag_ids is None:
self.sensor_dag_ids = QA_Check_Workflow.SENSOR_DAG_IDS

# List of all datasets to go through and produce QA tables of

# main IDs to take note of:
# doi, PMID,
self.workflows_to_check = {
"pubmed": [
Table(
table_id="pubmed",
sharded=False,
primary_key_location=["MedlineCitation.PMID.value", "MedlineCitation.PMID.Version"],
)
],
"openalex": [
Table(table_id="authors", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
Table(table_id="concepts", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
Table(table_id="funders", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
Table(
table_id="institutions", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"
),
Table(table_id="publishers", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
Table(table_id="sources", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
Table(table_id="works", source_dataset="openalex", sharded=False, primary_key_location="ids.doi"),
],
"doi_workflow": [
Table(table_id="author", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="book", source_dataset="observatory", sharded=True, primary_key_location="isbn"),
Table(table_id="country", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="doi", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="funder", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="group", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="institution", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="journal", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="publisher", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="region", source_dataset="observatory", sharded=True, primary_key_location="id"),
Table(table_id="subregion", source_dataset="observatory", sharded=True, primary_key_location="id"),
],
"ror": [Table(table_id="ror", sharded=True, primary_key_id_loc="id")],
}
self.observatory_api_conn_id = observatory_api_conn_id
self.input_table_id_tasks = []

self.add_task(self.create_qa_check_dataset)
self.create_tasks()

def create_tasks(self):
# Add sensors
with self.parallel_tasks():
for ext_dag_id in self.sensor_dag_ids:
sensor = DagRunSensor(
task_id=f"{ext_dag_id}_sensor",
external_dag_id=ext_dag_id,
mode="reschedule",
duration=timedelta(days=7), # Look back up to 7 days from execution date
poke_interval=int(timedelta(hours=1).total_seconds()), # Check at this interval if dag run is ready
timeout=int(timedelta(days=3).total_seconds()), # Sensor will fail after 3 days of waiting
)
self.add_operator(sensor)

# Setup tasks
self.add_setup_task(self.check_dependencies)

# Create tasks creating the QA Metadata tables
self.input_table_task_ids = []
with self.parallel_tasks():
for workflow, table_list in self.workflows_to_check:
task_id = f"qa_check_{workflow}"
self.add_task(
self.qa_check_dataset,
op_kwargs={"qa_check_workflow": workflow, "task_id": task_id, "table_list": table_list},
task_id=task_id,
)
self.input_table_task_ids.append(task_id)

def make_release(self, **kwargs) -> Release:
"""Make a release instance. The release is passed as an argument to the function (TelescopeFunction) that is
called in 'task_callable'.
:param kwargs: the context passed from the PythonOperator. See
https://airflow.apache.org/docs/stable/macros-ref.html for a list of the keyword arguments that are passed
to this argument.
:return: A release instance or list of release instances
"""

snapshot_date = make_snapshot_date(**kwargs)
return Release(
dag_id=self.dag_id,
run_id=kwargs["run_id"],
snapshot_date=snapshot_date,
)

def create_qa_check_dataset(self, release: QA_CheckRelease, **kwargs):
"""Create dataset for all the QA Check tables."""

success = bq_create_dataset(project_id=CloudWorkspace.project_id, dataset_id=self.bq_dataset_id)

set_task_state(success, self.create_qa_check_dataset.__name__, release)

def qa_check_dataset(self, release: QA_CheckRelease, **kwargs):
"""
For each dataset, create a table where the rows of the table hold the qa metadata of each of the tables.
use a Jinja script to create the metadata on the tables
and append it onto the last row that exists."""

# if a sharded table - and if not done beofre - use the observatory api

table_list: List[Table] = kwargs["table_list"]

# Get list of tables in the dataset - make sure that the number of tables match whats given?
# exclude upsert and delete tables, and snapshots

for table in table_list:
# if the table is sharded, loop through each of the shards

# Checl that hte shards havent been done before.

# if its a new one, do the QA check.

# assert that the table actually exists, else throw an error.
assert bq_table_exists(
table.table_id
), f"The table {table.table_id} does not exist in dataset: {table.source_dataset}"

logging.info(f"QA for table - {table.table_id}")

success = bq_create_empty_table_()

# needs qa_table_id, needs to append to the last row of info, run date and time,

# Compile the sql to make the qa check
sql = render_template(
self.qa_check_template_path,
last_table_created_date=release.snapshot_date,
)

# dictionary from sql run

# query the table - if last run was the same as the last sharhded date.

# append onto the table

# assert that number of rows has increased for n number of tables in the source dataset.


def bq_table_exists(table_id: str) -> bool:
"""Check if a Bigquery table exists.
:param table_id: Fully qualified table id.
:return exists: True if the table exists."""

project_id, dataset_id, table_name = table_id.split(".")

bq_client = bigquery.client()

0 comments on commit e927830

Please sign in to comment.