Skip to content

Commit

Permalink
Entrypoint support in docker operator (#14642)
Browse files Browse the repository at this point in the history
* include entrypoint in DockerOperator

* update DockerOperator test with entrypoint

* de-duplicate container removal logic

* rename get_command method
  • Loading branch information
jsemric committed Apr 9, 2021
1 parent cf67bb8 commit 594d93d
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 32 deletions.
71 changes: 40 additions & 31 deletions airflow/providers/docker/operators/docker.py
Expand Up @@ -99,6 +99,8 @@ class DockerOperator(BaseOperator):
:param volumes: List of volumes to mount into the container, e.g.
``['/host/path:/container/path', '/host/path2:/container/path2:ro']``.
:type volumes: list
:param entrypoint: Overwrite the default ENTRYPOINT of the image
:type entrypoint: str or list
:param working_dir: Working directory to
set on the container (equivalent to the -w switch the docker client)
:type working_dir: str
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(
tmp_dir: str = '/tmp/airflow',
user: Optional[Union[str, int]] = None,
volumes: Optional[List[str]] = None,
entrypoint: Optional[Union[str, List[str]]] = None,
working_dir: Optional[str] = None,
xcom_all: bool = False,
docker_conn_id: Optional[str] = None,
Expand Down Expand Up @@ -196,6 +199,7 @@ def __init__(
self.tmp_dir = tmp_dir
self.user = user
self.volumes = volumes or []
self.entrypoint = entrypoint
self.working_dir = working_dir
self.xcom_all = xcom_all
self.docker_conn_id = docker_conn_id
Expand Down Expand Up @@ -233,7 +237,7 @@ def _run_image(self) -> Optional[str]:
if not self.cli:
raise Exception("The 'cli' should be initialized before!")
self.container = self.cli.create_container(
command=self.get_command(),
command=self.format_command(self.command),
name=self.container_name,
environment={**self.environment, **self._private_environment},
host_config=self.cli.create_host_config(
Expand All @@ -251,39 +255,42 @@ def _run_image(self) -> Optional[str]:
),
image=self.image,
user=self.user,
entrypoint=self.format_command(self.entrypoint),
working_dir=self.working_dir,
tty=self.tty,
)

lines = self.cli.attach(container=self.container['Id'], stdout=True, stderr=True, stream=True)
self.cli.start(self.container['Id'])

line = ''
res_lines = []
for line in lines:
line = line.strip()
if hasattr(line, 'decode'):
# Note that lines returned can also be byte sequences so we have to handle decode here
line = line.decode('utf-8')
res_lines.append(line)
self.log.info(line)
try:
self.cli.start(self.container['Id'])

result = self.cli.wait(self.container['Id'])
if result['StatusCode'] != 0:
if self.auto_remove:
self.cli.remove_container(self.container['Id'])
res_lines = "\n".join(res_lines)
raise AirflowException('docker container failed: ' + repr(result) + f"lines {res_lines}")
line = ''
res_lines = []
for line in lines:
line = line.strip()
if hasattr(line, 'decode'):
# Note that lines returned can also be byte sequences so we have to handle decode here
line = line.decode('utf-8')
res_lines.append(line)
self.log.info(line)

# duplicated conditional logic because of expensive operation
ret = None
if self.do_xcom_push:
ret = self.cli.logs(container=self.container['Id']) if self.xcom_all else line.encode('utf-8')
result = self.cli.wait(self.container['Id'])
if result['StatusCode'] != 0:
res_lines = "\n".join(res_lines)
raise AirflowException('docker container failed: ' + repr(result) + f"lines {res_lines}")

if self.auto_remove:
self.cli.remove_container(self.container['Id'])

return ret
ret = None
if self.do_xcom_push:
ret = (
self.cli.logs(container=self.container['Id'])
if self.xcom_all
else line.encode('utf-8')
)
return ret
finally:
if self.auto_remove:
self.cli.remove_container(self.container['Id'])

def execute(self, context) -> Optional[str]:
self.cli = self._get_cli()
Expand Down Expand Up @@ -320,18 +327,20 @@ def _get_cli(self) -> APIClient:
tls_config = self.__get_tls_config()
return APIClient(base_url=self.docker_url, version=self.api_version, tls=tls_config)

def get_command(self) -> Union[List[str], str]:
@staticmethod
def format_command(command: Union[str, List[str]]) -> Union[List[str], str]:
"""
Retrieve command(s). if command string starts with [, it returns the command list)
:param command: Docker command or entrypoint
:type command: str | List[str]
:return: the command (or commands)
:rtype: str | List[str]
"""
if isinstance(self.command, str) and self.command.strip().find('[') == 0:
commands = ast.literal_eval(self.command)
else:
commands = self.command
return commands
if isinstance(command, str) and command.strip().find('[') == 0:
return ast.literal_eval(command)
return command

def on_kill(self) -> None:
if self.cli is not None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/docker/operators/docker_swarm.py
Expand Up @@ -118,7 +118,7 @@ def _run_service(self) -> None:
types.TaskTemplate(
container_spec=types.ContainerSpec(
image=self.image,
command=self.get_command(),
command=self.format_command(self.command),
env=self.environment,
user=self.user,
tty=self.tty,
Expand Down
2 changes: 2 additions & 0 deletions tests/providers/docker/operators/test_docker.py
Expand Up @@ -67,6 +67,7 @@ def test_execute(self):
owner='unittest',
task_id='unittest',
volumes=['/host/path:/container/path'],
entrypoint='["sh", "-c"]',
working_dir='/container/path',
shm_size=1000,
host_tmp_dir='/host/airflow',
Expand All @@ -86,6 +87,7 @@ def test_execute(self):
host_config=self.client_mock.create_host_config.return_value,
image='ubuntu:latest',
user=None,
entrypoint=['sh', '-c'],
working_dir='/container/path',
tty=True,
)
Expand Down

0 comments on commit 594d93d

Please sign in to comment.