Skip to content

Commit

Permalink
openlineage, bigquery: add openlineage method support for BigQueryExe…
Browse files Browse the repository at this point in the history
…cuteQueryOperator (#31293)

Signed-off-by: Maciej Obuchowski <obuchowski.maciej@gmail.com>
  • Loading branch information
mobuchowski committed Aug 4, 2023
1 parent af08392 commit e10aa6a
Show file tree
Hide file tree
Showing 6 changed files with 428 additions and 12 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -2245,7 +2245,7 @@ def run_query(
self.running_job_id = job.job_id
return job.job_id

def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False) -> str:
if force_rerun:
hash_base = str(uuid.uuid4())
else:
Expand Down
92 changes: 82 additions & 10 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -133,6 +133,68 @@ def get_db_hook(self: BigQueryCheckOperator) -> BigQueryHook: # type:ignore[mis
)


class _BigQueryOpenLineageMixin:
def get_openlineage_facets_on_complete(self, task_instance):
"""
Retrieve OpenLineage data for a COMPLETE BigQuery job.
This method retrieves statistics for the specified job_ids using the BigQueryDatasetsProvider.
It calls BigQuery API, retrieving input and output dataset info from it, as well as run-level
usage statistics.
Run facets should contain:
- ExternalQueryRunFacet
- BigQueryJobRunFacet
Job facets should contain:
- SqlJobFacet if operator has self.sql
Input datasets should contain facets:
- DataSourceDatasetFacet
- SchemaDatasetFacet
Output datasets should contain facets:
- DataSourceDatasetFacet
- SchemaDatasetFacet
- OutputStatisticsOutputDatasetFacet
"""
from openlineage.client.facet import SqlJobFacet
from openlineage.common.provider.bigquery import BigQueryDatasetsProvider

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.utils.utils import normalize_sql

if not self.job_id:
return OperatorLineage()

client = self.hook.get_client(project_id=self.hook.project_id)
job_ids = self.job_id
if isinstance(self.job_id, str):
job_ids = [self.job_id]
inputs, outputs, run_facets = {}, {}, {}
for job_id in job_ids:
stats = BigQueryDatasetsProvider(client=client).get_facets(job_id=job_id)
for input in stats.inputs:
input = input.to_openlineage_dataset()
inputs[input.name] = input
if stats.output:
output = stats.output.to_openlineage_dataset()
outputs[output.name] = output
for key, value in stats.run_facets.items():
run_facets[key] = value

job_facets = {}
if hasattr(self, "sql"):
job_facets["sql"] = SqlJobFacet(query=normalize_sql(self.sql))

return OperatorLineage(
inputs=list(inputs.values()),
outputs=list(outputs.values()),
run_facets=run_facets,
job_facets=job_facets,
)


class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
"""Performs checks against BigQuery.
Expand Down Expand Up @@ -1153,6 +1215,7 @@ def __init__(
self.encryption_configuration = encryption_configuration
self.hook: BigQueryHook | None = None
self.impersonation_chain = impersonation_chain
self.job_id: str | list[str] | None = None

def execute(self, context: Context):
if self.hook is None:
Expand All @@ -1164,7 +1227,7 @@ def execute(self, context: Context):
impersonation_chain=self.impersonation_chain,
)
if isinstance(self.sql, str):
job_id: str | list[str] = self.hook.run_query(
self.job_id = self.hook.run_query(
sql=self.sql,
destination_dataset_table=self.destination_dataset_table,
write_disposition=self.write_disposition,
Expand All @@ -1184,7 +1247,7 @@ def execute(self, context: Context):
encryption_configuration=self.encryption_configuration,
)
elif isinstance(self.sql, Iterable):
job_id = [
self.job_id = [
self.hook.run_query(
sql=s,
destination_dataset_table=self.destination_dataset_table,
Expand All @@ -1210,9 +1273,9 @@ def execute(self, context: Context):
raise AirflowException(f"argument 'sql' of type {type(str)} is neither a string nor an iterable")
project_id = self.hook.project_id
if project_id:
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
job_id_path = convert_job_id(job_id=self.job_id, project_id=project_id, location=self.location)
context["task_instance"].xcom_push(key="job_id_path", value=job_id_path)
return job_id
return self.job_id

def on_kill(self) -> None:
super().on_kill()
Expand Down Expand Up @@ -2562,7 +2625,7 @@ def execute(self, context: Context):
return table


class BigQueryInsertJobOperator(GoogleCloudBaseOperator):
class BigQueryInsertJobOperator(GoogleCloudBaseOperator, _BigQueryOpenLineageMixin):
"""Execute a BigQuery job.
Waits for the job to complete and returns job id.
Expand Down Expand Up @@ -2663,6 +2726,13 @@ def __init__(
self.deferrable = deferrable
self.poll_interval = poll_interval

@property
def sql(self) -> str | None:
try:
return self.configuration["query"]["query"]
except KeyError:
return None

def prepare_template(self) -> None:
# If .json is passed then we have to read the file
if isinstance(self.configuration, str) and self.configuration.endswith(".json"):
Expand Down Expand Up @@ -2697,7 +2767,7 @@ def execute(self, context: Any):
)
self.hook = hook

job_id = hook.generate_job_id(
self.job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
Expand All @@ -2708,13 +2778,13 @@ def execute(self, context: Any):

try:
self.log.info("Executing: %s'", self.configuration)
job: BigQueryJob | UnknownJob = self._submit_job(hook, job_id)
job: BigQueryJob | UnknownJob = self._submit_job(hook, self.job_id)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
job_id=job_id,
job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
Expand All @@ -2723,7 +2793,7 @@ def execute(self, context: Any):
else:
# Same job configuration so we need force_rerun
raise AirflowException(
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)
Expand Down Expand Up @@ -2757,7 +2827,9 @@ def execute(self, context: Any):
self.job_id = job.job_id
project_id = self.project_id or self.hook.project_id
if project_id:
job_id_path = convert_job_id(job_id=job_id, project_id=project_id, location=self.location)
job_id_path = convert_job_id(
job_id=self.job_id, project_id=project_id, location=self.location # type: ignore[arg-type]
)
context["ti"].xcom_push(key="job_id_path", value=job_id_path)
# Wait for the job to complete
if not self.deferrable:
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/openlineage/extractors/base.py
Expand Up @@ -86,6 +86,12 @@ def extract(self) -> OperatorLineage | None:
# OpenLineage methods are optional - if there's no method, return None
try:
return self._get_openlineage_facets(self.operator.get_openlineage_facets_on_start) # type: ignore
except ImportError:
self.log.error(
"OpenLineage provider method failed to import OpenLineage integration. "
"This should not happen. Please report this bug to developers."
)
return None
except AttributeError:
return None

Expand Down
9 changes: 8 additions & 1 deletion airflow/providers/openlineage/utils/utils.py
Expand Up @@ -23,7 +23,7 @@
import os
from contextlib import suppress
from functools import wraps
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import attrs
Expand Down Expand Up @@ -414,3 +414,10 @@ def is_source_enabled() -> bool:
def get_filtered_unknown_operator_keys(operator: BaseOperator) -> dict:
not_required_keys = {"dag", "task_group"}
return {attr: value for attr, value in operator.__dict__.items() if attr not in not_required_keys}


def normalize_sql(sql: str | Iterable[str]):
if isinstance(sql, str):
sql = [stmt for stmt in sql.split(";") if stmt != ""]
sql = [obj for stmt in sql for obj in stmt.split(";") if obj != ""]
return ";\n".join(sql)

0 comments on commit e10aa6a

Please sign in to comment.