In [None]:
# | default_exp batch_job_components.base

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
# | export

from os import environ
from typing import *

import airt_service.sanitizer
from airt.logger import get_logger
from airt_service.aws.utils import get_available_aws_regions
from airt_service.azure.utils import get_available_azure_regions

In [None]:
import pytest
from airt_service.helpers import set_env_variable_context

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
# | export


class BatchJobContext:
    def __init__(self, task: str):
        """Batch Job Context

        Do not use __init__, please use factory method `create` to initiate object
        """
        self.task = task

    def create_job(self, command: str, environment_vars: Dict[str, str]):
        """Create a new job

        Args:
            command: Command to execute in job
            environment_vars: Environment vars to set in the container
        """
        raise NotImplementedError()

    _factories: Dict[str, Any] = {}

    @classmethod
    def create(cls, task: str, **kwargs) -> "BatchJobContext":
        """Factory method to create a new job

        Args:
            task: Task name to get batch environment info; One of csv_processing, predictions, preprocessing and training
            kwargs: Key word arguments which will be passed to the constructor of inherited class

        Returns:
            The initialized object of the inherited class
        """
        # default executor is AWS & Azure
        ctx_name = (
            "AzureBatchJobContext"
            if "cloud_provider" in kwargs and kwargs["cloud_provider"] == "azure"
            else "AwsBatchJobContext"
        )

        if "JOB_EXECUTOR" in environ:
            if environ["JOB_EXECUTOR"] == "aws":
                cloud_provider = kwargs["cloud_provider"]
                ctx_name = (
                    "AzureBatchJobContext"
                    if cloud_provider == "azure"
                    else "AwsBatchJobContext"
                )
            elif environ["JOB_EXECUTOR"] == "fastapi":
                ctx_name = "FastAPIBatchJobContext"
            elif environ["JOB_EXECUTOR"] == "none":
                ctx_name = "NoneBatchJobContext"
            else:
                raise ValueError(f'Unknown value: {environ["JOB_EXECUTOR"]=}')

        factory = BatchJobContext._factories[ctx_name]
        return factory(task=task, **kwargs)

    @classmethod
    def add_factory(cls):
        BatchJobContext._factories[cls.__name__] = cls

    def __enter__(self):
        logger.info(f"Entering {self}")
        return self

    def __exit__(self, exc_type, exc, exc_tb):
        logger.info(f"Exiting {self}: {exc_type=}, {exc=}, {exc_tb}")
        return False

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(task={self.task})"

In [None]:
with pytest.raises(NotImplementedError) as e:
    with BatchJobContext(task="csv_processing") as batch_ctx:
        batch_ctx.create_job(command="ls", environment_vars={})

with set_env_variable_context(variable="JOB_EXECUTOR", value="something"):
    with pytest.raises(ValueError) as e:
        with BatchJobContext.create("csv_processing", region="eu-west-1") as batch_ctx:
            pass
    display(e)

[INFO] __main__: Entering BatchJobContext(task=csv_processing)
[INFO] __main__: Exiting BatchJobContext(task=csv_processing): exc_type=<class 'NotImplementedError'>, exc=NotImplementedError(), <traceback object>


<ExceptionInfo ValueError('Unknown value: environ["JOB_EXECUTOR"]=\'something\'') tblen=2>