Skip to content

Commit

Permalink
Add support to specify kernel name in PapermillOperator (#20035)
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthong582000 committed Dec 4, 2021
1 parent 7a85224 commit d3f4456
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
8 changes: 7 additions & 1 deletion airflow/providers/papermill/operators/papermill.py
Expand Up @@ -44,25 +44,30 @@ class PapermillOperator(BaseOperator):
:type output_nb: str
:param parameters: the notebook parameters to set
:type parameters: dict
:param kernel_name: (optional) name of kernel to execute the notebook against
(ignores kernel name in the notebook document metadata)
:type kernel_name: str
"""

supports_lineage = True

template_fields = ('input_nb', 'output_nb', 'parameters')
template_fields = ('input_nb', 'output_nb', 'parameters', 'kernel_name')

def __init__(
self,
*,
input_nb: Optional[str] = None,
output_nb: Optional[str] = None,
parameters: Optional[Dict] = None,
kernel_name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)

self.input_nb = input_nb
self.output_nb = output_nb
self.parameters = parameters
self.kernel_name = kernel_name
if input_nb:
self.inlets.append(NoteBook(url=input_nb, parameters=self.parameters))
if output_nb:
Expand All @@ -79,4 +84,5 @@ def execute(self, context):
parameters=item.parameters,
progress_bar=False,
report_mode=True,
kernel_name=self.kernel_name,
)
13 changes: 11 additions & 2 deletions tests/providers/papermill/operators/test_papermill.py
Expand Up @@ -30,21 +30,28 @@ class TestPapermillOperator(unittest.TestCase):
def test_execute(self, mock_papermill):
in_nb = "/tmp/does_not_exist"
out_nb = "/tmp/will_not_exist"
kernel_name = "python3"
parameters = {"msg": "hello_world", "train": 1}

op = PapermillOperator(
input_nb=in_nb,
output_nb=out_nb,
parameters=parameters,
task_id="papermill_operator_test",
kernel_name=kernel_name,
dag=None,
)

op.pre_execute(context={}) # make sure to have the inlets
op.pre_execute(context={}) # Make sure to have the inlets
op.execute(context={})

mock_papermill.execute_notebook.assert_called_once_with(
in_nb, out_nb, parameters=parameters, progress_bar=False, report_mode=True
in_nb,
out_nb,
parameters=parameters,
kernel_name=kernel_name,
progress_bar=False,
report_mode=True,
)

def test_render_template(self):
Expand All @@ -56,6 +63,7 @@ def test_render_template(self):
input_nb="/tmp/{{ dag.dag_id }}.ipynb",
output_nb="/tmp/out-{{ dag.dag_id }}.ipynb",
parameters={"msgs": "dag id is {{ dag.dag_id }}!"},
kernel_name="python3",
dag=dag,
)

Expand All @@ -66,3 +74,4 @@ def test_render_template(self):
assert "/tmp/test_render_template.ipynb" == getattr(operator, 'input_nb')
assert '/tmp/out-test_render_template.ipynb' == getattr(operator, 'output_nb')
assert {"msgs": "dag id is test_render_template!"} == getattr(operator, 'parameters')
assert "python3" == getattr(operator, 'kernel_name')

0 comments on commit d3f4456

Please sign in to comment.