In [1]:
#| default_exp airflow.utils

In [2]:
#| export

import subprocess  # nosec B404
import shlex
from typing import *
import pandas as pd
import os
import json

from datetime import datetime, timedelta
from pathlib import Path
from contextlib import contextmanager
import tempfile
from time import sleep

from airt_service.sanitizer import sanitized_print

In [3]:
from airt_service.db.models import (
    DataBlob,
    User,
    create_user_for_testing,
    get_session,
    get_session_with_context,
)
from sqlmodel import select
from airt_service.batch_job import get_environment_vars_for_batch_job
from airt_service.data.utils import create_db_uri_for_s3_datablob
from airt_service.helpers import commit_or_rollback

In [4]:
test_username = create_user_for_testing(subscription_type="small")
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "s3://test-airt-service/account_312571_events"
    datablob = DataBlob(
        type="s3",
        source=uri,
        uri=create_db_uri_for_s3_datablob(
            uri=uri,
            access_key=os.environ["AWS_ACCESS_KEY_ID"],
            secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
        ),
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    with commit_or_rollback(session):
        session.add(datablob)
    display(datablob)
    datablob_id = datablob.id

DataBlob(id=47, uuid=UUID('de400d5c-54b4-46a6-9ea6-3b52200043d9'), type='s3', uri='s3://****************************************@test-airt-service/account_312571_events', source='s3://test-airt-service/account_312571_events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 12, 9, 10, 8, 24), user_id=133, pulled_on=None, tags=[])

In [5]:
test_dag_name = f"test-{datetime.now().isoformat()}"
test_dag_name

'test-2022-12-09T10:08:23.710441'

In [6]:
bash_dag = """from datetime import datetime, timedelta
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
with DAG(
    '{dag_name}',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={{
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_other_function,
        # 'on_retry_callback': another_function,
        # 'sla_miss_callback': yet_another_function,
        # 'trigger_rule': 'all_success'
    }},
    description='From S3',
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['s3'],
    is_paused_upon_creation=True,
) as dag:

    # t1, t2 and t3 are examples of tasks created by instantiating operators
    t1 = BashOperator(
        task_id='local_s3_pull',
        depends_on_past=False,
        bash_command='s3_pull {{{{ dag_run.conf["datablob_id"] if dag_run else "" }}}}',
    )
"""

In [7]:
sanitized_print(bash_dag.format(dag_name="somethinghardcodedstringliterally"))

from datetime import datetime, timedelta
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
with DAG(
    'somethinghardcodedstringliterally',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_

In [8]:
#| export


def list_dags(
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
):
    command = f"{airflow_command} dags list -o json"
    # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
    p = subprocess.run(  # nosec B603
        shlex.split(command), shell=False, capture_output=True, text=True, check=True
    )
    try:
        return json.loads(p.stdout)
    except Exception as e:
        sanitized_print(f"{p.stdout=}")
        raise e

In [9]:
airflow_command = f"{os.environ['HOME']}/airflow_venv/bin/airflow"

In [10]:
!mkdir -p {os.environ['HOME']}/airflow/dags/
!cp {os.environ['HOME']}/airflow_venv/lib/python3.10/site-packages/airflow/example_dags/tutorial.py {os.environ['HOME']}/airflow/dags
!ll {os.environ['HOME']}/airflow/dags
!sleep 10
!{airflow_command} dags unpause tutorial

/bin/bash: line 1: ll: command not found
  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)
DAG: tutorial does not exist in 'dag' table


In [11]:
df = pd.DataFrame.from_dict(list_dags())
df

Unnamed: 0,dag_id,filepath,owner,paused
0,s3_pull-27,s3_pull-27.py,airflow,False
1,s3_pull-8,s3_pull-8.py,airflow,False
2,tutorial,tutorial.py,airflow,True


In [12]:
# | export


def list_dag_runs(
    dag_id: str,
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
):
    command = f"{airflow_command} dags list-runs -d {dag_id} -o json"

    # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
    p = subprocess.run(  # nosec B603
        shlex.split(command),
        shell=False,
        capture_output=True,
        text=True,
        check=True,
    )

    return json.loads(p.stdout)

In [13]:
pd.DataFrame.from_dict(list_dag_runs("tutorial"))

In [14]:
#| export


def create_dag(
    dag_id: str,
    dag_definition_template: str,
    *,
    root_path: Path = Path(f"{os.environ['HOME']}/airflow/dags/"),
    **kwargs,
):
    root_path.mkdir(exist_ok=True, parents=True)
    tmp_file_path = root_path / f'{dag_id.replace(":", "_")}.py'
    with open(tmp_file_path, "w") as temp_file:
        temp_file.write(dag_definition_template.format(dag_name=dag_id, **kwargs))

    while True:
        df = pd.DataFrame.from_dict(list_dags())
        if (dag_id == df["dag_id"]).sum():
            break
        sanitized_print(".", end="")
        sleep(1)
    return tmp_file_path


@contextmanager
def create_testing_dag_ctx(
    dag_definition_template: str,
    *,
    root_path: Path = Path(f"{os.environ['HOME']}/airflow/dags/"),
    **kwargs,
):
    tmp_file_path = None
    try:
        dag_id = f"test-{datetime.now().isoformat()}".replace(":", "_")

        tmp_file_path = create_dag(
            dag_id=dag_id,
            dag_definition_template=dag_definition_template,
            root_path=root_path,
            **kwargs,
        )
        yield dag_id
    finally:
        if tmp_file_path and tmp_file_path.exists():
            tmp_file_path.unlink()

In [15]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    s = !{airflow_command} dags list
    display(s)
    display(f"{dag_id=}")
    assert dag_id in "\n".join(s)
s = !{airflow_command} dags list
assert dag_id not in "\n".join(s), dag_id

 '  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)',
 'dag_id                          | filepath                           | owner   | paused',
 's3_pull-27                      | s3_pull-27.py                      | airflow | False ',
 's3_pull-8                       | s3_pull-8.py                       | airflow | False ',
 'test-2022-12-09T10_08_37.558736 | test-2022-12-09T10_08_37.558736.py | airflow | None  ',
 'tutorial                        | tutorial.py                        | airflow | True  ',
 '                                                                                       ']

"dag_id='test-2022-12-09T10_08_37.558736'"

In [16]:
#| export


def run_subprocess_with_retry(
    command: str, *, no_retries: int = 12, sleep_for: int = 5
):
    for i in range(no_retries):
        # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
        p = subprocess.run(  # nosec B603
            shlex.split(command),
            shell=False,
            capture_output=True,
            text=True,
            check=False,
        )
        if p.returncode == 0:
            return p
        sleep(sleep_for)
    raise TimeoutError(p)

In [17]:
#| export


def unpause_dag(
    dag_id: str,
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
    no_retries: int = 12,
):
    unpause_command = f"{airflow_command} dags unpause {dag_id}"
    p = run_subprocess_with_retry(unpause_command, no_retries=no_retries)

In [18]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    unpause_dag(dag_id)
    s = !{airflow_command} dags list
    display(s)
    assert dag_id in "\n".join(s), dag_id

'test-2022-12-09T10_08_40.893952'

 '  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)',
 'dag_id                          | filepath                           | owner   | paused',
 's3_pull-27                      | s3_pull-27.py                      | airflow | False ',
 's3_pull-8                       | s3_pull-8.py                       | airflow | False ',
 'test-2022-12-09T10_08_40.893952 | test-2022-12-09T10_08_40.893952.py | airflow | False ',
 'tutorial                        | tutorial.py                        | airflow | True  ',
 '                                                                                       ']

In [19]:
#| export


def trigger_dag(
    dag_id: str,
    conf: Dict[str, Any],
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
    no_retries: int = 12,
    unpause_if_needed: bool = True,
):
    if unpause_if_needed:
        unpause_dag(
            dag_id=dag_id, airflow_command=airflow_command, no_retries=no_retries
        )

    run_id = f"airt-service__{datetime.now().isoformat()}"
    command = f"{airflow_command} dags trigger {dag_id} --conf {shlex.quote(json.dumps(conf))} --run-id {run_id}"
    p = run_subprocess_with_retry(command, no_retries=no_retries)
    sanitized_print(p)

    runs = list_dag_runs(dag_id=dag_id)
    sanitized_print(runs)

    return run_id

In [20]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    run_id = trigger_dag(dag_id, conf={"datablob_id": datablob.id})

run_id

'test-2022-12-09T10_08_56.163621'

[{'dag_id': 'test-2022-12-09T10_08_56.163621', 'run_id': 'airt-service__2022-12-09T10:09:10.877122', 'state': 'running', 'execution_date': '2022-12-09T10:09:12+00:00', 'start_date': '2022-12-09T10:09:12.286289+00:00', 'end_date': ''}, {'dag_id': 'test-2022-12-09T10_08_56.163621', 'run_id': 'scheduled__2022-12-08T10:09:06.807903+00:00', 'state': 'running', 'execution_date': '2022-12-08T10:09:06.807903+00:00', 'start_date': '2022-12-09T10:09:10.920257+00:00', 'end_date': ''}]


'airt-service__2022-12-09T10:09:10.877122'

In [21]:
#| export


def wait_for_run_to_complete(dag_id: str, run_id: str, timeout: int = 60) -> str:
    t0 = datetime.now()
    while (datetime.now() - t0) < timedelta(seconds=timeout):
        runs = pd.DataFrame(list_dag_runs(dag_id=dag_id))
        state = runs.loc[runs["run_id"] == run_id, "state"].iloc[0]
        if state in ["success", "failed"]:
            return state
        sleep(5)
    raise TimeoutError()

In [22]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    run_id = trigger_dag(dag_id, conf={"datablob_id": datablob_id})
    display(run_id)
    state = wait_for_run_to_complete(dag_id, run_id, timeout=600)
state

'test-2022-12-09T10_09_13.596500'

[{'dag_id': 'test-2022-12-09T10_09_13.596500', 'run_id': 'airt-service__2022-12-09T10:09:27.899344', 'state': 'running', 'execution_date': '2022-12-09T10:09:28+00:00', 'start_date': '2022-12-09T10:09:29.491878+00:00', 'end_date': ''}, {'dag_id': 'test-2022-12-09T10_09_13.596500', 'run_id': 'scheduled__2022-12-08T10:09:22.295839+00:00', 'state': 'running', 'execution_date': '2022-12-08T10:09:22.295839+00:00', 'start_date': '2022-12-09T10:09:28.433180+00:00', 'end_date': ''}]


'airt-service__2022-12-09T10:09:27.899344'

'success'

In [23]:
dag_id = "tutorial"
run_id = trigger_dag(dag_id, conf={"datablob_id": datablob_id})
display(run_id)
state = wait_for_run_to_complete(dag_id, run_id, timeout=600)
state

[{'dag_id': 'tutorial', 'run_id': 'airt-service__2022-12-09T10:09:50.884073', 'state': 'running', 'execution_date': '2022-12-09T10:09:51+00:00', 'start_date': '2022-12-09T10:09:52.589691+00:00', 'end_date': ''}, {'dag_id': 'tutorial', 'run_id': 'scheduled__2022-12-08T10:09:37.362547+00:00', 'state': 'running', 'execution_date': '2022-12-08T10:09:37.362547+00:00', 'start_date': '2022-12-09T10:09:51.694179+00:00', 'end_date': ''}]


'airt-service__2022-12-09T10:09:50.884073'

'success'

In [24]:
batch_env_vars = get_environment_vars_for_batch_job()

In [25]:
# | eval: false

batch_dag = """from datetime import datetime, timedelta
import json
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.providers.amazon.aws.operators.batch import BatchOperator
with DAG(
    '{dag_name}',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={{
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_other_function,
        # 'on_retry_callback': another_function,
        # 'sla_miss_callback': yet_another_function,
        # 'trigger_rule': 'all_success'
    }},
    description='From S3',
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['s3'],
    #is_paused_upon_creation=True,
) as dag:

    # t1, t2 and t3 are examples of tasks created by instantiating operators
    env_var_str = '{{{{ dag_run.conf["environment"] }}}}'
    import logging
    
    log: logging.log = logging.getLogger("airflow")
    log.setLevel(logging.INFO)
    log.info("this is me logging some random stuff and see whether it fails or not")
    log.info(env_var_str)
    
    t1 = BatchOperator(
        task_id='batch_s3_pull',
        depends_on_past=False,
        job_definition="staging_csv_processing_job_definition",
        job_queue="staging_csv_processing_job_queue",
        job_name="test_airflow",
        overrides={{
            "command":['s3_pull', '{{{{ dag_run.conf["datablob_id"] if dag_run else "" }}}}'],
            "environment": {env_str}
        }}
    )
"""

In [26]:
# "environment": """+json.dumps([dict(name=name, value=value) for name, value in batch_env_vars.items()]).replace("{", "{{").replace("}", "}}")+"""

In [None]:
# | eval: false

batch_env_var_names = list(batch_env_vars.keys())
batch_env_var_names
env_str = [
    {"name": key, "value": f"{{{{ dag_run.conf['{key}'] }}}}"}
    for key in batch_env_var_names
]

with create_testing_dag_ctx(batch_dag, env_str=env_str) as dag_id:
    display(dag_id)
    #     sleep(1)
    conf = batch_env_vars.copy()
    conf["datablob_id"] = 128
    run_id = trigger_dag(
        dag_id,
        conf=conf,
    )
    display(run_id)
    state = wait_for_run_to_complete(dag_id, run_id, timeout=600)