In [None]:
# | default_exp batch_job_components.aws

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
# | export

from typing import *

from airt.logger import get_logger

import airt_service
import airt_service.sanitizer
from airt_service.aws.batch_utils import aws_batch_create_job
from airt_service.aws.utils import get_queue_definition_arns
from airt_service.batch_job_components.base import BatchJobContext

In [None]:
from _pytest.monkeypatch import MonkeyPatch
from fastcore.utils import patch

from airt_service.helpers import set_env_variable_context

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
old_setattr = MonkeyPatch.setattr


@patch
def setattr(self: MonkeyPatch, *args, **kwargs):
    global logger
    old_setattr(self, *args, **kwargs)
    logger = get_logger(__name__)

In [None]:
# | export


class AwsBatchJobContext(BatchJobContext):
    """A class for creating AwsBatchJobContext"""

    def __init__(self, task: str, **kwargs: Any):
        """AWS Batch Job Context

        Do not use __init__, please use factory method `create` to initiate object
        """
        BatchJobContext.__init__(self, task=task)
        self.region = kwargs["region"]

    def create_job(self, command: str, environment_vars: Dict[str, str]) -> None:
        """Create a new job

        Args:
            command: Command to execute in job
            environment_vars: Environment vars to set in the container
        """
        logger.info(
            f"{self.__class__.__name__}.create_job({self=}, {command=}, {environment_vars=})"
        )
        (
            job_queue_arn,
            job_definition_arn,
        ) = airt_service.aws.utils.get_queue_definition_arns(self.task, self.region)

        airt_service.aws.batch_utils.aws_batch_create_job(
            job_queue_arn=job_queue_arn,
            job_definition_arn=job_definition_arn,
            region=self.region,
            command=command,
            environment_vars=environment_vars,
        )


AwsBatchJobContext.add_factory()

In [None]:
with MonkeyPatch.context() as monkeypatch:
    job_queue_arn = "aws:job_queue_arn"
    job_definition_arn = "aws:job_definition_arn"
    region = "eu-west-1"
    monkeypatch.setattr(
        "airt_service.aws.utils.get_queue_definition_arns",
        lambda task, region: (job_queue_arn, job_definition_arn),
    )

    def test_patch_create_job(*args, **kwargs):
        display(f"{kwargs=}")
        assert kwargs["job_queue_arn"] == job_queue_arn
        assert kwargs["job_definition_arn"] == job_definition_arn
        assert kwargs["region"] == region
        assert (
            kwargs["command"]
            == f"process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data"
        )
        assert "AWS_ACCESS_KEY_ID" in kwargs["environment_vars"]
        assert "AWS_SECRET_ACCESS_KEY" in kwargs["environment_vars"]

    monkeypatch.setattr(
        "airt_service.aws.batch_utils.aws_batch_create_job", test_patch_create_job
    )

    with BatchJobContext.create("csv_processing", region=region) as batch_ctx:
        display(batch_ctx)
        assert batch_ctx.__class__.__name__ == "AwsBatchJobContext"
        batch_ctx.create_job(
            command="process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data",
            environment_vars={
                "AWS_ACCESS_KEY_ID": "random_value",
                "AWS_SECRET_ACCESS_KEY": "random_value",
            },
        )

[INFO] airt_service.batch_job_components.base: Entering AwsBatchJobContext(task=csv_processing)


AwsBatchJobContext(task=csv_processing)

[INFO] __main__: AwsBatchJobContext.create_job(self=AwsBatchJobContext(task=csv_processing), command='process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************'})


"kwargs={'job_queue_arn': 'aws:job_queue_arn', 'job_definition_arn': 'aws:job_definition_arn', 'region': 'eu-west-1', 'command': 'process_csv_for 3 PersonId OccurredTime --blocksize 256MB --deduplicate_data', 'environment_vars': {'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************'}}"

[INFO] airt_service.batch_job_components.base: Exiting AwsBatchJobContext(task=csv_processing): exc_type=None, exc=None, None
