In [None]:
#| default_exp batch_job

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
#| export

import airt_service.sanitizer

from airt_service.batch_job_components.base import BatchJobContext

from airt_service.batch_job_components.aws import AwsBatchJobContext
from airt_service.batch_job_components.azure import AzureBatchJobContext
from airt_service.batch_job_components.fastapi import FastAPIBatchJobContext
from airt_service.batch_job_components.none import NoneBatchJobContext

In [None]:
assert len(BatchJobContext._factories) > 0
BatchJobContext._factories

{'AwsBatchJobContext': airt_service.batch_job_components.aws.AwsBatchJobContext,
 'AzureBatchJobContext': airt_service.batch_job_components.azure.AzureBatchJobContext,
 'FastAPIBatchJobContext': airt_service.batch_job_components.fastapi.FastAPIBatchJobContext,
 'NoneBatchJobContext': airt_service.batch_job_components.none.NoneBatchJobContext}

In [None]:
#| export

from os import environ

from fastapi import BackgroundTasks

from airt.logger import get_logger

In [None]:
from _pytest.monkeypatch import MonkeyPatch

from airt_service.background_task import execute_cli
from airt_service.helpers import set_env_variable_context

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
#| export


def get_environment_vars_for_batch_job() -> dict:
    """Get the necessary environment variables for creating a batch job

    Returns:
        The environment variables as a dict
    """
    return {
        var: environ[var]
        for var in [
            "AWS_ACCESS_KEY_ID",
            "AWS_SECRET_ACCESS_KEY",
            "AWS_DEFAULT_REGION",
            "AZURE_SUBSCRIPTION_ID",
            "AZURE_TENANT_ID",
            "AZURE_CLIENT_ID",
            "AZURE_CLIENT_SECRET",
            "AZURE_STORAGE_ACCOUNT_PREFIX",
            "AZURE_RESOURCE_GROUP",
            #             "AIRT_SERVICE_SUPER_USER_PASSWORD",
            #             "AIRT_TOKEN_SECRET_KEY",
            "STORAGE_BUCKET_PREFIX",
            "DB_USERNAME",
            "DB_PASSWORD",
            "DB_HOST",
            "DB_PORT",
            "DB_DATABASE",
            "DB_DATABASE_SERVER",
        ]
    }

In [None]:
actual = get_environment_vars_for_batch_job()
assert "STORAGE_BUCKET_PREFIX" in actual
assert actual["STORAGE_BUCKET_PREFIX"]
# actual

In [None]:
#| export


def create_batch_job(
    command: str,
    task: str,
    cloud_provider: str,
    region: str,
    background_tasks: BackgroundTasks,
):
    """Create a new batch job

    Args:
        command: The CLI command as a string
        task: Task name as a string
        cloud_provider: Cloud provider in which to execute batch job
        region: Region to execute
        background_tasks: An instance of BackgroundTasks
    """
    logger.info(f"create_batch_job(): {command=}, {task=}")
    with BatchJobContext.create(
        task,
        cloud_provider=cloud_provider,
        region=region,
        background_tasks=background_tasks,
    ) as batch_ctx:
        logger.info(f"{batch_ctx=}")
        batch_ctx.create_job(
            command=command, environment_vars=get_environment_vars_for_batch_job()
        )

In [None]:
b = BackgroundTasks()

# Test using FastAPIBatchJobContext with set_env_variable_context
with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
    test_command = "s3_pull 1"
    create_batch_job(command=test_command, task="csv_processing", cloud_provider="azure", region="eu-west-1", background_tasks=b)

bg_task = b.tasks[-1]
display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
assert bg_task.func == execute_cli
assert bg_task.kwargs["command"] == test_command

[INFO] __main__: create_batch_job(): command='s3_pull 1', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] __main__: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='s3_pull 1', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 'DB_USERNAME': 'root', 'DB_P

'bg_task.func=<function execute_cli>'

'bg_task.args=()'

"bg_task.kwargs={'command': 's3_pull 1'}"

In [None]:
# Test using AwsBatchJobContext with MonkeyPatch
with MonkeyPatch.context() as monkeypatch:
    job_queue_arn = "aws:job_queue_arn"
    job_definition_arn = "aws:job_definition_arn"
    cloud_provider = "aws"
    region = "eu-west-1"
    monkeypatch.setattr(
        "airt_service.aws.utils.get_queue_definition_arns",
        lambda task, region: (job_queue_arn, job_definition_arn),
    )

    test_command = "db_pull 1"

    def test_patch_create_job(*args, **kwargs):
        display(f"{kwargs=}")
        assert kwargs["job_queue_arn"] == job_queue_arn
        assert kwargs["job_definition_arn"] == job_definition_arn
        assert kwargs["command"] == test_command
        assert kwargs["region"] == region
        assert "AWS_ACCESS_KEY_ID" in kwargs["environment_vars"]
        assert "AWS_SECRET_ACCESS_KEY" in kwargs["environment_vars"]

    monkeypatch.setattr(
        "airt_service.aws.batch_utils.aws_batch_create_job", test_patch_create_job
    )

    b = BackgroundTasks()
    create_batch_job(
        command=test_command,
        task="csv_processing",
        cloud_provider=cloud_provider,
        region=region,
        background_tasks=b,
    )

[INFO] __main__: create_batch_job(): command='db_pull 1', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering AwsBatchJobContext(task=csv_processing)
[INFO] __main__: batch_ctx=AwsBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.aws: AwsBatchJobContext.create_job(self=AwsBatchJobContext(task=csv_processing), command='db_pull 1', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 'DB_USERNAME': 'root', 'DB_PASSWORD': '*********

"kwargs={'job_queue_arn': 'aws:job_queue_arn', 'job_definition_arn': 'aws:job_definition_arn', 'region': 'eu-west-1', 'command': 'db_pull 1', 'environment_vars': {'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 'DB_USERNAME': 'root', 'DB_PASSWORD': '****************************************', 'DB_HOST': 'kumaran-mysql', 'DB_PORT': '3306', 'DB_DATABASE': 'airt_service', 'DB_DATABASE_SERVER': 'mysql'}}"

[INFO] airt_service.batch_job_components.base: Exiting AwsBatchJobContext(task=csv_processing): exc_type=None, exc=None, None


In [None]:
with MonkeyPatch.context() as monkeypatch:
    batch_account_name = "batch_account_name"
    batch_pool_name = "batch_pool_name"
    batch_job_name = "batch_job_name"
    cloud_provider = "azure"
    region = "northeurope"
    monkeypatch.setattr(
        "airt_service.azure.utils.get_batch_account_pool_job_names",
        lambda task, region: (
            batch_account_name,
            batch_pool_name,
            batch_job_name,
        ),
    )

    test_command = "db_pull 1"

    def test_patch_create_job(*args, **kwargs):
        display(f"{kwargs=}")
        assert kwargs["batch_account_name"] == batch_account_name
        assert kwargs["batch_pool_name"] == batch_pool_name
        assert kwargs["batch_job_name"] == batch_job_name
        assert kwargs["region"] == region
        assert kwargs["command"] == test_command
        assert "AZURE_SUBSCRIPTION_ID" in kwargs["environment_vars"]

    monkeypatch.setattr(
        "airt_service.azure.batch_utils.azure_batch_create_job", test_patch_create_job
    )

    b = BackgroundTasks()
    create_batch_job(
        command=test_command,
        task="csv_processing",
        cloud_provider=cloud_provider,
        region=region,
        background_tasks=b,
    )

[INFO] __main__: create_batch_job(): command='db_pull 1', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering AzureBatchJobContext(task=csv_processing)
[INFO] __main__: batch_ctx=AzureBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.azure: AzureBatchJobContext.create_job(self=AzureBatchJobContext(task=csv_processing), command='db_pull 1', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 'DB_USERNAME': 'root', 'DB_PASSWORD': 

"kwargs={'command': 'db_pull 1', 'batch_job_name': 'batch_job_name', 'batch_pool_name': 'batch_pool_name', 'batch_account_name': 'batch_account_name', 'region': 'northeurope', 'container_settings': <azure.batch.models._models_py3.TaskContainerSettings object>, 'environment_vars': {'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 'DB_USERNAME': 'root', 'DB_PASSWORD': '****************************************', 'DB_HOST': 'kumaran-mysql', 'DB_PORT': '3306', 'DB_DATABASE': 'airt_service', 'DB_DATABA

[INFO] airt_service.batch_job_components.base: Exiting AzureBatchJobContext(task=csv_processing): exc_type=None, exc=None, None


In [None]:
#| exporti


def update_all():
    global __all__
    __all__ = [
        "BatchJobContext",
        "AwsBatchJobContext",
        "AzureBatchJobContext",
        "FastAPIBatchJobContext",
        "NoneBatchJobContext",
        "get_environment_vars_for_batch_job",
        "create_batch_job",
    ]


update_all()