Skip to content

Commit

Permalink
Fix incorrect typing, remove hardcoded argument values and improve co…
Browse files Browse the repository at this point in the history
…de in AzureContainerInstancesOperator (#11408)
  • Loading branch information
ephraimbuddy committed Oct 11, 2020
1 parent 5bc5994 commit 686e0ee
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 18 deletions.
Expand Up @@ -28,6 +28,8 @@
ResourceRequests,
ResourceRequirements,
VolumeMount,
IpAddress,
ContainerPort,
)
from msrestazure.azure_exceptions import CloudError

Expand Down Expand Up @@ -88,37 +90,44 @@ class AzureContainerInstancesOperator(BaseOperator):
:param gpu: GPU Resource for the container.
:type gpu: azure.mgmt.containerinstance.models.GpuResource
:param command: the command to run inside the container
:type command: Optional[str]
:type command: Optional[List[str]]
:param container_timeout: max time allowed for the execution of
the container instance.
:type container_timeout: datetime.timedelta
:param tags: azure tags as dict of str:str
:type tags: Optional[dict[str, str]]
:param os_type: The operating system type required by the containers
in the container group. Possible values include: 'Windows', 'Linux'
:type os_type: str
:param restart_policy: Restart policy for all containers within the container group.
Possible values include: 'Always', 'OnFailure', 'Never'
:type restart_policy: str
:param ip_address: The IP address type of the container group.
:type ip_address: IpAddress
**Example**::
AzureContainerInstancesOperator(
"azure_service_principal",
"azure_registry_user",
"my-resource-group",
"my-container-name-{{ ds }}",
"myprivateregistry.azurecr.io/my_container:latest",
"westeurope",
{"MODEL_PATH": "my_value",
ci_conn_id = "azure_service_principal",
registry_conn_id = "azure_registry_user",
resource_group = "my-resource-group",
name = "my-container-name-{{ ds }}",
image = "myprivateregistry.azurecr.io/my_container:latest",
region = "westeurope",
environment_variables = {"MODEL_PATH": "my_value",
"POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}",
"POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}",
"JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}" },
['POSTGRES_PASSWORD'],
[("azure_wasb_conn_id",
"my_storage_container",
"my_fileshare",
"/input-data",
True),],
secured_variables = ['POSTGRES_PASSWORD'],
volumes = [("azure_wasb_conn_id",
"my_storage_container",
"my_fileshare",
"/input-data",
True),],
memory_in_gb=14.0,
cpu=4.0,
gpu=GpuResource(count=1, sku='K80'),
command=["/bin/echo", "world"],
container_timeout=timedelta(hours=2),
task_id="start_container"
)
"""
Expand All @@ -142,10 +151,14 @@ def __init__(
memory_in_gb: Optional[Any] = None,
cpu: Optional[Any] = None,
gpu: Optional[Any] = None,
command: Optional[str] = None,
command: Optional[List[str]] = None,
remove_on_error: bool = True,
fail_if_exists: bool = True,
tags: Optional[Dict[str, str]] = None,
os_type: str = 'Linux',
restart_policy: str = 'Never',
ip_address: Optional[IpAddress] = None,
ports: Optional[List[ContainerPort]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -167,6 +180,22 @@ def __init__(
self.fail_if_exists = fail_if_exists
self._ci_hook: Any = None
self.tags = tags
self.os_type = os_type
if self.os_type not in ['Linux', 'Windows']:
raise AirflowException(
"Invalid value for the os_type argument. "
"Please set 'Linux' or 'Windows' as the os_type. "
f"Found `{self.os_type}`."
)
self.restart_policy = restart_policy
if self.restart_policy not in ['Always', 'OnFailure', 'Never']:
raise AirflowException(
"Invalid value for the restart_policy argument. "
"Please set one of 'Always', 'OnFailure','Never' as the restart_policy. "
f"Found `{self.restart_policy}`"
)
self.ip_address = ip_address
self.ports = ports

def execute(self, context: dict) -> int:
# Check name again in case it was templated.
Expand Down Expand Up @@ -214,13 +243,18 @@ def execute(self, context: dict) -> int:
requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu)
)

if self.ip_address and not self.ports:
self.ports = [ContainerPort(port=80)]
self.log.info("Default port set. Container will listen on port 80")

container = Container(
name=self.name,
image=self.image,
resources=resources,
command=self.command,
environment_variables=environment_variables,
volume_mounts=volume_mounts,
ports=self.ports,
)

container_group = ContainerGroup(
Expand All @@ -230,9 +264,10 @@ def execute(self, context: dict) -> int:
],
image_registry_credentials=image_registry_credentials,
volumes=volumes,
restart_policy='Never',
os_type='Linux',
restart_policy=self.restart_policy,
os_type=self.os_type,
tags=self.tags,
ip_address=self.ip_address,
)

self._ci_hook.create_or_update(self.resource_group, self.name, container_group)
Expand Down
Expand Up @@ -19,6 +19,7 @@

import unittest
from collections import namedtuple
from unittest.mock import MagicMock

import mock
from azure.mgmt.containerinstance.models import ContainerState, Event
Expand Down Expand Up @@ -197,3 +198,116 @@ def test_name_checker(self):
for name in valid_names:
checked_name = AzureContainerInstancesOperator._check_name(name)
self.assertEqual(checked_name, name)

@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
)
def test_execute_with_ipaddress(self, aci_mock):
expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test')
expected_cg = make_mock_cg(expected_c_state)
ipaddress = MagicMock()

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
resource_group='resource-group',
name='container-name',
image='container-image',
region='region',
task_id='task',
ip_address=ipaddress,
)
aci.execute(None)
self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args

self.assertEqual(called_cg.ip_address, ipaddress)

@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
)
def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test')
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
resource_group='resource-group',
name='container-name',
image='container-image',
region='region',
task_id='task',
restart_policy="Always",
os_type='Windows',
)
aci.execute(None)
self.assertEqual(aci_mock.return_value.create_or_update.call_count, 1)
(_, _, called_cg), _ = aci_mock.return_value.create_or_update.call_args

self.assertEqual(called_cg.restart_policy, 'Always')
self.assertEqual(called_cg.os_type, 'Windows')

@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
)
def test_execute_fails_with_incorrect_os_type(self, aci_mock):
expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test')
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False

with self.assertRaises(AirflowException) as e:
AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
resource_group='resource-group',
name='container-name',
image='container-image',
region='region',
task_id='task',
os_type='MacOs',
)

self.assertEqual(
str(e.exception),
"Invalid value for the os_type argument. "
"Please set 'Linux' or 'Windows' as the os_type. "
"Found `MacOs`.",
)

@mock.patch(
"airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook"
)
def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test')
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.exists.return_value = False

with self.assertRaises(AirflowException) as e:
AzureContainerInstancesOperator(
ci_conn_id=None,
registry_conn_id=None,
resource_group='resource-group',
name='container-name',
image='container-image',
region='region',
task_id='task',
restart_policy='Everyday',
)

self.assertEqual(
str(e.exception),
"Invalid value for the restart_policy argument. "
"Please set one of 'Always', 'OnFailure','Never' as the restart_policy. "
"Found `Everyday`",
)

0 comments on commit 686e0ee

Please sign in to comment.