From 919c03b1c45513e87716b01ab7b3e91fca2e7612 Mon Sep 17 00:00:00 2001 From: Tasos Papaioannou Date: Wed, 10 Apr 2024 16:09:45 -0400 Subject: [PATCH] Add support for hussh and ansible-pylibssh clients. --- broker/binds/hussh.py | 155 +++++++++++++ broker/binds/pylibssh.py | 254 ++++++++++++++++++++ broker/binds/ssh2.py | 281 +++++++++++++++++++++++ broker/helpers.py | 2 +- broker/hosts.py | 2 +- broker/session.py | 343 +++------------------------- broker/settings.py | 1 + broker/ssh_session.py | 55 +++++ broker_settings.yaml.example | 1 + pyproject.toml | 5 + tests/functional/test_containers.py | 4 +- tests/functional/test_rh_beaker.py | 4 +- tests/functional/test_satlab.py | 15 +- 13 files changed, 793 insertions(+), 329 deletions(-) create mode 100644 broker/binds/hussh.py create mode 100644 broker/binds/pylibssh.py create mode 100644 broker/binds/ssh2.py create mode 100644 broker/ssh_session.py diff --git a/broker/binds/hussh.py b/broker/binds/hussh.py new file mode 100644 index 00000000..e475f41c --- /dev/null +++ b/broker/binds/hussh.py @@ -0,0 +1,155 @@ +"""Module providing classes to establish ssh or ssh-like connections to hosts. + +Classes: + Session - Wrapper around hussh's auth/connection system. + +Note: You typically want to use a Host object instance to create sessions, + not these classes directly. +""" +from contextlib import contextmanager +from pathlib import Path + +from hussh import Connection +from logzero import logger + +from broker import exceptions, helpers + + +class Session: + """Wrapper around hussh's auth/connection system.""" + + def __init__(self, **kwargs): + """Initialize a Session object. + + kwargs: + hostname (str): The hostname or IP address of the remote host. Defaults to 'localhost'. + username (str): The username to authenticate with. Defaults to 'root'. + timeout (float): The timeout for the connection in seconds. Defaults to 60. + port (int): The port number to connect to. Defaults to 22. + key_filename (str): The path to the private key file to use for authentication. + password (str): The password to use for authentication. + ipv6 (bool): Whether or not to use IPv6. Defaults to False. + ipv4_fallback (bool): Whether or not to fallback to IPv4 if IPv6 fails. Defaults to True. + + Raises: + AuthException: If no password or key file is provided. + ConnectionError: If the connection fails. + FileNotFoundError: If the key file is not found. + """ + host = kwargs.get("hostname", "localhost") + user = kwargs.get("username", "root") + port = kwargs.get("port", 22) + timeout = kwargs.get("timeout", 60) * 1000 + + key_filename = kwargs.get("key_filename") + password = kwargs.get("password") + + # TODO Create and use socket if hussh allows user to specify one + self.session = None + + conn_kwargs = {"username": user, "port": port, "timeout": timeout} + try: + if key_filename: + auth_type = "Key" + if not Path(key_filename).exists(): + raise FileNotFoundError(f"Key not found in '{key_filename}'") + conn_kwargs["private_key"] = key_filename + elif password: + auth_type = "Password" + conn_kwargs["password"] = password + elif user: + auth_type = "Session" + else: + raise exceptions.AuthenticationError("No password or key file provided.") + + logger.info(f"{conn_kwargs=}") + self.session = Connection(host, **conn_kwargs) + + except Exception as err: # noqa: BLE001 + raise exceptions.AuthenticationError( + f"{auth_type}-based authentication failed." + ) from err + + @staticmethod + def _set_destination(source, destination): + dest = destination or source + if dest.endswith("/"): + dest = dest + Path(source).name + return dest + + def disconnect(self): + """Disconnect session.""" + + def remote_copy(self, source, dest_host, dest_path=None, ensure_dir=True): + """Copy a file from this host to another.""" + dest_path = dest_path or source + if ensure_dir: + dest_host.session.run(f"mkdir -p {Path(dest_path).absolute().parent}") + + # Copy from this host to destination host + self.session.remote_copy( + source_path=source, dest_conn=dest_host.session.session, dest_path=dest_path + ) + + def run(self, command, timeout=0): + """Run a command on the host and return the results.""" + # TODO support timeout parameter + result = self.session.execute(command) + + # Create broker Result from hussh SSHResult + return helpers.Result( + status=result.status, + stderr=result.stderr, + stdout=result.stdout, + ) + + def scp_read(self, source, destination=None, return_data=False): + """SCP read a remote file into a local destination or return a bytes object if return_data is True.""" + destination = self._set_destination(source, destination) + if return_data: + return self.session.scp_read(remote_path=source) + self.session.scp_read(remote_path=source, local_path=destination) + + def scp_write(self, source, destination=None, ensure_dir=True): + """SCP write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + self.session.scp_write(source, destination) + + def sftp_read(self, source, destination=None, return_data=False): + """Read a remote file into a local destination or return a bytes object if return_data is True.""" + if return_data: + return self.session.sftp_read(remote_path=source).encode("utf-8") + + destination = self._set_destination(source, destination) + + # Create the destination path if it doesn't exist + Path(destination).parent.mkdir(parents=True, exist_ok=True) + + self.session.sftp_read(remote_path=source, local_path=destination) + + def sftp_write(self, source, destination=None, ensure_dir=True): + """Sftp write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + self.session.sftp_write(local_path=source, remote_path=destination) + + def shell(self, pty=False): + """Create and return an interactive shell instance.""" + return self.session.shell(pty=pty) + + @contextmanager + def tail_file(self, filename): + """Tail a file on the remote host.""" + with self.session.tail(filename) as _tailer: + yield (tailer := FileTailer(tailer=_tailer)) + tailer.contents = _tailer.contents + + +class FileTailer: + """Wrapper for hussh's FileTailer class.""" + + def __init__(self, **kwargs): + self.tailer = kwargs.get("tailer") diff --git a/broker/binds/pylibssh.py b/broker/binds/pylibssh.py new file mode 100644 index 00000000..4cc091ca --- /dev/null +++ b/broker/binds/pylibssh.py @@ -0,0 +1,254 @@ +"""Module providing classes to establish ssh or ssh-like connections to hosts. + +Classes: + Session - Wrapper around ansible-pylibssh auth/connection system. + InteractiveShell - Wrapper around ansible-pylibssh non-blocking channel system. + +Note: You typically want to use a Host object instance to create sessions, + not these classes directly. +""" +from contextlib import contextmanager +from pathlib import Path +from tempfile import NamedTemporaryFile + +from pylibsshext.session import Session as _Session + +from broker import exceptions, helpers +from broker.ssh_session import _create_connect_socket + + +class Session: + """Wrapper around ansible-pylibssh's auth/connection system.""" + + def __init__(self, **kwargs): + """Initialize a Session object. + + kwargs: + hostname (str): The hostname or IP address of the remote host. Defaults to 'localhost'. + username (str): The username to authenticate with. Defaults to 'root'. + timeout (float): The timeout for the connection in seconds. Defaults to 60. + port (int): The port number to connect to. Defaults to 22. + key_filename (str): The path to the private key file to use for authentication. + password (str): The password to use for authentication. + ipv6 (bool): Whether or not to use IPv6. Defaults to False. + ipv4_fallback (bool): Whether or not to fallback to IPv4 if IPv6 fails. Defaults to True. + + Raises: + AuthException: If no password or key file is provided. + ConnectionError: If the connection fails. + FileNotFoundError: If the key file is not found. + """ + host = kwargs.get("hostname", "localhost") + user = kwargs.get("username", "root") + port = kwargs.get("port", 22) + key_filename = kwargs.get("key_filename") + password = kwargs.get("password") + timeout = kwargs.get("timeout", 60) + + # Create the socket + self.sock, self.is_ipv6 = _create_connect_socket( + host, + port, + timeout, + ipv6=kwargs.get("ipv6", False), + ipv4_fallback=kwargs.get("ipv4_fallback", True), + ) + + self.session = _Session() + try: + if key_filename: + auth_type = "Key" + key_path = Path(key_filename) + if not key_path.exists(): + raise FileNotFoundError(f"Key not found in '{key_filename}'") + self.session.connect( + fd=self.sock.fileno(), + host=host, + host_key_checking=False, + port=port, + private_key=key_path.read_bytes(), + timeout=timeout, + user=user, + ) + elif password: + auth_type = "Password" + self.session.connect( + fd=self.sock.fileno(), + host=host, + host_key_checking=False, + password=password, + port=port, + timeout=timeout, + user=user, + ) + elif user: + auth_type = "Session" + raise exceptions.NotImplementedError("Session-based auth for ansible-pylibssh") + else: + raise exceptions.AuthenticationError("No password or key file provided.") + except Exception as err: # noqa: BLE001 + raise exceptions.AuthenticationError( + f"{auth_type}-based authentication failed." + ) from err + + @staticmethod + def _set_destination(source, destination): + dest = destination or source + if dest.endswith("/"): + dest = dest + Path(source).name + return dest + + def disconnect(self): + """Disconnect session.""" + + def remote_copy(self, source, dest_host, dest_path=None, ensure_dir=True): + """Copy a file from this host to another.""" + dest_path = dest_path or source + if ensure_dir: + dest_host.session.run(f"mkdir -p {Path(dest_path).absolute().parent}") + + # TODO read/write without local dest_path intermediate + sftp_down = self.session.sftp() + sftp_up = dest_host.session.session.sftp() + try: + with NamedTemporaryFile() as tmp: + sftp_down.get(source, tmp.file.name) + sftp_up.put(tmp.file.name, dest_path) + finally: + sftp_down.close() + sftp_up.close() + + def run(self, command, timeout=0): + """Run a command on the host and return the results.""" + channel = self.session.new_channel() + try: + res = channel.exec_command(command) + return helpers.Result( + status=res.returncode, + stdout=res.stdout.decode("utf-8"), + stderr=res.stderr.decode("utf-8"), + ) + finally: + channel.close() + + def scp_write(self, source, destination=None, ensure_dir=True): + """SCP write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + + scp = self.session.scp() + scp.put(destination, source) + + def sftp_read(self, source, destination=None, return_data=False): + """Read a remote file into a local destination or return a bytes object if return_data is True.""" + # TODO read contents directly into bytes object if return_data is True + destination = self._set_destination(source, destination) + # Create the destination path if it doesn't exist + destination = Path(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + + # Initiate the sftp session, read data, write it to a local destination + sftp = self.session.sftp() + try: + sftp.get(source, destination) + if return_data: + return destination.read_bytes() + finally: + if return_data: + destination.unlink() + sftp.close() + + def sftp_write(self, source, destination=None, ensure_dir=True): + """Sftp write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + + sftp = self.session.sftp() + try: + sftp.put(source, destination) + finally: + sftp.close() + + def shell(self, pty=False): + """Create and return an interactive shell instance.""" + return InteractiveShell(self.session, pty=pty) + + @contextmanager + def tail_file(self, filename): + """Tail a file on the remote host.""" + # TODO re-factor to use SFTP instead + initial_size = int(self.run(f"stat -c %s {filename}").stdout.strip()) + yield (tailer := FileTailer(session=self.session, filename=filename)) + tailer.contents = self.run(f"tail -c +{initial_size} {filename}").stdout + + +class FileTailer: + """FileTailer class.""" + + def __init__(self, **kwargs): + self.session = kwargs.get("session") + self.filename = kwargs.get("filename") + + +class InteractiveShell: + """A helper class that provides an interactive shell interface. + + Preferred use of this class is via its context manager + + with InteractiveShell(my_session) as shell: + shell.send("some-command --argument") + shell.send("another-command") + time.sleep(5) # give time for things to complete + assert "expected text" in shell.result.stdout + + """ + + def __init__(self, session, pty=False): + # FIXME: invoke_shell() always requests pty + # self._channel = session.invoke_shell(pty=pty) + if pty: + self._channel = session.invoke_shell() + else: + raise exceptions.NotImplementedError("Interactive shell with pty=False") + + def __enter__(self): + """Return the shell object.""" + return self + + def __exit__(self, *exc_args): + """Close the channel and read stdout/stderr and status.""" + self.send("exit") # ensure shell has exited + self._channel.send_eof() + + stdout = self._channel.read_bulk_response(timeout=0.5) + stderr = self._channel.read_bulk_response(stderr=1) + status = self._channel.get_channel_exit_status() + + self._channel.close() + + self.result = helpers.Result( + status=status, + stdout=stdout.decode("utf-8"), + stderr=stderr.decode("utf-8"), + ) + + def __getattribute__(self, name): + """Expose non-duplicate attributes from the Channel instance.""" + try: + return object.__getattribute__(self, name) + except AttributeError: + return getattr(self._channel, name) + + def send(self, cmd): + """Send a command to the channel, ensuring a newline character.""" + if not cmd.endswith("\n"): + cmd += "\n" + self._channel.write(cmd.encode("utf-8")) + + def stdout(self): + """Read the contents of a channel's stdout.""" + # FIXME handle read on open channel + res = self._channel.read_bulk_response() + return res.stdout.decode("utf-8") diff --git a/broker/binds/ssh2.py b/broker/binds/ssh2.py new file mode 100644 index 00000000..0e1c3e26 --- /dev/null +++ b/broker/binds/ssh2.py @@ -0,0 +1,281 @@ +"""Module providing classes to establish ssh or ssh-like connections to hosts. + +Classes: + Session - Wrapper around ssh2-python's auth/connection system. + InteractiveShell - Wrapper around ssh2-python's non-blocking channel system. + +Note: You typically want to use a Host object instance to create sessions, + not these classes directly. +""" +from contextlib import contextmanager +from pathlib import Path + +from logzero import logger +from ssh2 import sftp as _sftp +from ssh2.exceptions import SocketSendError +from ssh2.session import Session as _Session + +from broker import exceptions, helpers +from broker.ssh_session import _create_connect_socket + +SFTP_MODE = ( + _sftp.LIBSSH2_SFTP_S_IRUSR + | _sftp.LIBSSH2_SFTP_S_IWUSR + | _sftp.LIBSSH2_SFTP_S_IRGRP + | _sftp.LIBSSH2_SFTP_S_IROTH +) +FILE_FLAGS = _sftp.LIBSSH2_FXF_CREAT | _sftp.LIBSSH2_FXF_WRITE | _sftp.LIBSSH2_FXF_TRUNC + + +class Session: + """Wrapper around ssh2-python's auth/connection system.""" + + def __init__(self, **kwargs): + """Initialize a Session object. + + kwargs: + hostname (str): The hostname or IP address of the remote host. Defaults to 'localhost'. + username (str): The username to authenticate with. Defaults to 'root'. + timeout (float): The timeout for the connection in seconds. Defaults to 60. + port (int): The port number to connect to. Defaults to 22. + key_filename (str): The path to the private key file to use for authentication. + password (str): The password to use for authentication. + ipv6 (bool): Whether or not to use IPv6. Defaults to False. + ipv4_fallback (bool): Whether or not to fallback to IPv4 if IPv6 fails. Defaults to True. + + Raises: + AuthException: If no password or key file is provided. + ConnectionError: If the connection fails. + FileNotFoundError: If the key file is not found. + """ + host = kwargs.get("hostname", "localhost") + user = kwargs.get("username", "root") + port = kwargs.get("port", 22) + key_filename = kwargs.get("key_filename") + password = kwargs.get("password") + timeout = kwargs.get("timeout", 60) + + # Create the socket + self.sock, self.is_ipv6 = _create_connect_socket( + host, + port, + timeout, + ipv6=kwargs.get("ipv6", False), + ipv4_fallback=kwargs.get("ipv4_fallback", True), + ) + + self.session = _Session() + + self.session.handshake(self.sock) + try: + if key_filename: + auth_type = "Key" + if not Path(key_filename).exists(): + raise FileNotFoundError(f"Key not found in '{key_filename}'") + self.session.userauth_publickey_fromfile(user, key_filename) + elif password: + auth_type = "Password" + self.session.userauth_password(user, password) + elif user: + auth_type = "Session" + self.session.agent_auth(user) + else: + raise exceptions.AuthenticationError("No password or key file provided.") + except Exception as err: # noqa: BLE001 + raise exceptions.AuthenticationError( + f"{auth_type}-based authentication failed." + ) from err + + @staticmethod + def _read(channel): + """Read the contents of a channel.""" + size, data = channel.read() + results = "" + while size > 0: + try: + results += data.decode("utf-8") + except UnicodeDecodeError as err: + logger.error(f"Skipping data chunk due to {err}\nReceived: {data}") + size, data = channel.read() + return helpers.Result.from_ssh( + stdout=results, + channel=channel, + ) + + @staticmethod + def _set_destination(source, destination): + dest = destination or source + if dest.endswith("/"): + dest = dest + Path(source).name + return dest + + def disconnect(self): + """Disconnect session.""" + self.session.disconnect() + + def remote_copy(self, source, dest_host, dest_path=None, ensure_dir=True): + """Copy a file from this host to another.""" + dest_path = dest_path or source + sftp_down = self.session.sftp_init() + sftp_up = dest_host.session.session.sftp_init() + if ensure_dir: + dest_host.session.run(f"mkdir -p {Path(dest_path).absolute().parent}") + with sftp_down.open( + source, _sftp.LIBSSH2_FXF_READ, _sftp.LIBSSH2_SFTP_S_IRUSR + ) as download, sftp_up.open(dest_path, FILE_FLAGS, SFTP_MODE) as upload: + for _size, data in download: + upload.write(data) + + def run(self, command, timeout=0): + """Run a command on the host and return the results.""" + self.session.set_timeout(helpers.translate_timeout(timeout)) + try: + channel = self.session.open_session() + except SocketSendError as err: + logger.warning( + f"Encountered connection issue. Attempting to reconnect and retry.\n{err}" + ) + # FIXME _session is on the Host, not Session + del self._session + channel = self.session.open_session() + channel.execute( + command, + ) + channel.wait_eof() + channel.close() + channel.wait_closed() + results = self._read(channel) + return results + + def scp_write(self, source, destination=None, ensure_dir=True): + """SCP write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + fileinfo = (source := Path(source).stat()) + + chan = self.session.scp_send64( + destination, + fileinfo.st_mode & 0o777, + fileinfo.st_size, + fileinfo.st_mtime, + fileinfo.st_atime, + ) + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + with source.open("rb") as local: + for data in local: + chan.write(data) + + def sftp_read(self, source, destination=None, return_data=False): + """Read a remote file into a local destination or return a bytes object if return_data is True.""" + if not return_data: + destination = self._set_destination(source, destination) + + # create the destination path if it doesn't exist + destination = Path(destination) + destination.parent.mkdir(parents=True, exist_ok=True) + + # initiate the sftp session, read data, write it to a local destination + sftp = self.session.sftp_init() + with sftp.open(source, _sftp.LIBSSH2_FXF_READ, _sftp.LIBSSH2_SFTP_S_IRUSR) as remote: + captured_data = b"" + for _rc, data in remote: + captured_data += data + if return_data: + return captured_data + destination.write_bytes(data) + + def sftp_write(self, source, destination=None, ensure_dir=True): + """Sftp write a local file to a remote destination.""" + destination = self._set_destination(source, destination) + + data = Path(source).read_bytes() + if ensure_dir: + self.run(f"mkdir -p {Path(destination).absolute().parent}") + + sftp = self.session.sftp_init() + with sftp.open(destination, FILE_FLAGS, SFTP_MODE) as remote: + remote.write(data) + + def shell(self, pty=False): + """Create and return an interactive shell instance.""" + channel = self.session.open_session() + return InteractiveShell(channel, pty) + + @contextmanager + def tail_file(self, filename): + """Simulate tailing a file on the remote host. + + Example: + with my_host.session.tail_file("/var/log/messages") as res: + # do something that creates new messages + print(res.stdout) + + yields a FileTailer object with contents attr set to the string output + """ + # TODO refactor to use SFTP instead + initial_size = int(self.run(f"stat -c %s {filename}").stdout.strip()) + yield (tailer := FileTailer(session=self.session, filename=filename)) + tailer.contents = self.run(f"tail -c +{initial_size} {filename}").stdout + + +class FileTailer: + """FileTailer class.""" + + def __init__(self, **kwargs): + self.session = kwargs.get("session") + self.filename = kwargs.get("filename") + + +class InteractiveShell: + """A helper class that provides an interactive shell interface. + + Preferred use of this class is via its context manager + + with InteractiveShell(channel=my_channel) as shell: + shell.send("some-command --argument") + shell.send("another-command") + time.sleep(5) # give time for things to complete + assert "expected text" in shell.result.stdout + + """ + + def __init__(self, channel, pty=False): + self._channel = channel + if pty: + self._channel.pty() + self._channel.shell() + + def __enter__(self): + """Return the shell object.""" + return self + + def __exit__(self, *exc_args): + """Close the channel and read stdout/stderr and status.""" + self._channel.close() + self.result = Session._read(self._channel) + + def __getattribute__(self, name): + """Expose non-duplicate attributes from the channel.""" + try: + return object.__getattribute__(self, name) + except AttributeError: + return getattr(self._channel, name) + + def send(self, cmd): + """Send a command to the channel, ensuring a newline character.""" + if not cmd.endswith("\n"): + cmd += "\n" + self._channel.write(cmd) + + def stdout(self): + """Read the contents of a channel's stdout.""" + if not self._channel.eof(): + _, data = self._channel.read(65535) + results = data.decode("utf-8") + else: + results = None + size, data = self._channel.read() + while size > 0: + results += data.decode("utf-8") + size, data = self._channel.read() + return results diff --git a/broker/helpers.py b/broker/helpers.py index 1b0fe0a6..1d66dbe3 100644 --- a/broker/helpers.py +++ b/broker/helpers.py @@ -507,7 +507,7 @@ def from_ssh(cls, stdout, channel): return cls( stdout=stdout, status=channel.get_exit_status(), - stderr=channel.read_stderr(), + stderr=channel.read_stderr()[1].decode("utf-8"), ) @classmethod diff --git a/broker/hosts.py b/broker/hosts.py index dbcba84b..ffc6c84b 100644 --- a/broker/hosts.py +++ b/broker/hosts.py @@ -136,7 +136,7 @@ def close(self): """Close the SSH connection to the host.""" # This attribute may be missing after pickling if isinstance(getattr(self, "_session", None), Session): - self._session.session.disconnect() + self._session.disconnect() self._session = None def release(self): diff --git a/broker/session.py b/broker/session.py index f6ed3b68..038585c2 100644 --- a/broker/session.py +++ b/broker/session.py @@ -10,326 +10,42 @@ """ from contextlib import contextmanager from pathlib import Path -import socket import tempfile from logzero import logger -from broker import exceptions, helpers - -try: - from ssh2 import sftp as ssh2_sftp - from ssh2.exceptions import SocketSendError - from ssh2.session import Session as ssh2_Session - - SFTP_MODE = ( - ssh2_sftp.LIBSSH2_SFTP_S_IRUSR - | ssh2_sftp.LIBSSH2_SFTP_S_IWUSR - | ssh2_sftp.LIBSSH2_SFTP_S_IRGRP - | ssh2_sftp.LIBSSH2_SFTP_S_IROTH - ) - FILE_FLAGS = ( - ssh2_sftp.LIBSSH2_FXF_CREAT | ssh2_sftp.LIBSSH2_FXF_WRITE | ssh2_sftp.LIBSSH2_FXF_TRUNC - ) -except ImportError: - logger.warning( - "ssh2-python is not installed, ssh actions will not work.\n" - "To use ssh, run pip install broker[ssh2]." - ) +from broker import helpers +from broker.exceptions import NotImplementedError +from broker.settings import settings +SSH_BACKENDS = ("ssh2-python", "ssh2-python312", "ansible-pylibssh", "hussh") +SSH_BACKEND = settings.SSH_BACKEND -def _create_connect_socket(host, port, timeout, ipv6=False, ipv4_fallback=True, sock=None): - """Create a socket and establish a connection to the specified host and port. - - Args: - host (str): The hostname or IP address of the remote server. - port (int): The port number to connect to. - timeout (float): The timeout value in seconds for the socket connection. - ipv6 (bool, optional): Whether to use IPv6. Defaults to False. - ipv4_fallback (bool, optional): Whether to fallback to IPv4 if IPv6 fails. Defaults to True. - sock (socket.socket, optional): An existing socket object to use. Defaults to None. - - Returns: - socket.socket: The connected socket object. - bool: True if IPv6 was used, False otherwise. - - Raises: - exceptions.ConnectionError: If unable to establish a connection to the host. - """ - if ipv6 and not sock: - try: - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - except OSError as err: - if ipv4_fallback: - logger.warning(f"IPv6 failed with {err}. Falling back to IPv4.") - return _create_connect_socket(host, port, timeout, ipv6=False) - else: - raise exceptions.ConnectionError( - f"Unable to establish IPv6 connection to {host}." - ) from err - elif not sock: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(timeout) - if ipv6: - try: - sock.connect((host, port)) - except socket.gaierror as err: - if ipv4_fallback: - logger.warning(f"IPv6 connection failed to {host}. Falling back to IPv4.") - return _create_connect_socket(host, port, timeout, ipv6=False, sock=sock) - else: - raise exceptions.ConnectionError( - f"Unable to establish IPv6 connection to {host}." - ) from err - else: - sock.connect((host, port)) - return sock, ipv6 +logger.debug(f"{SSH_BACKEND=}") class Session: - """Wrapper around ssh2-python's auth/connection system.""" - - def __init__(self, **kwargs): - """Initialize a Session object. - - kwargs: - hostname (str): The hostname or IP address of the remote host. Defaults to 'localhost'. - username (str): The username to authenticate with. Defaults to 'root'. - timeout (float): The timeout for the connection in seconds. Defaults to 60. - port (int): The port number to connect to. Defaults to 22. - key_filename (str): The path to the private key file to use for authentication. - password (str): The password to use for authentication. - ipv6 (bool): Whether or not to use IPv6. Defaults to False. - ipv4_fallback (bool): Whether or not to fallback to IPv4 if IPv6 fails. Defaults to True. - - Raises: - AuthException: If no password or key file is provided. - ConnectionError: If the connection fails. - FileNotFoundError: If the key file is not found. - """ - host = kwargs.get("hostname", "localhost") - user = kwargs.get("username", "root") - port = kwargs.get("port", 22) - key_filename = kwargs.get("key_filename") - password = kwargs.get("password") - timeout = kwargs.get("timeout", 60) - # create the socket - self.sock, self.is_ipv6 = _create_connect_socket( - host, - port, - timeout, - ipv6=kwargs.get("ipv6", False), - ipv4_fallback=kwargs.get("ipv4_fallback", True), - ) - self.session = ssh2_Session() - self.session.handshake(self.sock) - try: - if key_filename: - auth_type = "Key" - if not Path(key_filename).exists(): - raise FileNotFoundError(f"Key not found in '{key_filename}'") - self.session.userauth_publickey_fromfile(user, key_filename) - elif password: - auth_type = "Password" - self.session.userauth_password(user, password) - elif user: - auth_type = "Session" - self.session.agent_auth(user) - else: - raise exceptions.AuthenticationError("No password or key file provided.") - except Exception as err: # noqa: BLE001 - raise exceptions.AuthenticationError( - f"{auth_type}-based authentication failed." - ) from err - - @staticmethod - def _read(channel): - """Read the contents of a channel.""" - size, data = channel.read() - results = "" - while size > 0: - try: - results += data.decode("utf-8") - except UnicodeDecodeError as err: - logger.error(f"Skipping data chunk due to {err}\nReceived: {data}") - size, data = channel.read() - return helpers.Result.from_ssh( - stdout=results, - channel=channel, - ) - - def run(self, command, timeout=0): - """Run a command on the host and return the results.""" - self.session.set_timeout(helpers.translate_timeout(timeout)) - try: - channel = self.session.open_session() - except SocketSendError as err: - logger.warning( - f"Encountered connection issue. Attempting to reconnect and retry.\n{err}" - ) - del self._session - channel = self.session.open_session() - channel.execute( - command, - ) - channel.wait_eof() - channel.close() - channel.wait_closed() - results = self._read(channel) - return results - - def shell(self, pty=False): - """Create and return an interactive shell instance.""" - channel = self.session.open_session() - return InteractiveShell(channel, pty) - - @contextmanager - def tail_file(self, filename): - """Simulate tailing a file on the remote host. - - Example: - with my_host.session.tail_file("/var/log/messages") as res: - # do something that creates new messages - print(res.stdout) - - returns a Result object with stdout, stderr, and status - """ - initial_size = int(self.run(f"stat -c %s {filename}").stdout.strip()) - yield (res := helpers.Result()) - # get the contents of the file from the initial size to the end - result = self.run(f"tail -c +{initial_size} {filename}") - res.__dict__.update(result.__dict__) - - def sftp_read(self, source, destination=None, return_data=False): - """Read a remote file into a local destination or return a bytes object if return_data is True.""" - if not return_data: - if not destination: - destination = source - elif destination.endswith("/"): - destination = destination + Path(source).name - # create the destination path if it doesn't exist - destination = Path(destination) - destination.parent.mkdir(parents=True, exist_ok=True) - # initiate the sftp session, read data, write it to a local destination - sftp = self.session.sftp_init() - with sftp.open( - source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR - ) as remote: - captured_data = b"" - for _rc, data in remote: - captured_data += data - if return_data: - return captured_data - destination.write_bytes(data) - - def sftp_write(self, source, destination=None, ensure_dir=True): - """Sftp write a local file to a remote destination.""" - if not destination: - destination = source - elif destination.endswith("/"): - destination = destination + Path(source).name - data = Path(source).read_bytes() - if ensure_dir: - self.run(f"mkdir -p {Path(destination).absolute().parent}") - sftp = self.session.sftp_init() - with sftp.open(destination, FILE_FLAGS, SFTP_MODE) as remote: - remote.write(data) - - def remote_copy(self, source, dest_host, dest_path=None, ensure_dir=True): - """Copy a file from this host to another.""" - dest_path = dest_path or source - sftp_down = self.session.sftp_init() - sftp_up = dest_host.session.session.sftp_init() - if ensure_dir: - dest_host.session.run(f"mkdir -p {Path(dest_path).absolute().parent}") - with sftp_down.open( - source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR - ) as download, sftp_up.open(dest_path, FILE_FLAGS, SFTP_MODE) as upload: - for _size, data in download: - upload.write(data) + """Default wrapper around ssh backend's auth/connection system.""" - def scp_write(self, source, destination=None, ensure_dir=True): - """SCP write a local file to a remote destination.""" - if not destination: - destination = source - elif destination.endswith("/"): - destination = destination + Path(source).name - fileinfo = (source := Path(source).stat()) - chan = self.session.scp_send64( - destination, - fileinfo.st_mode & 0o777, - fileinfo.st_size, - fileinfo.st_mtime, - fileinfo.st_atime, +try: + if SSH_BACKEND == "ansible-pylibssh": + from broker.binds.pylibssh import InteractiveShell, Session + elif SSH_BACKEND == "hussh": + from broker.binds.hussh import Session + elif SSH_BACKEND in ("ssh2-python", "ssh2-python312"): + from broker.binds.ssh2 import InteractiveShell, Session # noqa: F401 + else: + logger.warning( + f"SSH backend {SSH_BACKEND!r} not supported.\n" + "Supported ssh backends:\n" + f"{SSH_BACKENDS}" ) - if ensure_dir: - self.run(f"mkdir -p {Path(destination).absolute().parent}") - with source.open("rb") as local: - for data in local: - chan.write(data) - - def __enter__(self): - """Return the session object.""" - return self - - def __exit__(self, *args): - """Close the session.""" - self.session.disconnect() - - -class InteractiveShell: - """A helper class that provides an interactive shell interface. - - Preferred use of this class is via its context manager - - with InteractiveShell(channel=my_channel) as shell: - shell.send("some-command --argument") - shell.send("another-command") - time.sleep(5) # give time for things to complete - assert "expected text" in shell.result.stdout - - """ - - def __init__(self, channel, pty=False): - self._chan = channel - if pty: - self._chan.pty() - self._chan.shell() - - def __enter__(self): - """Return the shell object.""" - return self - - def __exit__(self, *exc_args): - """Close the channel and read stdout/stderr and status.""" - self._chan.close() - self.result = Session._read(self._chan) - - def __getattribute__(self, name): - """Expose non-duplicate attributes from the channel.""" - try: - return object.__getattribute__(self, name) - except AttributeError: - return getattr(self._chan, name) - - def send(self, cmd): - """Send a command to the channel, ensuring a newline character.""" - if not cmd.endswith("\n"): - cmd += "\n" - self._chan.write(cmd) - - def stdout(self): - """Read the contents of a channel's stdout.""" - if not self._chan.eof(): - _, data = self._chan.read(65535) - results = data.decode("utf-8") - else: - results = None - size, data = self._chan.read() - while size > 0: - results += data.decode("utf-8") - size, data = self._chan.read() - return results +except ImportError: + logger.warning( + f"{SSH_BACKEND} is not installed.\n" + "ssh actions will not work.\n" + f"To use ssh, run 'pip install broker[{SSH_BACKEND}]'." + ) class ContainerSession: @@ -359,7 +75,7 @@ def run(self, command, demux=True, **kwargs): return result def disconnect(self): - """Needed for simple compatability with Session.""" + """Needed for simple compatibility with Session.""" @contextmanager def tail_file(self, filename): @@ -419,10 +135,3 @@ def sftp_read(self, source, destination=None, return_data=False): def shell(self, pty=False): """Create and return an interactive shell instance.""" raise NotImplementedError("ContainerSession.shell has not been implemented") - - def __enter__(self): - """Return the session object.""" - return self - - def __exit__(self, *args): - """Do nothing on exit.""" diff --git a/broker/settings.py b/broker/settings.py index 853a5f4c..ea6dae56 100644 --- a/broker/settings.py +++ b/broker/settings.py @@ -94,6 +94,7 @@ def init_settings(settings_path, interactive=False): Validator("HOST_SSH_KEY_FILENAME", default=None), Validator("HOST_IPV6", default=False), Validator("HOST_IPV4_FALLBACK", default=True), + Validator("SSH_BACKEND", default="ssh2-python312"), Validator("LOGGING", is_type_of=dict), Validator( "LOGGING.CONSOLE_LEVEL", diff --git a/broker/ssh_session.py b/broker/ssh_session.py new file mode 100644 index 00000000..bacad5de --- /dev/null +++ b/broker/ssh_session.py @@ -0,0 +1,55 @@ +"""Module providing base SSH methods and classes.""" +import socket + +from logzero import logger + +from broker import exceptions + + +def _create_connect_socket(host, port, timeout, ipv6=False, ipv4_fallback=True, sock=None): + """Create a socket and establish a connection to the specified host and port. + + Args: + host (str): The hostname or IP address of the remote server. + port (int): The port number to connect to. + timeout (float): The timeout value in seconds for the socket connection. + ipv6 (bool, optional): Whether to use IPv6. Defaults to False. + ipv4_fallback (bool, optional): Whether to fallback to IPv4 if IPv6 fails. Defaults to True. + sock (socket.socket, optional): An existing socket object to use. Defaults to None. + + Returns: + socket.socket: The connected socket object. + bool: True if IPv6 was used, False otherwise. + + Raises: + exceptions.ConnectionError: If unable to establish a connection to the host. + """ + if ipv6 and not sock: + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + except OSError as err: + if ipv4_fallback: + logger.warning(f"IPv6 failed with {err}. Falling back to IPv4.") + return _create_connect_socket(host, port, timeout, ipv6=False) + else: + raise exceptions.ConnectionError( + f"Unable to establish IPv6 connection to {host}." + ) from err + elif not sock: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + if ipv6: + try: + sock.connect((host, port)) + except socket.gaierror as err: + if ipv4_fallback: + logger.warning(f"IPv6 connection failed to {host}. Falling back to IPv4.") + # FIXME this socket was created for AF_INET6. We shouldn't reuse it with ipv6=False. + return _create_connect_socket(host, port, timeout, ipv6=False, sock=sock) + else: + raise exceptions.ConnectionError( + f"Unable to establish IPv6 connection to {host}." + ) from err + else: + sock.connect((host, port)) + return sock, ipv6 diff --git a/broker_settings.yaml.example b/broker_settings.yaml.example index 44add3e3..858c4b6a 100644 --- a/broker_settings.yaml.example +++ b/broker_settings.yaml.example @@ -13,6 +13,7 @@ host_ssh_key_filename: "" host_ipv6: False # If IPv6 connection attempts fail, fallback to IPv4 host_ipv4_fallback: True +ssh_backend: ssh2-python312 # Provider settings AnsibleTower: base_url: "https:///" diff --git a/pyproject.toml b/pyproject.toml index dc755a23..04b5d66d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,12 @@ setup = [ "build", "twine", ] + ssh2_py311 = ["ssh2-python"] +ssh2_python = ["ssh2-python"] +ssh2_python312 = ["ssh2-python312"] +ansible_pylibssh = ["ansible-pylibssh"] +hussh = ["hussh"] [project.scripts] broker = "broker.commands:cli" diff --git a/tests/functional/test_containers.py b/tests/functional/test_containers.py index 61cd5d91..89d86fab 100644 --- a/tests/functional/test_containers.py +++ b/tests/functional/test_containers.py @@ -32,7 +32,7 @@ def temp_inventory(): @pytest.mark.parametrize( - "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")] + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")], ids=lambda f: f.name.split(".")[0] ) def test_checkout_scenarios(args_file, temp_inventory): result = CliRunner().invoke(cli, ["checkout", "--args-file", args_file]) @@ -40,7 +40,7 @@ def test_checkout_scenarios(args_file, temp_inventory): @pytest.mark.parametrize( - "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")] + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")], ids=lambda f: f.name.split(".")[0] ) def test_execute_scenarios(args_file): result = CliRunner().invoke(cli, ["execute", "--args-file", args_file]) diff --git a/tests/functional/test_rh_beaker.py b/tests/functional/test_rh_beaker.py index 5aa39daf..89098d71 100644 --- a/tests/functional/test_rh_beaker.py +++ b/tests/functional/test_rh_beaker.py @@ -29,7 +29,7 @@ def temp_inventory(): @pytest.mark.parametrize( - "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")] + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")], ids=lambda f: f.name.split(".")[0] ) def test_checkout_scenarios(args_file, temp_inventory): result = CliRunner().invoke(cli, ["checkout", "--args-file", args_file]) @@ -37,7 +37,7 @@ def test_checkout_scenarios(args_file, temp_inventory): # @pytest.mark.parametrize( -# "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")] +# "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")], ids=lambda f: f.name.split(".")[0] # ) # def test_execute_scenarios(args_file): # result = CliRunner().invoke(cli, ["execute", "--args-file", args_file]) diff --git a/tests/functional/test_satlab.py b/tests/functional/test_satlab.py index bac090f3..9652d4e8 100644 --- a/tests/functional/test_satlab.py +++ b/tests/functional/test_satlab.py @@ -1,9 +1,12 @@ from pathlib import Path from tempfile import NamedTemporaryFile + import pytest from click.testing import CliRunner + from broker import Broker from broker.commands import cli +from broker.hosts import Host from broker.providers.ansible_tower import AnsibleTower from broker.settings import inventory_path @@ -29,7 +32,7 @@ def temp_inventory(): @pytest.mark.parametrize( - "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")] + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("checkout_")], ids=lambda f: f.name.split(".")[0] ) def test_checkout_scenarios(args_file, temp_inventory): result = CliRunner().invoke(cli, ["checkout", "--args-file", args_file]) @@ -37,7 +40,7 @@ def test_checkout_scenarios(args_file, temp_inventory): @pytest.mark.parametrize( - "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")] + "args_file", [f for f in SCENARIO_DIR.iterdir() if f.name.startswith("execute_")], ids=lambda f: f.name.split(".")[0] ) def test_execute_scenarios(args_file): result = CliRunner().invoke(cli, ["execute", "--args-file", args_file]) @@ -68,7 +71,7 @@ def test_tower_host(): assert res.stdout.strip() == r_host.hostname loc_settings_path = Path("broker_settings.yaml") remote_dir = "/tmp/fake" - r_host.session.sftp_write(loc_settings_path.name, f"{remote_dir}/") + r_host.session.sftp_write(loc_settings_path.name, f"{remote_dir}/", ensure_dir=True) res = r_host.execute(f"ls {remote_dir}") assert str(loc_settings_path) in res.stdout with NamedTemporaryFile() as tmp: @@ -88,8 +91,8 @@ def test_tower_host(): r_host.execute(f"echo 'hello world' > {tailed_file}") with r_host.session.tail_file(tailed_file) as tf: r_host.execute(f"echo 'this is a new line' >> {tailed_file}") - assert "this is a new line" in tf.stdout - assert "hello world" not in tf.stdout + assert "this is a new line" in tf.contents + assert "hello world" not in tf.contents def test_tower_host_mp(): @@ -99,7 +102,7 @@ def test_tower_host_mp(): assert res.stdout.strip() == r_host.hostname loc_settings_path = Path("broker_settings.yaml") remote_dir = "/tmp/fake" - r_host.session.sftp_write(loc_settings_path.name, f"{remote_dir}/") + r_host.session.sftp_write(loc_settings_path.name, f"{remote_dir}/", ensure_dir=True) res = r_host.execute(f"ls {remote_dir}") assert str(loc_settings_path) in res.stdout with NamedTemporaryFile() as tmp: