In [None]:
#| default_exp airflow.utils

In [None]:
#| export

import subprocess  # nosec B404
import shlex
from typing import *
import pandas as pd
import os
import json

from datetime import datetime, timedelta
from pathlib import Path
from contextlib import contextmanager
import tempfile
from time import sleep

from airt_service.sanitizer import sanitized_print

In [None]:
from airt_service.db.models import (
    DataBlob,
    User,
    create_user_for_testing,
    get_session,
    get_session_with_context,
)
from sqlmodel import select
from airt_service.batch_job import get_environment_vars_for_batch_job
from airt_service.data.utils import create_db_uri_for_s3_datablob
from airt_service.helpers import commit_or_rollback

In [None]:
test_username = create_user_for_testing(subscription_type="small")
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "s3://test-airt-service/account_312571_events"
    datablob = DataBlob(
        type="s3",
        source=uri,
        uri=create_db_uri_for_s3_datablob(
            uri=uri,
            access_key=os.environ["AWS_ACCESS_KEY_ID"],
            secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
        ),
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    with commit_or_rollback(session):
        session.add(datablob)
    display(datablob)
    datablob_id = datablob.id

DataBlob(id=47, uuid=UUID('e1277f72-133f-4770-a4af-4f657cda7fa0'), type='s3', uri='s3://****************************************@test-airt-service/account_312571_events', source='s3://test-airt-service/account_312571_events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 12, 2, 8, 8, 31), user_id=133, pulled_on=None, tags=[])

In [None]:
test_dag_name = f"test-{datetime.now().isoformat()}"
test_dag_name

'test-2022-12-02T08:08:30.682171'

In [None]:
bash_dag = """from datetime import datetime, timedelta
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
with DAG(
    '{dag_name}',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={{
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_other_function,
        # 'on_retry_callback': another_function,
        # 'sla_miss_callback': yet_another_function,
        # 'trigger_rule': 'all_success'
    }},
    description='From S3',
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['s3'],
    is_paused_upon_creation=True,
) as dag:

    # t1, t2 and t3 are examples of tasks created by instantiating operators
    t1 = BashOperator(
        task_id='local_s3_pull',
        depends_on_past=False,
        bash_command='s3_pull {{{{ dag_run.conf["datablob_id"] if dag_run else "" }}}}',
    )
"""

In [None]:
sanitized_print(bash_dag.format(dag_name="somethinghardcodedstringliterally"))

from datetime import datetime, timedelta
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.operators.bash import BashOperator
with DAG(
    'somethinghardcodedstringliterally',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_

In [None]:
#| export


def list_dags(
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
):
    command = f"{airflow_command} dags list -o json"
    # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
    p = subprocess.run(  # nosec B603
        shlex.split(command), shell=False, capture_output=True, text=True, check=True
    )
    try:
        return json.loads(p.stdout)
    except Exception as e:
        sanitized_print(f"{p.stdout=}")
        raise e

In [None]:
airflow_command = f"{os.environ['HOME']}/airflow_venv/bin/airflow"

In [None]:
!mkdir -p /root/airflow/dags/
!cp {os.environ['HOME']}/airflow_venv/lib/python3.8/site-packages/airflow/example_dags/tutorial.py /root/airflow/dags
!ll /root/airflow/dags
!sleep 10
!{airflow_command} dags unpause tutorial

total 16K
drwxr-xr-x 2 kumaran kumaran 4.0K Dec  2 08:08 [0m[01;34m__pycache__[0m/
-rw-r--r-- 1 kumaran kumaran 1.9K Dec  2 07:59 s3_pull-10.py
-rw-r--r-- 1 kumaran kumaran 1.9K Dec  2 08:01 s3_pull-33.py
-rw-r--r-- 1 kumaran kumaran 3.8K Dec  2 08:08 tutorial.py
  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)
Dag: tutorial, paused: False


In [None]:
df = pd.DataFrame.from_dict(list_dags())
df

Unnamed: 0,dag_id,filepath,owner,paused
0,s3_pull-10,s3_pull-10.py,airflow,False
1,s3_pull-33,s3_pull-33.py,airflow,False
2,tutorial,tutorial.py,airflow,False


In [None]:
# | export


def list_dag_runs(
    dag_id: str,
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
):
    command = f"{airflow_command} dags list-runs -d {dag_id} -o json"

    # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
    p = subprocess.run(  # nosec B603
        shlex.split(command),
        shell=False,
        capture_output=True,
        text=True,
        check=True,
    )

    return json.loads(p.stdout)

In [None]:
pd.DataFrame.from_dict(list_dag_runs("tutorial"))

total 16K
drwxr-xr-x 2 kumaran kumaran 4.0K Dec  2 08:08 [0m[01;34m__pycache__[0m/
-rw-r--r-- 1 kumaran kumaran 1.9K Dec  2 07:59 s3_pull-10.py
-rw-r--r-- 1 kumaran kumaran 1.9K Dec  2 08:01 s3_pull-33.py
-rw-r--r-- 1 kumaran kumaran 3.8K Dec  2 08:08 tutorial.py
  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)
Dag: tutorial, paused: False


Unnamed: 0,dag_id,run_id,state,execution_date,start_date,end_date
0,tutorial,airt-service__2022-12-02T08:00:20.973564,success,2022-12-02T08:00:22+00:00,2022-12-02T08:00:22.449038+00:00,2022-12-02T08:00:29.434476+00:00
1,tutorial,scheduled__2022-12-01T07:58:33.458042+00:00,success,2022-12-01T07:58:33.458042+00:00,2022-12-02T07:58:40.464505+00:00,2022-12-02T07:58:47.903252+00:00


In [None]:
#| export


def create_dag(
    dag_id: str,
    dag_definition_template: str,
    *,
    root_path: Path = Path("/root/airflow/dags/"),
    **kwargs,
):
    root_path.mkdir(exist_ok=True, parents=True)
    tmp_file_path = root_path / f'{dag_id.replace(":", "_")}.py'
    with open(tmp_file_path, "w") as temp_file:
        temp_file.write(dag_definition_template.format(dag_name=dag_id, **kwargs))

    while True:
        df = pd.DataFrame.from_dict(list_dags())
        if (dag_id == df["dag_id"]).sum():
            break
        sanitized_print(".", end="")
        sleep(1)
    return tmp_file_path


@contextmanager
def create_testing_dag_ctx(
    dag_definition_template: str,
    *,
    root_path: Path = Path("/root/airflow/dags/"),
    **kwargs,
):
    tmp_file_path = None
    try:
        dag_id = f"test-{datetime.now().isoformat()}".replace(":", "_")

        tmp_file_path = create_dag(
            dag_id=dag_id,
            dag_definition_template=dag_definition_template,
            root_path=root_path,
            **kwargs,
        )
        yield dag_id
    finally:
        if tmp_file_path and tmp_file_path.exists():
            tmp_file_path.unlink()

In [None]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    s = !{airflow_command} dags list
    display(s)
    display(f"{dag_id=}")
    assert dag_id in "\n".join(s)
s = !{airflow_command} dags list
assert dag_id not in "\n".join(s), dag_id

 '  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)',
 'dag_id                          | filepath                           | owner   | paused',
 's3_pull-10                      | s3_pull-10.py                      | airflow | False ',
 's3_pull-33                      | s3_pull-33.py                      | airflow | False ',
 'test-2022-12-02T08_09_08.275205 | test-2022-12-02T08_09_08.275205.py | airflow | None  ',
 'tutorial                        | tutorial.py                        | airflow | False ',
 '                                                                                       ']

"dag_id='test-2022-12-02T08_09_08.275205'"

In [None]:
#| export


def run_subprocess_with_retry(
    command: str, *, no_retries: int = 12, sleep_for: int = 5
):
    for i in range(no_retries):
        # nosemgrep: python.lang.security.audit.dangerous-subprocess-use.dangerous-subprocess-use
        p = subprocess.run(  # nosec B603
            shlex.split(command),
            shell=False,
            capture_output=True,
            text=True,
            check=False,
        )
        if p.returncode == 0:
            return p
        sleep(sleep_for)
    raise TimeoutError(p)

In [None]:
#| export


def unpause_dag(
    dag_id: str,
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
    no_retries: int = 12,
):
    unpause_command = f"{airflow_command} dags unpause {dag_id}"
    p = run_subprocess_with_retry(unpause_command, no_retries=no_retries)

In [None]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    unpause_dag(dag_id)
    s = !{airflow_command} dags list
    display(s)
    assert dag_id in "\n".join(s), dag_id

'test-2022-12-02T08_09_12.457102'

 '  option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)',
 'dag_id                          | filepath                           | owner   | paused',
 's3_pull-10                      | s3_pull-10.py                      | airflow | False ',
 's3_pull-33                      | s3_pull-33.py                      | airflow | False ',
 'test-2022-12-02T08_09_12.457102 | test-2022-12-02T08_09_12.457102.py | airflow | False ',
 'tutorial                        | tutorial.py                        | airflow | False ',
 '                                                                                       ']

In [None]:
#| export


def trigger_dag(
    dag_id: str,
    conf: Dict[str, Any],
    *,
    airflow_command: str = f"{os.environ['HOME']}/airflow_venv/bin/airflow",
    no_retries: int = 12,
    unpause_if_needed: bool = True,
):
    if unpause_if_needed:
        unpause_dag(
            dag_id=dag_id, airflow_command=airflow_command, no_retries=no_retries
        )

    run_id = f"airt-service__{datetime.now().isoformat()}"
    command = f"{airflow_command} dags trigger {dag_id} --conf {shlex.quote(json.dumps(conf))} --run-id {run_id}"
    p = run_subprocess_with_retry(command, no_retries=no_retries)
    sanitized_print(p)

    runs = list_dag_runs(dag_id=dag_id)
    sanitized_print(runs)

    return run_id

In [None]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    run_id = trigger_dag(dag_id, conf={"datablob_id": datablob.id})

run_id

'test-2022-12-02T08_09_29.759452'

[{'dag_id': 'test-2022-12-02T08_09_29.759452', 'run_id': 'airt-service__2022-12-02T08:09:45.392344', 'state': 'running', 'execution_date': '2022-12-02T08:09:46+00:00', 'start_date': '2022-12-02T08:09:47.472326+00:00', 'end_date': ''}, {'dag_id': 'test-2022-12-02T08_09_29.759452', 'run_id': 'scheduled__2022-12-01T08:09:41.940713+00:00', 'state': 'running', 'execution_date': '2022-12-01T08:09:41.940713+00:00', 'start_date': '2022-12-02T08:09:45.353470+00:00', 'end_date': ''}]


'airt-service__2022-12-02T08:09:45.392344'

In [None]:
#| export


def wait_for_run_to_complete(dag_id: str, run_id: str, timeout: int = 60) -> str:
    t0 = datetime.now()
    while (datetime.now() - t0) < timedelta(seconds=timeout):
        runs = pd.DataFrame(list_dag_runs(dag_id=dag_id))
        state = runs.loc[runs["run_id"] == run_id, "state"].iloc[0]
        if state in ["success", "failed"]:
            return state
        sleep(5)
    raise TimeoutError()

In [None]:
with create_testing_dag_ctx(bash_dag) as dag_id:
    display(dag_id)
    run_id = trigger_dag(dag_id, conf={"datablob_id": datablob_id})
    display(run_id)
    state = wait_for_run_to_complete(dag_id, run_id, timeout=600)
state

'test-2022-12-02T08_09_48.315850'

[{'dag_id': 'test-2022-12-02T08_09_48.315850', 'run_id': 'airt-service__2022-12-02T08:09:57.542561', 'state': 'running', 'execution_date': '2022-12-02T08:09:58+00:00', 'start_date': '2022-12-02T08:09:59.480074+00:00', 'end_date': ''}, {'dag_id': 'test-2022-12-02T08_09_48.315850', 'run_id': 'scheduled__2022-12-01T08:09:57.226166+00:00', 'state': 'running', 'execution_date': '2022-12-01T08:09:57.226166+00:00', 'start_date': '2022-12-02T08:09:58.403549+00:00', 'end_date': ''}]


'airt-service__2022-12-02T08:09:57.542561'

'success'

In [None]:
dag_id = "tutorial"
run_id = trigger_dag(dag_id, conf={"datablob_id": datablob_id})
display(run_id)
state = wait_for_run_to_complete(dag_id, run_id, timeout=600)
state

[{'dag_id': 'tutorial', 'run_id': 'airt-service__2022-12-02T08:10:22.592585', 'state': 'running', 'execution_date': '2022-12-02T08:10:23+00:00', 'start_date': '2022-12-02T08:10:24.452639+00:00', 'end_date': ''}, {'dag_id': 'tutorial', 'run_id': 'airt-service__2022-12-02T08:00:20.973564', 'state': 'success', 'execution_date': '2022-12-02T08:00:22+00:00', 'start_date': '2022-12-02T08:00:22.449038+00:00', 'end_date': '2022-12-02T08:00:29.434476+00:00'}, {'dag_id': 'tutorial', 'run_id': 'scheduled__2022-12-01T07:58:33.458042+00:00', 'state': 'success', 'execution_date': '2022-12-01T07:58:33.458042+00:00', 'start_date': '2022-12-02T07:58:40.464505+00:00', 'end_date': '2022-12-02T07:58:47.903252+00:00'}]


'airt-service__2022-12-02T08:10:22.592585'

'success'

In [None]:
batch_env_vars = get_environment_vars_for_batch_job()

In [None]:
# | eval: false

batch_dag = """from datetime import datetime, timedelta
import json
from textwrap import dedent

# The DAG object; we'll need this to instantiate a DAG
from airflow import DAG

# Operators; we need this to operate!
from airflow.providers.amazon.aws.operators.batch import BatchOperator
with DAG(
    '{dag_name}',
    # These args will get passed on to each operator
    # You can override them on a per-task basis during operator initialization
    default_args={{
        'depends_on_past': False,
        'email': ['info@airt.ai'],
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': timedelta(minutes=5),
        # 'queue': 'bash_queue',
        # 'pool': 'backfill',
        # 'priority_weight': 10,
        # 'end_date': datetime(2016, 1, 1),
        # 'wait_for_downstream': False,
        # 'sla': timedelta(hours=2),
        # 'execution_timeout': timedelta(seconds=300),
        # 'on_failure_callback': some_function,
        # 'on_success_callback': some_other_function,
        # 'on_retry_callback': another_function,
        # 'sla_miss_callback': yet_another_function,
        # 'trigger_rule': 'all_success'
    }},
    description='From S3',
    start_date=datetime(2021, 1, 1),
    catchup=False,
    tags=['s3'],
    #is_paused_upon_creation=True,
) as dag:

    # t1, t2 and t3 are examples of tasks created by instantiating operators
    env_var_str = '{{{{ dag_run.conf["environment"] }}}}'
    import logging
    
    log: logging.log = logging.getLogger("airflow")
    log.setLevel(logging.INFO)
    log.info("this is me logging some random stuff and see whether it fails or not")
    log.info(env_var_str)
    
    t1 = BatchOperator(
        task_id='batch_s3_pull',
        depends_on_past=False,
        job_definition="staging_csv_processing_job_definition",
        job_queue="staging_csv_processing_job_queue",
        job_name="test_airflow",
        overrides={{
            "command":['s3_pull', '{{{{ dag_run.conf["datablob_id"] if dag_run else "" }}}}'],
            "environment": {env_str}
        }}
    )
"""

In [None]:
# "environment": """+json.dumps([dict(name=name, value=value) for name, value in batch_env_vars.items()]).replace("{", "{{").replace("}", "}}")+"""

In [None]:
# | eval: false

batch_env_var_names = list(batch_env_vars.keys())
batch_env_var_names
env_str = [
    {"name": key, "value": f"{{{{ dag_run.conf['{key}'] }}}}"}
    for key in batch_env_var_names
]

with create_testing_dag_ctx(batch_dag, env_str=env_str) as dag_id:
    display(dag_id)
    #     sleep(1)
    conf = batch_env_vars.copy()
    conf["datablob_id"] = 128
    run_id = trigger_dag(
        dag_id,
        conf=conf,
    )
    display(run_id)
    state = wait_for_run_to_complete(dag_id, run_id, timeout=600)