Skip to content

Commit

Permalink
Refactor SSH tests to not use SSH server in operator tests (#21326)
Browse files Browse the repository at this point in the history
This required a slight refactor to the SSHOperator (moving
`exec_ssh_client_command` "down" in to the Hook) but the SSH _Operator_
tests now just use stubbing, and the only place that connects to a real
SSH server is the one test of `test_exec_ssh_client_command` in SSHHook.

This is both better structured, and hopefully produces less (or ideally
no) random failures in our tests
  • Loading branch information
ashb committed Feb 4, 2022
1 parent 3e98280 commit ab762a5
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 252 deletions.
68 changes: 68 additions & 0 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
from base64 import decodebytes
from io import StringIO
from select import select
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

import paramiko
Expand Down Expand Up @@ -428,3 +429,70 @@ def _pkey_from_private_key(self, private_key: str, passphrase: Optional[str] = N
'Ensure key provided is valid for one of the following'
'key formats: RSA, DSS, ECDSA, or Ed25519'
)

def exec_ssh_client_command(
self,
ssh_client: paramiko.SSHClient,
command: str,
get_pty: bool,
environment: Optional[dict],
timeout: Optional[int],
) -> Tuple[int, bytes, bytes]:
self.log.info("Running command: %s", command)

# set timeout taken as params
stdin, stdout, stderr = ssh_client.exec_command(
command=command,
get_pty=get_pty,
timeout=timeout,
environment=environment,
)
# get channels
channel = stdout.channel

# closing stdin
stdin.close()
channel.shutdown_write()

agg_stdout = b''
agg_stderr = b''

# capture any initial output in case channel is closed already
stdout_buffer_length = len(stdout.channel.in_buffer)

if stdout_buffer_length > 0:
agg_stdout += stdout.channel.recv(stdout_buffer_length)

# read from both stdout and stderr
while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready():
readq, _, _ = select([channel], [], [], timeout)
for recv in readq:
if recv.recv_ready():
line = stdout.channel.recv(len(recv.in_buffer))
agg_stdout += line
self.log.info(line.decode('utf-8', 'replace').strip('\n'))
if recv.recv_stderr_ready():
line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
agg_stderr += line
self.log.warning(line.decode('utf-8', 'replace').strip('\n'))
if (
stdout.channel.exit_status_ready()
and not stderr.channel.recv_stderr_ready()
and not stdout.channel.recv_ready()
):
stdout.channel.shutdown_read()
try:
stdout.channel.close()
except Exception:
# there is a race that when shutdown_read has been called and when
# you try to close the connection, the socket is already closed
# We should ignore such errors (but we should log them with warning)
self.log.warning("Ignoring exception on close", exc_info=True)
break

stdout.close()
stderr.close()

exit_status = stdout.channel.recv_exit_status()

return exit_status, agg_stdout, agg_stderr
95 changes: 26 additions & 69 deletions airflow/providers/ssh/operators/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@

import warnings
from base64 import b64encode
from select import select
from typing import Optional, Sequence, Tuple, Union

from paramiko.client import SSHClient
from typing import TYPE_CHECKING, Optional, Sequence, Union

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.ssh.hooks.ssh import SSHHook

if TYPE_CHECKING:
from paramiko.client import SSHClient

from airflow.providers.ssh.hooks.ssh import SSHHook

CMD_TIMEOUT = 10

Expand Down Expand Up @@ -66,7 +67,7 @@ class SSHOperator(BaseOperator):
def __init__(
self,
*,
ssh_hook: Optional[SSHHook] = None,
ssh_hook: Optional["SSHHook"] = None,
ssh_conn_id: Optional[str] = None,
remote_host: Optional[str] = None,
command: Optional[str] = None,
Expand Down Expand Up @@ -100,10 +101,12 @@ def __init__(
'Please use `conn_timeout` and `cmd_timeout` instead.'
'The old option `timeout` will be removed in a future version.',
DeprecationWarning,
stacklevel=1,
stacklevel=2,
)

def get_hook(self) -> SSHHook:
def get_hook(self) -> "SSHHook":
from airflow.providers.ssh.hooks.ssh import SSHHook

if self.ssh_conn_id:
if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
Expand All @@ -128,78 +131,32 @@ def get_hook(self) -> SSHHook:

return self.ssh_hook

def get_ssh_client(self) -> SSHClient:
def get_ssh_client(self) -> "SSHClient":
# Remember to use context manager or call .close() on this when done
self.log.info('Creating ssh_client')
return self.get_hook().get_conn()

def exec_ssh_client_command(self, ssh_client: SSHClient, command: str) -> Tuple[int, bytes, bytes]:
self.log.info("Running command: %s", command)

# set timeout taken as params
stdin, stdout, stderr = ssh_client.exec_command(
command=command,
get_pty=self.get_pty,
timeout=self.timeout,
environment=self.environment,
def exec_ssh_client_command(self, ssh_client: "SSHClient", command: str):
warnings.warn(
'exec_ssh_client_command method on SSHOperator is deprecated, call '
'`ssh_hook.exec_ssh_client_command` instead',
DeprecationWarning,
)
assert self.ssh_hook
return self.ssh_hook.exec_ssh_client_command(
ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty
)
# get channels
channel = stdout.channel

# closing stdin
stdin.close()
channel.shutdown_write()

agg_stdout = b''
agg_stderr = b''

# capture any initial output in case channel is closed already
stdout_buffer_length = len(stdout.channel.in_buffer)

if stdout_buffer_length > 0:
agg_stdout += stdout.channel.recv(stdout_buffer_length)

# read from both stdout and stderr
while not channel.closed or channel.recv_ready() or channel.recv_stderr_ready():
readq, _, _ = select([channel], [], [], self.cmd_timeout)
for recv in readq:
if recv.recv_ready():
line = stdout.channel.recv(len(recv.in_buffer))
agg_stdout += line
self.log.info(line.decode('utf-8', 'replace').strip('\n'))
if recv.recv_stderr_ready():
line = stderr.channel.recv_stderr(len(recv.in_stderr_buffer))
agg_stderr += line
self.log.warning(line.decode('utf-8', 'replace').strip('\n'))
if (
stdout.channel.exit_status_ready()
and not stderr.channel.recv_stderr_ready()
and not stdout.channel.recv_ready()
):
stdout.channel.shutdown_read()
try:
stdout.channel.close()
except Exception:
# there is a race that when shutdown_read has been called and when
# you try to close the connection, the socket is already closed
# We should ignore such errors (but we should log them with warning)
self.log.warning("Ignoring exception on close", exc_info=True)
break

stdout.close()
stderr.close()

exit_status = stdout.channel.recv_exit_status()

return exit_status, agg_stdout, agg_stderr

def raise_for_status(self, exit_status: int, stderr: bytes) -> None:
if exit_status != 0:
error_msg = stderr.decode('utf-8')
raise AirflowException(f"error running cmd: {self.command}, error: {error_msg}")

def run_ssh_client_command(self, ssh_client: SSHClient, command: str) -> bytes:
exit_status, agg_stdout, agg_stderr = self.exec_ssh_client_command(ssh_client, command)
def run_ssh_client_command(self, ssh_client: "SSHClient", command: str) -> bytes:
assert self.ssh_hook
exit_status, agg_stdout, agg_stderr = self.ssh_hook.exec_ssh_client_command(
ssh_client, command, timeout=self.timeout, environment=self.environment, get_pty=self.get_pty
)
self.raise_for_status(exit_status, agg_stderr)
return agg_stdout

Expand Down
20 changes: 18 additions & 2 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from unittest import mock

import paramiko
import tenacity
from parameterized import parameterized

from airflow import settings
Expand Down Expand Up @@ -739,6 +740,21 @@ def test_openssh_private_key(self):
session.delete(conn)
session.commit()

def test_exec_ssh_client_command(self):
hook = SSHHook(
ssh_conn_id='ssh_default',
conn_timeout=30,
banner_timeout=100,
)

if __name__ == '__main__':
unittest.main()
for attempt in tenacity.Retrying(stop=tenacity.stop_after_attempt(5)):
with attempt, hook.get_conn() as client:
ret = hook.exec_ssh_client_command(
client,
'echo airflow',
False,
None,
30,
)

assert ret == (0, b'airflow\n', b'')
Loading

0 comments on commit ab762a5

Please sign in to comment.