In [None]:
# | default_exp batch_job_components.azure

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.


2022-12-19 07:54:31.649615: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.


In [None]:
# | export

from os import environ
from typing import *

import azure.batch.models as batchmodels
from airt.logger import get_logger

import airt_service
import airt_service.sanitizer
from airt_service.azure.batch_utils import azure_batch_create_job
from airt_service.azure.utils import get_batch_account_pool_job_names
from airt_service.batch_job_components.base import BatchJobContext

In [None]:
from _pytest.monkeypatch import MonkeyPatch
from fastcore.utils import patch

from airt_service.helpers import set_env_variable_context

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
old_setattr = MonkeyPatch.setattr


@patch
def setattr(self: MonkeyPatch, *args, **kwargs):
    global logger
    old_setattr(self, *args, **kwargs)
    logger = get_logger(__name__)

In [None]:
# | export


class AzureBatchJobContext(BatchJobContext):
    """A class for creating AzureBatchJobContext"""

    def __init__(self, task: str, **kwargs):
        """Azure Batch Job Context

        Do not use __init__, please use factory method `create` to initiate object
        """
        BatchJobContext.__init__(self, task=task)
        self.region = kwargs["region"]

    def create_job(self, command: str, environment_vars: Dict[str, str]):
        """Create a new job

        Args:
            command: Command to execute in job
            environment_vars: Environment vars to set in the container
        """
        logger.info(
            f"{self.__class__.__name__}.create_job({self=}, {command=}, {environment_vars=})"
        )
        # ToDo: We have batch accounts available only in northeurope for now
        region = "northeurope"
        (
            batch_account_name,
            batch_pool_name,
            batch_job_name,
        ) = airt_service.azure.utils.get_batch_account_pool_job_names(self.task, region)

        tag = "dev"
        if environ["DOMAIN"] == "api.airt.ai":
            tag = "latest"

        container_settings = batchmodels.TaskContainerSettings(
            image_name=f"ghcr.io/airtai/airt-service:{tag}"
        )

        airt_service.azure.batch_utils.azure_batch_create_job(
            command=command,
            batch_job_name=batch_job_name,
            batch_pool_name=batch_pool_name,
            batch_account_name=batch_account_name,
            region=region,
            container_settings=container_settings,
            environment_vars=environment_vars,
        )


AzureBatchJobContext.add_factory()

In [None]:
with MonkeyPatch.context() as monkeypatch:
    batch_account_name = "batch_account_name"
    batch_pool_name = "batch_pool_name"
    batch_job_name = "batch_job_name"
    region = "northeurope"
    monkeypatch.setattr(
        "airt_service.azure.utils.get_batch_account_pool_job_names",
        lambda task, region: (
            batch_account_name,
            batch_pool_name,
            batch_job_name,
        ),
    )

    def test_patch_create_job(*args, **kwargs):
        display(f"{kwargs=}")
        assert kwargs["batch_account_name"] == batch_account_name
        assert kwargs["batch_pool_name"] == batch_pool_name
        assert kwargs["batch_job_name"] == batch_job_name
        assert kwargs["region"] == region
        assert (
            kwargs["command"]
            == f"process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data"
        )
        assert "AZURE_SUBSCRIPTION_ID" in kwargs["environment_vars"]

    monkeypatch.setattr(
        "airt_service.azure.batch_utils.azure_batch_create_job", test_patch_create_job
    )

    with BatchJobContext.create(
        "csv_processing", cloud_provider="azure", region=region
    ) as batch_ctx:
        display(batch_ctx)
        assert batch_ctx.__class__.__name__ == "AzureBatchJobContext"
        batch_ctx.create_job(
            command="process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data",
            environment_vars={
                "AZURE_SUBSCRIPTION_ID": "random_value",
            },
        )

[INFO] airt_service.batch_job_components.base: Entering AzureBatchJobContext(task=csv_processing)


AzureBatchJobContext(task=csv_processing)

[INFO] __main__: AzureBatchJobContext.create_job(self=AzureBatchJobContext(task=csv_processing), command='process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data', environment_vars={'AZURE_SUBSCRIPTION_ID': '************************************'})


"kwargs={'command': 'process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data', 'batch_job_name': 'batch_job_name', 'batch_pool_name': 'batch_pool_name', 'batch_account_name': 'batch_account_name', 'region': 'northeurope', 'container_settings': <azure.batch.models._models_py3.TaskContainerSettings object>, 'environment_vars': {'AZURE_SUBSCRIPTION_ID': '************************************'}}"

[INFO] airt_service.batch_job_components.base: Exiting AzureBatchJobContext(task=csv_processing): exc_type=None, exc=None, None
