In [None]:
#| default_exp airflow.azure_batch_executor

In [None]:
#| export

import os
import tempfile
import shlex
import yaml
from pathlib import Path
from typing import *

from azure.batch.batch_auth import SharedKeyCredentials
from fastcore.script import call_parse, Param

from airt_service.sanitizer import sanitized_print
from airt.executor.subcommand import CLICommandBase, ClassCLICommand
from airt.helpers import slugify
from airt.logger import get_logger
from airt.patching import patch
from airt_service.airflow.base_executor import BaseAirflowExecutor, dag_template
from airt_service.airflow.utils import trigger_dag, wait_for_run_to_complete
from airt_service.azure.batch_utils import (
    get_random_string,
    BatchPool,
    BatchJob,
    AUTO_SCALE_FORMULA,
)
from airt_service.azure.utils import (
    get_azure_batch_environment_component_names,
    get_batch_account_pool_job_names,
)
from airt_service.batch_job import get_environment_vars_for_batch_job
from airt_service.helpers import generate_random_string

22-10-20 06:55:38.602 [INFO] airt.executor.subcommand: Module loaded.


In [None]:
import pytest
from datetime import timedelta
from time import sleep


from airt.executor.subcommand import SimpleCLICommand
from airt.testing import activate_by_import
from airt_service.airflow.utils import list_dag_runs
from airt_service.db.models import create_user_for_testing

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
test_username = create_user_for_testing(subscription_type="small")
display(test_username)

'entezlusrf'

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
logger.info("Module loaded.")

[INFO] __main__: Module loaded.


In [None]:
#| exporti


def setup_test_paths(td: str) -> Tuple[str, str]:
    d = Path(td)
    paths = [d / sd for sd in ["data", "model"]]
    sanitized_print(f"{paths=}")

    # create tmp dirs for data and model
    for p in paths:
        p.mkdir(parents=True, exist_ok=True)

    # RemotePaths: data_path is "read-only", while model_path can be used for both reading and writing between calls
    return tuple(f"local:{p}" for p in paths)  # type: ignore

In [None]:
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)

data_path_url, model_path_url

paths=[Path('/tmp/tmpfinf6qnx/data'), Path('/tmp/tmpfinf6qnx/model')]


('local:/tmp/tmpfinf6qnx/data', 'local:/tmp/tmpfinf6qnx/model')

In [None]:
#| export

DEFAULT_EXEC_ENVIRONMENT = "preprocessing"

In [None]:
#| export


class AirflowAzureBatchExecutor(BaseAirflowExecutor):
    def __init__(
        self,
        steps: List[CLICommandBase],
        region: str,
        exec_environments: Optional[List[Optional[str]]] = None,
        batch_environment_path: Optional[Union[str, Path]] = None,
    ):
        """Constructs a new AirflowAzureBatchExecutor instance

        Args:
            steps: List of instances of either ClassCLICommand or SimpleCLICommand
            region: Region to execute
            exec_environments: List of execution environments to execute steps
            batch_environment_path: Path for yaml file in which azure batch environment names are stored
        """
        self.region = region
        self.batch_environment_path = batch_environment_path

        if exec_environments is None:
            exec_environments = [DEFAULT_EXEC_ENVIRONMENT] * len(steps)

        if len(exec_environments) != len(steps):
            raise ValueError(
                f"len(exec_environments)={len(exec_environments)} != len(steps){len(steps)}"
            )

        existing_exec_environments = list(
            get_azure_batch_environment_component_names(
                self.region, self.batch_environment_path
            ).keys()
        )

        self.exec_environments = []
        for exec_env in exec_environments:
            if exec_env is None:
                self.exec_environments.append(DEFAULT_EXEC_ENVIRONMENT)
                continue
            if exec_env not in existing_exec_environments:
                raise ValueError(
                    f"Invalid value {exec_env} given for exec environment; Allowed values are {existing_exec_environments}"
                )
            self.exec_environments.append(exec_env)

        self.exec_environments = [
            exec_env if exec_env is not None else DEFAULT_EXEC_ENVIRONMENT
            for exec_env in exec_environments
        ]

        super(AirflowAzureBatchExecutor, self).__init__(steps)

    def execute(
        self,
        *,
        description: str,
        tags: Union[str, List[str]],
        on_step_start: Optional[CLICommandBase] = None,
        on_step_end: Optional[CLICommandBase] = None,
        **kwargs,
    ) -> Tuple[Path, str]:
        """Create DAG and execute steps in airflow

        Args:
            description: description of DAG
            tags: tags for DAG
            on_step_start: CLI to call before executing step/task in DAG
            on_step_end: CLI to call after executing step/task in DAG
            kwargs: keyword arguments needed for steps/tasks
        Returns:
            A tuple which contains dag file path and run id
        """
        raise NotImplementedError("Need to implement")

In [None]:
def save_test_azure_batch_environment_names(folder: Path):
    region = "northeurope"
    test_batch_environment_names = {
        region: {
            task: {
                arn: "random_azure_batch_env_component_name"
                for arn in [
                    "batch_job_name",
                    "batch_pool_name",
                    "batch_account_name",
                ]
            }
            for task in ["csv_processing", "predictions", "preprocessing", "training"]
        }
    }

    folder = Path(folder)
    test_batch_environment_path = folder / "azure_batch_environment.yml"
    with open(test_batch_environment_path, "w") as f:
        yaml.dump(test_batch_environment_names, f, default_flow_style=False)

    return test_batch_environment_path

In [None]:
steps = [
    ClassCLICommand(
        executor_name="test-executor", class_name="MyTestExecutor", f_name="f"
    ),
    ClassCLICommand(
        executor_name="test-executor", class_name="MyTestExecutor", f_name="g"
    ),
]

In [None]:
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)

    test_batch_environment_path = save_test_azure_batch_environment_names(d)
    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        batch_environment_path=test_batch_environment_path,
    )
    display(abe.exec_environments)
    assert abe.exec_environments == ["preprocessing"] * len(steps)

    with pytest.raises(ValueError) as e:
        abe = AirflowAzureBatchExecutor(
            steps=steps,
            region=region,
            exec_environments=["preprocessing"],
            batch_environment_path=test_batch_environment_path,
        )
    display(e)

    with pytest.raises(ValueError) as e:
        abe = AirflowAzureBatchExecutor(
            steps=steps,
            region=region,
            exec_environments=["gibberish", "gibberish"],
            batch_environment_path=test_batch_environment_path,
        )
    display(e)

paths=[Path('/tmp/tmpl847y635/data'), Path('/tmp/tmpl847y635/model')]


['preprocessing', 'preprocessing']

<ExceptionInfo ValueError('len(exec_environments)=1 != len(steps)2') tblen=2>

<ExceptionInfo ValueError("Invalid value gibberish given for exec environment; Allowed values are ['csv_processing', 'predictions', 'preprocessing', 'training']") tblen=2>

In [None]:
#| export


@patch
def _create_step_template(
    self: AirflowAzureBatchExecutor,
    step: CLICommandBase,
    exec_environment: str,
    **kwargs,
):
    """
    Create template for step

    Args:
        step: step to create template
        kwargs: keyword arguments for step
    Returns:
        Template for step
    """
    cli_command = step.to_cli(**kwargs)
    task_id = slugify(cli_command)

    azure_batch_environment_vars = ""
    for name, value in get_environment_vars_for_batch_job().items():
        azure_batch_environment_vars = (
            azure_batch_environment_vars + f" --env {name}='{value}'"
        )

    (
        batch_account_name,
        batch_pool_name,
        batch_job_name,
    ) = get_batch_account_pool_job_names(
        task=exec_environment,
        region=self.region,
        batch_environment_path=self.batch_environment_path,
    )

    if exec_environment in ["training", "predictions"]:
        batch_pool_vm_size = "standard_nc6s_v3"
    elif exec_environment in ["csv_processing", "preprocessing"]:
        batch_pool_vm_size = "standard_d2s_v3"

    batch_task_id = f"batch-task-{get_random_string()}"
    azure_batch_conn_id = f'azure_batch_conn_id="custom_azure_batch_default"'
    task_params = f"""task_id="{task_id}", batch_pool_id="{batch_pool_name}", batch_pool_vm_size="{batch_pool_vm_size}", batch_job_id="{batch_job_name}", batch_task_command_line="{cli_command}", batch_task_id="{batch_task_id}", {azure_batch_conn_id}"""

    vm_details = f"""vm_publisher="microsoft-azure-batch", vm_offer="ubuntu-server-container", vm_sku="20-04-lts", vm_version="latest", vm_node_agent_sku_id="batch.node.ubuntu 20.04" """

    auto_scale_params = (
        f'enable_auto_scale=True, auto_scale_formula="""{AUTO_SCALE_FORMULA}"""'
    )

    tag = "dev"
    if os.environ["DOMAIN"] == "api.airt.ai":
        tag = "latest"
    batch_task_container_settings = f"""batch_task_container_settings=batchmodels.TaskContainerSettings(image_name="registry.gitlab.com/airt.ai/airt-service:{tag}", container_run_options="{azure_batch_environment_vars}")"""

    task = f"""AzureBatchOperator({task_params}, {vm_details}, {auto_scale_params}, {batch_task_container_settings})"""
    return task

In [None]:
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)

    test_batch_environment_path = save_test_azure_batch_environment_names(d)
    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        batch_environment_path=test_batch_environment_path,
    )
    actual = abe._create_step_template(
        steps[0],
        exec_environment="training",
        data_path_url=data_path_url,
        model_path_url=model_path_url,
    )
#     display(actual)

paths=[Path('/tmp/tmpvvqf2xsn/data'), Path('/tmp/tmpvvqf2xsn/model')]


In [None]:
#| export


@patch
def _create_dag_template(
    self: AirflowAzureBatchExecutor,
    on_step_start: Optional[CLICommandBase] = None,
    on_step_end: Optional[CLICommandBase] = None,
    **kwargs,
) -> str:
    """
    Create DAG template with steps as tasks

    Args:
        on_step_start: CLI to call before executing step/task in DAG
        on_step_end: CLI to call after executing step/task in DAG
        kwargs: keyword arguments to pass to steps' CLI
    Returns:
        Generated DAG with steps as tasks
    """
    curr_dag_template = dag_template

    downstream_tasks = ""
    newline = "\n"
    tab = " " * 4

    existing_tasks = 0
    for i, step in enumerate(self.steps):
        if on_step_start is not None:
            curr_dag_template += f"""{newline}{tab}t{existing_tasks+1} = {self._create_step_template(on_step_start, self.exec_environments[i], step_count=i+1, **kwargs)}"""  # type: ignore
            existing_tasks += 1

        curr_dag_template += f"""{newline}{tab}t{existing_tasks+1} = {self._create_step_template(step, self.exec_environments[i], **kwargs)}"""  # type: ignore
        existing_tasks += 1

        if on_step_end is not None:
            curr_dag_template += f"""{newline}{tab}t{existing_tasks+1} = {self._create_step_template(on_step_end, self.exec_environments[i], step_count=i+1, **kwargs)}"""  # type: ignore
            existing_tasks += 1

    downstream_tasks = f"{newline}{tab}" + " >> ".join(
        [f"t{i}" for i in range(1, existing_tasks + 1)]
    )
    curr_dag_template += downstream_tasks

    return curr_dag_template

In [None]:
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)
    test_batch_environment_path = save_test_azure_batch_environment_names(d)

    kwargs = {"data_path_url": data_path_url, "model_path_url": model_path_url}

    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        batch_environment_path=test_batch_environment_path,
    )

    on_step_start = SimpleCLICommand(command="sleep {step_count}")
    on_step_end = SimpleCLICommand(command="echo step {step_count} completed")
    template = abe._create_dag_template(
        on_step_start=on_step_start, on_step_end=on_step_end, **kwargs
    )
    display(template)

paths=[Path('/tmp/tmplx66w3m7/data'), Path('/tmp/tmplx66w3m7/model')]


'import datetime\nfrom textwrap import dedent\n\n# The DAG object; we\'ll need this to instantiate a DAG\nfrom airflow import DAG\n\n# Operators; we need this to operate!\nfrom airflow.providers.amazon.aws.operators.batch import BatchOperator\nimport azure.batch.models as batchmodels\nfrom airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator\nfrom airflow.operators.bash import BashOperator\nfrom airflow.operators.trigger_dagrun import TriggerDagRunOperator\nwith DAG(\n    \'{dag_name}\',\n    # These args will get passed on to each operator\n    # You can override them on a per-task basis during operator initialization\n    default_args={{\n        \'schedule_interval\': {schedule_interval},\n        \'depends_on_past\': False,\n        \'email\': [\'info@airt.ai\'],\n        \'email_on_failure\': False,\n        \'email_on_retry\': False,\n        \'retries\': 1,\n        \'retry_delay\': datetime.timedelta(minutes=5),\n        # \'queue\': \'queue\',\n        #

In [None]:
# | eval: false
# Test case for AirflowAzureBatchExecutor._create_dag
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)
    steps = [
        SimpleCLICommand(command="env"),
        ClassCLICommand(
            executor_name="test-executor", class_name="MyTestExecutor", f_name="f"
        ),
        #         ClassCLICommand(
        #             executor_name="test-executor", class_name="MyTestExecutor", f_name="g"
        #         ),
    ]
    exec_environments = ["training", None]
    on_step_start = SimpleCLICommand(command="sleep {step_count}")
    on_step_end = SimpleCLICommand(command="echo step {step_count} completed")

    td = Path(d)
    created_azure_env_path = td / "azure_batch_environment.yml"

    shared_key_credentials = SharedKeyCredentials(
        "testbatchnortheurope", os.environ["SHARED_KEY_CREDENTIALS"]
    )

    batch_account_name = "testbatchnortheurope"
    region = "northeurope"

    batch_pool = BatchPool.from_name(
        name="test-cpu-pool",
        batch_account_name=batch_account_name,
        region=region,
        shared_key_credentials=shared_key_credentials,
    )
    batch_job = BatchJob.from_name(name="test-cpu-job", batch_pool=batch_pool)

    display(f"{batch_pool.name=}")
    display(f"{batch_job.name=}")
    region = "northeurope"
    test_batch_environment_names = {
        region: {
            task: {
                "batch_job_name": batch_job.name,
                "batch_pool_name": batch_pool.name,
                "batch_account_name": batch_account_name,
            }
            for task in [
                "csv_processing",
                "predictions",
                "preprocessing",
                "training",
            ]
        }
    }
    display(f"{test_batch_environment_names=}")
    with open(created_azure_env_path, "w") as f:
        yaml.dump(test_batch_environment_names, f, default_flow_style=False)

    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        exec_environments=exec_environments,
        batch_environment_path=created_azure_env_path,
    )
    dag_id, dag_file_path = abe._create_dag(
        data_path_url=data_path_url,
        model_path_url=model_path_url,
        #         schedule_interval="@weekly",
        schedule_interval=None,
        description="test description",
        tags="test_tag",
        on_step_start=on_step_start,
        on_step_end=on_step_end,
    )

    display(f"{dag_file_path=}")
    dag_id = str(dag_file_path).split("/")[-1].split(".py")[0]

    sleep(15)

    dag_runs = list_dag_runs(dag_id=dag_id)
    display(f"{dag_runs=}")

    run_id = trigger_dag(dag_id=dag_id, conf={})

    #     run_id = dag_runs[0]["run_id"]
    display(run_id)
    state = wait_for_run_to_complete(dag_id=dag_id, run_id=run_id, timeout=3600)
    display(state)
    dag_file_path.unlink()

paths=[Path('/tmp/tmp5svusq8h/data'), Path('/tmp/tmp5svusq8h/model')]


"batch_pool.name='test-cpu-pool'"

"batch_job.name='test-cpu-job'"

"test_batch_environment_names={'northeurope': {'csv_processing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'predictions': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'preprocessing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'training': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}}}"

"dag_file_path=Path('/root/airflow/dags/env_test-executor-my_test_executor-f-data-path-urllocaltmptmp5svusq8hdata-model-path-urllocaltmptmp5svusq8hmodel.py')"

'dag_runs=[]'

[{'dag_id': 'env_test-executor-my_test_executor-f-data-path-urllocaltmptmp5svusq8hdata-model-path-urllocaltmptmp5svusq8hmodel', 'run_id': 'airt-service__2022-10-20T06:56:04.259388', 'state': 'running', 'execution_date': '2022-10-20T06:56:05+00:00', 'start_date': '2022-10-20T06:56:05.939394+00:00', 'end_date': ''}]


'airt-service__2022-10-20T06:56:04.259388'

'success'

In [None]:
# | eval: false
# Test case for AirflowAzureBatchExecutor.schedule
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)
    steps = [
        SimpleCLICommand(command="env"),
        ClassCLICommand(
            executor_name="test-executor", class_name="MyTestExecutor", f_name="f"
        ),
        #         ClassCLICommand(
        #             executor_name="test-executor", class_name="MyTestExecutor", f_name="g"
        #         ),
    ]
    exec_environments = ["csv_processing", "preprocessing"]
    on_step_start = SimpleCLICommand(command="sleep {step_count}")
    on_step_end = SimpleCLICommand(command="echo step {step_count} completed")

    td = Path(d)
    created_azure_env_path = td / "azure_batch_environment.yml"

    shared_key_credentials = SharedKeyCredentials(
        "testbatchnortheurope", os.environ["SHARED_KEY_CREDENTIALS"]
    )

    batch_account_name = "testbatchnortheurope"
    region = "northeurope"

    batch_pool = BatchPool.from_name(
        name="test-cpu-pool",
        batch_account_name=batch_account_name,
        region=region,
        shared_key_credentials=shared_key_credentials,
    )
    batch_job = BatchJob.from_name(name="test-cpu-job", batch_pool=batch_pool)

    display(f"{batch_pool.name=}")
    display(f"{batch_job.name=}")
    region = "northeurope"
    test_batch_environment_names = {
        region: {
            task: {
                "batch_job_name": batch_job.name,
                "batch_pool_name": batch_pool.name,
                "batch_account_name": batch_account_name,
            }
            for task in [
                "csv_processing",
                "predictions",
                "preprocessing",
                "training",
            ]
        }
    }
    display(f"{test_batch_environment_names=}")
    with open(created_azure_env_path, "w") as f:
        yaml.dump(test_batch_environment_names, f, default_flow_style=False)

    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        exec_environments=exec_environments,
        batch_environment_path=created_azure_env_path,
    )
    dag_file_path = abe.schedule(
        data_path_url=data_path_url,
        model_path_url=model_path_url,
        #         schedule_interval="@weekly",
        schedule_interval=timedelta(days=7),
        description="test description",
        tags="test_tag",
        on_step_start=on_step_start,
        on_step_end=on_step_end,
    )

    display(f"{dag_file_path=}")
    dag_id = str(dag_file_path).split("/")[-1].split(".py")[0]

    sleep(15)

    dag_runs = list_dag_runs(dag_id=dag_id)
    display(f"{dag_runs=}")

    run_id = trigger_dag(dag_id=dag_id, conf={})

    #     run_id = dag_runs[0]["run_id"]
    display(run_id)
    state = wait_for_run_to_complete(dag_id=dag_id, run_id=run_id, timeout=3600)
    display(state)
    dag_file_path.unlink()

paths=[Path('/tmp/tmpog4g09bp/data'), Path('/tmp/tmpog4g09bp/model')]


"batch_pool.name='test-cpu-pool'"

"batch_job.name='test-cpu-job'"

"test_batch_environment_names={'northeurope': {'csv_processing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'predictions': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'preprocessing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'training': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}}}"

"dag_file_path=Path('/root/airflow/dags/env_test-executor-my_test_executor-f-data-path-urllocaltmptmpog4g09bpdata-model-path-urllocaltmptmpog4g09bpmodel.py')"

'dag_runs=[]'

[{'dag_id': 'env_test-executor-my_test_executor-f-data-path-urllocaltmptmpog4g09bpdata-model-path-urllocaltmptmpog4g09bpmodel', 'run_id': 'airt-service__2022-10-20T07:09:23.546392', 'state': 'running', 'execution_date': '2022-10-20T07:09:24+00:00', 'start_date': '2022-10-20T07:09:25.152173+00:00', 'end_date': ''}]


'airt-service__2022-10-20T07:09:23.546392'

'success'

In [None]:
#| export


@patch
def execute(
    self: AirflowAzureBatchExecutor,
    *,
    description: str,
    tags: Union[str, List[str]],
    on_step_start: Optional[CLICommandBase] = None,
    on_step_end: Optional[CLICommandBase] = None,
    **kwargs
) -> Tuple[Path, str]:
    """Create DAG and execute steps in airflow

    Args:
        description: description of DAG
        tags: tags for DAG
        on_step_start: CLI to call before executing step/task in DAG
        on_step_end: CLI to call after executing step/task in DAG
        kwargs: keyword arguments needed for steps/tasks
    Returns:
        A tuple which contains dag file path and run id
    """
    schedule_interval = None
    dag_id, dag_file_path = self._create_dag(
        schedule_interval=schedule_interval,
        description=description,
        tags=tags,
        on_step_start=on_step_start,
        on_step_end=on_step_end,
        **kwargs
    )

    run_id = trigger_dag(dag_id=dag_id, conf={})
    return dag_file_path, run_id

In [None]:
# | eval: false
# Test case for AirflowAzureBatchExecutor.execute
region = "northeurope"
with tempfile.TemporaryDirectory() as d:
    data_path_url, model_path_url = setup_test_paths(d)
    steps = [
        SimpleCLICommand(command="env"),
        ClassCLICommand(
            executor_name="test-executor", class_name="MyTestExecutor", f_name="f"
        ),
        #         ClassCLICommand(
        #             executor_name="test-executor", class_name="MyTestExecutor", f_name="g"
        #         ),
    ]
    exec_environments = ["training", "predictions"]
    on_step_start = SimpleCLICommand(command="sleep {step_count}")
    on_step_end = SimpleCLICommand(command="echo step {step_count} completed")

    td = Path(d)
    created_azure_env_path = td / "azure_batch_environment.yml"

    shared_key_credentials = SharedKeyCredentials(
        "testbatchnortheurope", os.environ["SHARED_KEY_CREDENTIALS"]
    )

    batch_account_name = "testbatchnortheurope"
    region = "northeurope"

    batch_pool = BatchPool.from_name(
        name="test-cpu-pool",
        batch_account_name=batch_account_name,
        region=region,
        shared_key_credentials=shared_key_credentials,
    )
    batch_job = BatchJob.from_name(name="test-cpu-job", batch_pool=batch_pool)

    display(f"{batch_pool.name=}")
    display(f"{batch_job.name=}")
    region = "northeurope"
    test_batch_environment_names = {
        region: {
            task: {
                "batch_job_name": batch_job.name,
                "batch_pool_name": batch_pool.name,
                "batch_account_name": batch_account_name,
            }
            for task in [
                "csv_processing",
                "predictions",
                "preprocessing",
                "training",
            ]
        }
    }
    display(f"{test_batch_environment_names=}")
    with open(created_azure_env_path, "w") as f:
        yaml.dump(test_batch_environment_names, f, default_flow_style=False)

    abe = AirflowAzureBatchExecutor(
        steps=steps,
        region=region,
        exec_environments=exec_environments,
        batch_environment_path=created_azure_env_path,
    )
    dag_file_path, run_id = abe.execute(
        data_path_url=data_path_url,
        model_path_url=model_path_url,
        description="test description",
        tags="test_tag",
        on_step_start=on_step_start,
        on_step_end=on_step_end,
    )

    display(dag_file_path)
    display(run_id)

    dag_id = str(dag_file_path).split("/")[-1].split(".py")[0]
    state = wait_for_run_to_complete(dag_id=dag_id, run_id=run_id, timeout=3600)
    display(state)
    dag_file_path.unlink()

paths=[Path('/tmp/tmp2bk3wgmy/data'), Path('/tmp/tmp2bk3wgmy/model')]


"batch_pool.name='test-cpu-pool'"

"batch_job.name='test-cpu-job'"

"test_batch_environment_names={'northeurope': {'csv_processing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'predictions': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'preprocessing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'training': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}}}"

[{'dag_id': 'env_test-executor-my_test_executor-f-data-path-urllocaltmptmp2bk3wgmydata-model-path-urllocaltmptmp2bk3wgmymodel', 'run_id': 'airt-service__2022-10-20T07:11:48.082338', 'state': 'running', 'execution_date': '2022-10-20T07:11:49+00:00', 'start_date': '2022-10-20T07:11:50.110928+00:00', 'end_date': ''}]


Path('/root/airflow/dags/env_test-executor-my_test_executor-f-data-path-urllocaltmptmp2bk3wgmydata-model-path-urllocaltmptmp2bk3wgmymodel.py')

'airt-service__2022-10-20T07:11:48.082338'

'success'

In [None]:
#| export


def _test_azure_batch_executor(region: str = "northeurope"):  # type: ignore
    with tempfile.TemporaryDirectory() as d:
        data_path_url, model_path_url = setup_test_paths(d)

        steps = [
            ClassCLICommand(
                executor_name="test-executor", class_name="MyTestExecutor", f_name="f"
            )
        ]
        exec_environments = ["training"]

        td = Path(d)
        created_azure_env_path = td / "azure_batch_environment.yml"

        shared_key_credentials = SharedKeyCredentials(
            "testbatchnortheurope", os.environ["SHARED_KEY_CREDENTIALS"]
        )

        batch_account_name = "testbatchnortheurope"
        region = "northeurope"

        batch_pool = BatchPool.from_name(
            name="test-cpu-pool",
            batch_account_name=batch_account_name,
            region=region,
            shared_key_credentials=shared_key_credentials,
        )
        batch_job = BatchJob.from_name(name="test-cpu-job", batch_pool=batch_pool)

        sanitized_print(f"{batch_pool.name=}")
        sanitized_print(f"{batch_job.name=}")
        region = "northeurope"
        test_batch_environment_names = {
            region: {
                task: {
                    "batch_job_name": batch_job.name,
                    "batch_pool_name": batch_pool.name,
                    "batch_account_name": batch_account_name,
                }
                for task in [
                    "csv_processing",
                    "predictions",
                    "preprocessing",
                    "training",
                ]
            }
        }
        sanitized_print(f"{test_batch_environment_names=}")
        with open(created_azure_env_path, "w") as f:
            yaml.dump(test_batch_environment_names, f, default_flow_style=False)

        abe = AirflowAzureBatchExecutor(
            steps=steps,
            region=region,
            exec_environments=exec_environments,  # type: ignore
            batch_environment_path=created_azure_env_path,
        )
        dag_file_path, run_id = abe.execute(
            data_path_url=data_path_url,
            model_path_url=model_path_url,
            description="test description",
            tags="test_tag",
        )

        sanitized_print(f"{dag_file_path=}")
        sanitized_print(f"{run_id=}")

        dag_id = str(dag_file_path).split("/")[-1].split(".py")[0]
        state = wait_for_run_to_complete(dag_id=dag_id, run_id=run_id, timeout=3600)
        sanitized_print(f"{state=}")
        dag_file_path.unlink()

In [None]:
#| export


@call_parse
def test_azure_batch_executor(region: Param("region", str) = "northeurope"):  # type: ignore
    """
    Create throw away environment for azure batch and execute airflow batch executor
    """
    _test_azure_batch_executor(region=region)

In [None]:
# | eval: false
test_azure_batch_executor()

paths=[Path('/tmp/tmp7xne3phb/data'), Path('/tmp/tmp7xne3phb/model')]
batch_pool.name='test-cpu-pool'
batch_job.name='test-cpu-job'
test_batch_environment_names={'northeurope': {'csv_processing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'predictions': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'preprocessing': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}, 'training': {'batch_job_name': 'test-cpu-job', 'batch_pool_name': 'test-cpu-pool', 'batch_account_name': 'testbatchnortheurope'}}}
[{'dag_id': 'test-executor-my_test_executor-f-data-path-urllocaltmptmp7xne3phbdata-model-path-urllocaltmptmp7xne3phbmodel', 'run_id': 'airt-service__2022-10-20T07:14:15.998356', 'state': 'running', 'execution_date': '2022-10-20T07:14:17+00:00', 'start_date': '2022-10-20T07:14:17.748007+