Skip to content

Commit

Permalink
Function to check inputs of workflow object
Browse files Browse the repository at this point in the history
  • Loading branch information
keegansmith21 committed Apr 27, 2023
1 parent a0e6f5e commit b464ae2
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
31 changes: 31 additions & 0 deletions observatory-platform/observatory/platform/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
delete_old_xcoms,
)
from observatory.platform.airflow import get_data_path
from observatory.platform.observatory_config import CloudWorkspace


DATE_TIME_FORMAT = "YYYY-MM-DD_HH:mm:ss"
Expand All @@ -68,6 +69,36 @@ def make_workflow_folder(dag_id: str, run_id: str, *subdirs: str) -> str:
return path


def check_workflow_inputs(workflow: Workflow, check_cloud_workspace=True) -> None:
"""Checks a Workflow object for validity
:param workflow: The Workflow object
:param check_cloud_workspace: Whether to check the CloudWorkspace field, defaults to True
:raises AirflowException: Raised if there are invalid fields
"""
invalid_fields = []
if not workflow.dag_id or not isinstance(workflow.dag_id, str):
invalid_fields.append("dag_id")

if check_cloud_workspace:
cloud_workspace = workflow.cloud_workspace
if not isinstance(cloud_workspace, CloudWorkspace):
invalid_fields.append("cloud_workspace")
else:
required_fields = {"project_id": str, "data_location": str, "download_bucket": str, "transform_bucket": str}
for field_name, field_type in required_fields.items():
field_value = getattr(cloud_workspace, field_name, None)
if not isinstance(field_value, field_type) or not field_value:
invalid_fields.append(f"cloud_workspace.{field_name}")

if cloud_workspace.output_project_id is not None:
if not isinstance(cloud_workspace.output_project_id, str) or not cloud_workspace.output_project_id:
invalid_fields.append("cloud_workspace.output_project_id")

if invalid_fields:
raise AirflowException(f"Workflow input fields invalid: {invalid_fields}")


def cleanup(dag_id: str, execution_date: str, workflow_folder: str = None, retention_days=31) -> None:
"""Delete all files, folders and XComs associated from a release.
Expand Down
65 changes: 48 additions & 17 deletions tests/observatory/platform/workflows/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from datetime import datetime, timezone
from functools import partial
from tempfile import TemporaryDirectory
from unittest.mock import patch
from unittest.mock import patch, MagicMock
from copy import deepcopy

import pendulum
from airflow import DAG
Expand All @@ -29,9 +30,7 @@
from airflow.operators.python import PythonOperator
from airflow.sensors.external_task import ExternalTaskSensor

from observatory.api.client import ApiClient, Configuration
from observatory.api.client.api.observatory_api import ObservatoryApi # noqa: E501
from observatory.api.testing import ObservatoryApiEnvironment
from observatory.platform.observatory_config import CloudWorkspace
from observatory.platform.observatory_environment import (
ObservatoryEnvironment,
ObservatoryTestCase,
Expand All @@ -45,6 +44,7 @@
make_snapshot_date,
cleanup,
set_task_state,
check_workflow_inputs,
)


Expand Down Expand Up @@ -104,6 +104,48 @@ def test_make_workflow_folder(self, mock_get_variable):
os.path.join(tempdir, f"test_dag/scheduled__2023-03-26T00:00:00+00:00/sub_folder/subsub_folder"),
)

def test_check_workflow_inputs(self):
"""Test check_workflow_inputs"""
# Test Dag ID validity
wf = MagicMock(dag_id="valid")
check_workflow_inputs(wf, check_cloud_workspace=False) # Should pass
for dag_id in ["", None, 42]: # Should all fail
wf.dag_id = dag_id
with self.assertRaises(AirflowException) as cm:
check_workflow_inputs(wf, check_cloud_workspace=False)
msg = cm.exception.args[0]
self.assertIn("dag_id", msg)

# Test when cloud workspace is of wrong type
wf = MagicMock(dag_id="valid", cloud_workspace="invalid")
with self.assertRaisesRegex(AirflowException, "cloud_workspace"):
check_workflow_inputs(wf)

# Test validity of each part of the cloud workspace
valid_cloud_workspace = CloudWorkspace(
project_id="project_id",
download_bucket="download_bucket",
transform_bucket="transform_bucket",
data_location="data_location",
output_project_id="output_project_id",
)
wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace))
check_workflow_inputs(wf) # Should pass
for attr, invalid_val in [
("project_id", ""),
("download_bucket", None),
("transform_bucket", 42),
("data_location", MagicMock()),
("output_project_id", ""),
]:
wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace))
setattr(wf.cloud_workspace, attr, invalid_val)
with self.assertRaisesRegex(AirflowException, f"cloud_workspace.{attr}"):
check_workflow_inputs(wf)
wf = MagicMock(dag_id="valid", cloud_workspace=deepcopy(valid_cloud_workspace))
wf.cloud_workspace.output_project_id = None
check_workflow_inputs(wf) # This one should pass

def test_make_snapshot_date(self):
"""Test make_table_name"""

Expand Down Expand Up @@ -173,11 +215,6 @@ def __init__(self, *args, **kwargs):

self.host = "localhost"
self.port = find_free_port()
configuration = Configuration(host=f"http://{self.host}:{self.port}")
api_client = ApiClient(configuration)
self.api = ObservatoryApi(api_client=api_client) # noqa: E501
self.env = ObservatoryApiEnvironment(host=self.host, port=self.port)
self.org_name = "Curtin University"

def test_make_task_id(self):
"""Test make_task_id"""
Expand Down Expand Up @@ -317,14 +354,8 @@ def test_make_dag(self):
self.assertTrue(telescope._parallel_tasks)
self.assertFalse(telescope._parallel_tasks)

@patch("observatory.platform.api.make_observatory_api")
def test_telescope(self, m_makeapi):
"""Basic test to make sure that the Workflow class can execute in an Airflow environment.
:return: None.
"""

m_makeapi.return_value = self.api

def test_telescope(self):
"""Basic test to make sure that the Workflow class can execute in an Airflow environment."""
# Setup Observatory environment
env = ObservatoryEnvironment(self.project_id, self.data_location, api_host=self.host, api_port=self.port)

Expand Down

0 comments on commit b464ae2

Please sign in to comment.