Skip to content

Commit

Permalink
Fix rendering parameters in PapermillOperator (#28979)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Jan 22, 2023
1 parent a1ffb26 commit 736f2e8
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 47 deletions.
70 changes: 43 additions & 27 deletions airflow/providers/papermill/operators/papermill.py
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, ClassVar, Collection, Optional, Sequence

import attr
import papermill as pm
Expand All @@ -33,8 +33,17 @@
class NoteBook(File):
"""Jupyter notebook"""

type_hint: str | None = "jupyter_notebook"
parameters: dict | None = {}
# For compatibility with Airflow 2.3:
# 1. Use predefined set because `File.template_fields` introduced in Airflow 2.4
# 2. Use old styled annotations because `cattrs` doesn't work well with PEP 604.

template_fields: ClassVar[Collection[str]] = {
"parameters",
*(File.template_fields if hasattr(File, "template_fields") else {"url"}),
}

type_hint: Optional[str] = "jupyter_notebook" # noqa: UP007
parameters: Optional[dict] = {} # noqa: UP007

meta_schema: str = __name__ + ".NoteBook"

Expand All @@ -43,8 +52,8 @@ class PapermillOperator(BaseOperator):
"""
Executes a jupyter notebook through papermill that is annotated with parameters
:param input_nb: input notebook (can also be a NoteBook or a File inlet)
:param output_nb: output notebook (can also be a NoteBook or File outlet)
:param input_nb: input notebook, either path or NoteBook inlet.
:param output_nb: output notebook, either path or NoteBook outlet.
:param parameters: the notebook parameters to set
:param kernel_name: (optional) name of kernel to execute the notebook against
(ignores kernel name in the notebook document metadata)
Expand All @@ -57,36 +66,43 @@ class PapermillOperator(BaseOperator):
def __init__(
self,
*,
input_nb: str | None = None,
output_nb: str | None = None,
input_nb: str | NoteBook | None = None,
output_nb: str | NoteBook | None = None,
parameters: dict | None = None,
kernel_name: str | None = None,
language_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.input_nb = input_nb
self.output_nb = output_nb
self.parameters = parameters

if not input_nb:
raise ValueError("Input notebook is not specified")
elif not isinstance(input_nb, NoteBook):
self.input_nb = NoteBook(url=input_nb, parameters=self.parameters)
else:
self.input_nb = input_nb

if not output_nb:
raise ValueError("Output notebook is not specified")
elif not isinstance(output_nb, NoteBook):
self.output_nb = NoteBook(url=output_nb)
else:
self.output_nb = output_nb

self.kernel_name = kernel_name
self.language_name = language_name
if input_nb:
self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters))
if output_nb:
self.outlets.append(NoteBook(url=output_nb))

self.inlets.append(self.input_nb)
self.outlets.append(self.output_nb)

def execute(self, context: Context):
if not self.inlets or not self.outlets:
raise ValueError("Input notebook or output notebook is not specified")

for i, item in enumerate(self.inlets):
pm.execute_notebook(
item.url,
self.outlets[i].url,
parameters=item.parameters,
progress_bar=False,
report_mode=True,
kernel_name=self.kernel_name,
language=self.language_name,
)
pm.execute_notebook(
self.input_nb.url,
self.output_nb.url,
parameters=self.input_nb.parameters,
progress_bar=False,
report_mode=True,
kernel_name=self.kernel_name,
language=self.language_name,
)
96 changes: 76 additions & 20 deletions tests/providers/papermill/operators/test_papermill.py
Expand Up @@ -19,14 +19,55 @@

from unittest.mock import patch

from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.papermill.operators.papermill import PapermillOperator
import pytest

from airflow.providers.papermill.operators.papermill import NoteBook, PapermillOperator
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2021, 1, 1)
TEST_INPUT_URL = "/foo/bar"
TEST_OUTPUT_URL = "/spam/egg"


class TestNoteBook:
"""Test NoteBook object."""

def test_templated_fields(self):
assert hasattr(NoteBook, "template_fields")
assert "parameters" in NoteBook.template_fields


class TestPapermillOperator:
"""Test PapermillOperator."""

def test_mandatory_attributes(self):
"""Test missing Input or Output notebooks."""
with pytest.raises(ValueError, match="Input notebook is not specified"):
PapermillOperator(task_id="missing_input_nb", output_nb="foo-bar")

with pytest.raises(ValueError, match="Output notebook is not specified"):
PapermillOperator(task_id="missing_input_nb", input_nb="foo-bar")

@pytest.mark.parametrize(
"output_nb",
[
pytest.param(TEST_OUTPUT_URL, id="output-as-string"),
pytest.param(NoteBook(TEST_OUTPUT_URL), id="output-as-notebook-object"),
],
)
@pytest.mark.parametrize(
"input_nb",
[
pytest.param(TEST_INPUT_URL, id="input-as-string"),
pytest.param(NoteBook(TEST_INPUT_URL), id="input-as-notebook-object"),
],
)
def test_notebooks_objects(self, input_nb, output_nb):
"""Test different type of Input/Output notebooks arguments."""
op = PapermillOperator(task_id="test_notebooks_objects", input_nb=input_nb, output_nb=output_nb)
assert op.input_nb.url == TEST_INPUT_URL
assert op.output_nb.url == TEST_OUTPUT_URL

@patch("airflow.providers.papermill.operators.papermill.pm")
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
Expand Down Expand Up @@ -57,26 +98,41 @@ def test_execute(self, mock_papermill):
report_mode=True,
)

def test_render_template(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_render_template", default_args=args)

operator = PapermillOperator(
task_id="render_dag_test",
def test_render_template(self, create_task_instance_of_operator):
"""Test rendering fields."""
ti = create_task_instance_of_operator(
PapermillOperator,
input_nb="/tmp/{{ dag.dag_id }}.ipynb",
output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
kernel_name="python3",
language_name="python",
dag=dag,
parameters={"msgs": "dag id is {{ dag.dag_id }}!", "test_dt": "{{ ds }}"},
kernel_name="{{ params.kernel_name }}",
language_name="{{ params.language_name }}",
# Additional parameters for render fields
params={
"kernel_name": "python3",
"language_name": "python",
},
# TI Settings
dag_id="test_render_template",
task_id="render_dag_test",
execution_date=DEFAULT_DATE,
)
task = ti.render_templates()

# Test render Input/Output notebook attributes
assert task.input_nb.url == "/tmp/test_render_template.ipynb"
assert task.input_nb.parameters == {
"msgs": "dag id is test_render_template!",
"test_dt": DEFAULT_DATE.date().isoformat(),
}
assert task.output_nb.url == "/tmp/out-test_render_template.ipynb"
assert task.output_nb.parameters == {}

ti = TaskInstance(operator, run_id="papermill_test")
ti.dag_run = DagRun(execution_date=DEFAULT_DATE)
ti.render_templates()
# Test render other templated attributes
assert task.parameters == task.input_nb.parameters
assert "python3" == task.kernel_name
assert "python" == task.language_name

assert "/tmp/test_render_template.ipynb" == operator.input_nb
assert "/tmp/out-test_render_template.ipynb" == operator.output_nb
assert {"msgs": "dag id is test_render_template!"} == operator.parameters
assert "python3" == operator.kernel_name
assert "python" == operator.language_name
# Test render Lineage inlets/outlets
assert task.inlets[0] == task.input_nb
assert task.outlets[0] == task.output_nb

0 comments on commit 736f2e8

Please sign in to comment.