Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement --export-slice-airflow #320

Merged
merged 16 commits into from
Oct 25, 2021
13 changes: 13 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
"tests/housing.py",
],
},
{
"name": "lineapy --airflow",
"type": "python",
"request": "launch",
"module": "lineapy.cli.cli",
"args": [
"--slice",
"p value",
"--airflow",
"sliced_housing",
"tests/housing.py",
],
},
{
"name": "Python: Current File",
"type": "python",
Expand Down
21 changes: 20 additions & 1 deletion lineapy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
default=None,
help="Requires --slice. Export the sliced code that {slice} depends on to {export_slice}.py",
)
@click.option(
"--export-slice-to-airflow-dag",
"--airflow",
default=None,
help="Requires --slice. Export the sliced code that {slice} depends on to an Airflow DAG {export_slice}.py",
)
@click.option(
"--print-source", help="Whether to print the source code", is_flag=True
)
Expand Down Expand Up @@ -65,6 +71,7 @@ def linea_cli(
mode,
slice,
export_slice,
export_slice_to_airflow_dag,
print_source,
print_graph,
verbose,
Expand All @@ -88,7 +95,8 @@ def linea_cli(

if visualize:
tracer.visualize()
if slice and not export_slice:

if slice and not export_slice and not export_slice_to_airflow_dag:
tree.add(
rich.console.Group(
f"Slice of {repr(slice)}",
Expand All @@ -103,6 +111,17 @@ def linea_cli(
full_code = tracer.sliced_func(slice, export_slice)
pathlib.Path(f"{export_slice}.py").write_text(full_code)

if export_slice_to_airflow_dag:
if not slice:
print(
"Please specify --slice. It is required for --export-slice-to-airflow-dag"
)
exit(1)
full_code = tracer.sliced_aiflow_dag(
slice, export_slice_to_airflow_dag
)
pathlib.Path(f"{export_slice}.py").write_text(full_code)

tracer.db.close()
if print_graph:
graph_code = prettify(
Expand Down
29 changes: 29 additions & 0 deletions lineapy/instrumentation/tracer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging
import pathlib
from collections import defaultdict
from dataclasses import InitVar, dataclass, field
from datetime import datetime
from functools import cached_property
from importlib import import_module
from os import getcwd
from typing import Dict, Optional

from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.utils.dates import days_ago
from black import FileMode, format_str

from lineapy.constants import GET_ITEM, GETATTR
Expand Down Expand Up @@ -156,6 +161,30 @@ def sliced_func(self, slice_name: str, func_name: str) -> str:
full_code = format_str(full_code, mode=black_mode)
return full_code

def sliced_aiflow_dag(self, slice_name: str, func_name: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about moving the airflow specific code to a different folder, like linea/airflow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created plugins/airflow.py, wdyt?

_sliced_func_code = self.sliced_func(slice_name, func_name)
pathlib.Path(f"{func_name}.py").write_text(_sliced_func_code)
_sliced_func = getattr(import_module(func_name), func_name)
DEFAULT_ARGS = {
"start_date": days_ago(1),
"owner": "airflow",
"retries": 2,
}
dagargs = {
"default_args": DEFAULT_ARGS,
"schedule_interval": "0 7 * * *",
"catchup": False,
"max_active_runs": 1,
"concurrency": 5,
}
dag = DAG(func_name, **dagargs)
with dag as dag:
task_id = func_name
task = PythonOperator(
task_id=task_id, python_callable=_sliced_func, dag=dag
)
return dag

def session_artifacts(self) -> list[ArtifactORM]:
return self.db.get_artifacts_for_session(self.session_context.id)

Expand Down