diff --git a/src/ssh/HISTORY.md b/src/ssh/HISTORY.md index 84449613487..9065fdee500 100644 --- a/src/ssh/HISTORY.md +++ b/src/ssh/HISTORY.md @@ -1,5 +1,15 @@ Release History =============== +1.0.1 +----- +* Added --ssh-client-folder parameter. +* Fixed issues caused when there are spaces or non-english characters in paths provided by users. +* Ensure all paths provided by users are converted to absolute paths. +* Print OpenSSH error messages to console on "az ssh vm". +* Print level1 SSH client log messages when running "az ssh vm" in debug mode. +* Change "isPreview". +* Correctly find pre-installed OpenSSH binaries on Windows 32bit machines. + 1.0.0 ----- * Delete all keys and certificates created during execution of ssh vm. diff --git a/src/ssh/azext_ssh/_params.py b/src/ssh/azext_ssh/_params.py index 2a1a5b33ce9..bd7b59f9241 100644 --- a/src/ssh/azext_ssh/_params.py +++ b/src/ssh/azext_ssh/_params.py @@ -19,6 +19,9 @@ def load_arguments(self, _): c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to a certificate file used for authentication when using local user credentials.') c.argument('port', options_list=['--port'], help='SSH port') + c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], + help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). ' + 'Default to ssh pre-installed if not provided.') c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH') with self.argument_context('ssh config') as c: @@ -37,6 +40,9 @@ def load_arguments(self, _): help='Folder where new generated keys will be stored.') c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to certificate file') c.argument('port', options_list=['--port'], help='SSH port') + c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], + help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). ' + 'Default to ssh pre-installed if not provided.') with self.argument_context('ssh cert') as c: c.argument('cert_path', options_list=['--file', '-f'], @@ -44,3 +50,6 @@ def load_arguments(self, _): c.argument('public_key_file', options_list=['--public-key-file', '-p'], help='The RSA public key file path. If not provided, ' 'generated key pair is stored in the same directory as --file.') + c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], + help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). ' + 'Default to ssh pre-installed if not provided.') diff --git a/src/ssh/azext_ssh/azext_metadata.json b/src/ssh/azext_ssh/azext_metadata.json index 6a44beb25b4..17c19bca98e 100644 --- a/src/ssh/azext_ssh/azext_metadata.json +++ b/src/ssh/azext_ssh/azext_metadata.json @@ -1,4 +1,4 @@ { - "azext.isPreview": true, + "azext.isPreview": false, "azext.minCliCoreVersion": "2.4.0" } \ No newline at end of file diff --git a/src/ssh/azext_ssh/custom.py b/src/ssh/azext_ssh/custom.py index 1fd295d2c3a..78100f6aaa7 100644 --- a/src/ssh/azext_ssh/custom.py +++ b/src/ssh/azext_ssh/custom.py @@ -3,99 +3,145 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import functools import os import hashlib import json import tempfile +import colorama +from colorama import Fore +from colorama import Style + from knack import log from azure.cli.core import azclierror from . import ip_utils from . import rsa_parser from . import ssh_utils +from . import ssh_info logger = log.get_logger(__name__) def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None, - ssh_args=None): + ssh_client_folder=None, ssh_args=None): + + if '--debug' in cmd.cli_ctx.data['safe_params'] and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_args): + ssh_args = ['-v'] if not ssh_args else ['-v'] + ssh_args + _assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user) credentials_folder = None - op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - local_user, cert_file, credentials_folder, op_call) + op_call = ssh_utils.start_ssh_connection + ssh_session = ssh_info.SSHSession(resource_group_name, vm_name, ssh_ip, public_key_file, + private_key_file, use_private_ip, local_user, cert_file, port, + ssh_client_folder, ssh_args) + _do_ssh_op(cmd, ssh_session, credentials_folder, op_call) def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False, - local_user=None, cert_file=None, port=None, credentials_folder=None): + local_user=None, cert_file=None, port=None, credentials_folder=None, ssh_client_folder=None): + _assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user) # If user provides their own key pair, certificate will be written in the same folder as public key. if (public_key_file or private_key_file) and credentials_folder: raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with " "--public-key-file/-p or --private-key-file/-i.") - op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite, port) + config_session = ssh_info.ConfigSession(config_path, resource_group_name, vm_name, ssh_ip, public_key_file, + private_key_file, overwrite, use_private_ip, local_user, cert_file, + port, ssh_client_folder) + + op_call = ssh_utils.write_ssh_config + + # if the folder doesn't exist, this extension won't create a new one. + config_folder = os.path.dirname(config_session.config_path) + if not os.path.isdir(config_folder): + raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} " + "does not exist.") + # Default credential location + # Add logic to test if credentials folder is valid if not credentials_folder: - config_folder = os.path.dirname(config_path) - if not os.path.isdir(config_folder): - raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} " - "does not exist.") - folder_name = ssh_ip - if resource_group_name and vm_name: - folder_name = resource_group_name + "-" + vm_name + # * is not a valid name for a folder in Windows. Treat this as a special case. + folder_name = config_session.ip if config_session.ip != "*" else "all_ips" + if config_session.resource_group_name and config_session.vm_name: + folder_name = config_session.resource_group_name + "-" + config_session.vm_name credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name)) + else: + credentials_folder = os.path.abspath(credentials_folder) - _do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - local_user, cert_file, credentials_folder, op_call) + _do_ssh_op(cmd, config_session, credentials_folder, op_call) -def ssh_cert(cmd, cert_path=None, public_key_file=None): +def ssh_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None): if not cert_path and not public_key_file: raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.") if cert_path and not os.path.isdir(os.path.dirname(cert_path)): raise azclierror.InvalidArgumentValueError(f"{os.path.dirname(cert_path)} folder doesn't exist") + + if public_key_file: + public_key_file = os.path.abspath(public_key_file) + if cert_path: + cert_path = os.path.abspath(cert_path) + if ssh_client_folder: + ssh_client_folder = os.path.abspath(ssh_client_folder) + # If user doesn't provide a public key, save generated key pair to the same folder as --file keys_folder = None if not public_key_file: keys_folder = os.path.dirname(cert_path) - logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate " - "is no longer being used.", keys_folder) - public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder) - cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path) - print(cert_file + "\n") + + public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder, ssh_client_folder) + cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder) + + if keys_folder: + logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). " + "Please delete once this certificate is no longer being used.", keys_folder) + + colorama.init() + # pylint: disable=broad-except + try: + cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1] + print(Fore.GREEN + f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time." + + Style.RESET_ALL) + except Exception as e: + logger.warning("Couldn't determine certificate validity. Error: %s", str(e)) + print(Fore.GREEN + f"Generated SSH certificate {cert_file}." + Style.RESET_ALL) -def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, - username, cert_file, credentials_folder, op_call): +def _do_ssh_op(cmd, op_info, credentials_folder, op_call): # Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys - ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip) + op_info.ip = op_info.ip or ip_utils.get_ssh_ip(cmd, op_info.resource_group_name, + op_info.vm_name, op_info.use_private_ip) - if not ssh_ip: - if not use_private_ip: - raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public IP address to SSH to") + if not op_info.ip: + if not op_info.use_private_ip: + raise azclierror.ResourceNotFoundError(f"VM '{op_info.vm_name}' does not have a public " + "IP address to SSH to") - raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to") + raise azclierror.ResourceNotFoundError("Internal Error. Couldn't determine the IP address.") # If user provides local user, no credentials should be deleted. delete_keys = False delete_cert = False # If user provides a local user, use the provided credentials for authentication - if not username: + if not op_info.local_user: delete_cert = True - public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file, - private_key_file, - credentials_folder) - cert_file, username = _get_and_write_certificate(cmd, public_key_file, None) + op_info.public_key_file, op_info.private_key_file, delete_keys = \ + _check_or_create_public_private_files(op_info.public_key_file, + op_info.private_key_file, + credentials_folder, + op_info.ssh_client_folder) - op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert) + op_info.cert_file, op_info.local_user = _get_and_write_certificate(cmd, op_info.public_key_file, + None, op_info.ssh_client_folder) + op_call(op_info, delete_keys, delete_cert) -def _get_and_write_certificate(cmd, public_key_file, cert_file): + +def _get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder): cloudtoscope = { "azurecloud": "https://pas.windows.net/CheckMyAccess/Linux/.default", "azurechinacloud": "https://pas.chinacloudapi.cn/CheckMyAccess/Linux/.default", @@ -124,9 +170,11 @@ def _get_and_write_certificate(cmd, public_key_file, cert_file): if not cert_file: cert_file = public_key_file + "-aadcert.pub" + + logger.debug("Generating certificate %s", cert_file) _write_cert_file(certificate, cert_file) # instead we use the validprincipals from the cert due to mismatched upn and email in guest scenarios - username = ssh_utils.get_ssh_cert_principals(cert_file)[0] + username = ssh_utils.get_ssh_cert_principals(cert_file, ssh_client_folder)[0] return cert_file, username.lower() @@ -172,7 +220,8 @@ def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username): raise azclierror.FileOperationError(f"Certificate file {cert_file} not found") -def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder): +def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, + ssh_client_folder=None): delete_keys = False # If nothing is passed, then create a directory with a ephemeral keypair if not public_key_file and not private_key_file: @@ -189,7 +238,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre os.makedirs(credentials_folder) public_key_file = os.path.join(credentials_folder, "id_rsa.pub") private_key_file = os.path.join(credentials_folder, "id_rsa") - ssh_utils.create_ssh_keyfile(private_key_file) + ssh_utils.create_ssh_keyfile(private_key_file, ssh_client_folder) if not public_key_file: if private_key_file: @@ -210,7 +259,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre def _write_cert_file(certificate_contents, cert_file): - with open(cert_file, 'w') as f: + with open(cert_file, 'w', encoding='utf-8') as f: f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}") return cert_file @@ -220,7 +269,7 @@ def _get_modulus_exponent(public_key_file): if not os.path.isfile(public_key_file): raise azclierror.FileOperationError(f"Public key file '{public_key_file}' was not found") - with open(public_key_file, 'r') as f: + with open(public_key_file, 'r', encoding='utf-8') as f: public_key_text = f.read() parser = rsa_parser.RSAParser() diff --git a/src/ssh/azext_ssh/file_utils.py b/src/ssh/azext_ssh/file_utils.py index b31927b9268..1e178ab44a2 100644 --- a/src/ssh/azext_ssh/file_utils.py +++ b/src/ssh/azext_ssh/file_utils.py @@ -28,6 +28,7 @@ def mkdir_p(path): def delete_file(file_path, message, warning=False): + # pylint: disable=broad-except if os.path.isfile(file_path): try: os.remove(file_path) @@ -39,6 +40,7 @@ def delete_file(file_path, message, warning=False): def delete_folder(dir_path, message, warning=False): + # pylint: disable=broad-except if os.path.isdir(dir_path): try: os.rmdir(dir_path) diff --git a/src/ssh/azext_ssh/ssh_info.py b/src/ssh/azext_ssh/ssh_info.py new file mode 100644 index 00000000000..6ef02efab67 --- /dev/null +++ b/src/ssh/azext_ssh/ssh_info.py @@ -0,0 +1,95 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import os +from azure.cli.core import azclierror + + +class SSHSession(): + # pylint: disable=too-many-instance-attributes + def __init__(self, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, + use_private_ip, local_user, cert_file, port, ssh_client_folder, ssh_args): + self.resource_group_name = resource_group_name + self.vm_name = vm_name + self.ip = ssh_ip + self.use_private_ip = use_private_ip + self.local_user = local_user + self.port = port + self.ssh_args = ssh_args + self.public_key_file = os.path.abspath(public_key_file) if public_key_file else None + self.private_key_file = os.path.abspath(private_key_file) if private_key_file else None + self.cert_file = os.path.abspath(cert_file) if cert_file else None + self.ssh_client_folder = os.path.abspath(ssh_client_folder) if ssh_client_folder else None + + def get_host(self): + if self.local_user and self.ip: + return self.local_user + "@" + self.ip + raise azclierror.BadRequestError("Unable to determine host.") + + def build_args(self): + private_key = [] + port_arg = [] + certificate = [] + if self.private_key_file: + private_key = ["-i", self.private_key_file] + if self.port: + port_arg = ["-p", self.port] + if self.cert_file: + certificate = ["-o", "CertificateFile=\"" + self.cert_file + "\""] + return private_key + certificate + port_arg + + +class ConfigSession(): + # pylint: disable=too-few-public-methods + # pylint: disable=too-many-instance-attributes + def __init__(self, config_path, resource_group_name, vm_name, ssh_ip, public_key_file, + private_key_file, overwrite, use_private_ip, local_user, cert_file, port, + ssh_client_folder): + self.config_path = os.path.abspath(config_path) + self.resource_group_name = resource_group_name + self.vm_name = vm_name + self.ip = ssh_ip + self.overwrite = overwrite + self.use_private_ip = use_private_ip + self.local_user = local_user + self.port = port + self.public_key_file = os.path.abspath(public_key_file) if public_key_file else None + self.private_key_file = os.path.abspath(private_key_file) if private_key_file else None + self.cert_file = os.path.abspath(cert_file) if cert_file else None + self.ssh_client_folder = os.path.abspath(ssh_client_folder) if ssh_client_folder else None + + def get_config_text(self): + lines = [""] + if self.resource_group_name and self.vm_name and self.ip: + lines = lines + self._get_rg_and_vm_entry() + # default to all hosts for config + if not self.ip: + self.ip = "*" + lines = lines + self._get_ip_entry() + return lines + + def _get_rg_and_vm_entry(self): + lines = [] + lines.append("Host " + self.resource_group_name + "-" + self.vm_name) + lines.append("\tUser " + self.local_user) + lines.append("\tHostName " + self.ip) + if self.cert_file: + lines.append("\tCertificateFile \"" + self.cert_file + "\"") + if self.private_key_file: + lines.append("\tIdentityFile \"" + self.private_key_file + "\"") + if self.port: + lines.append("\tPort " + self.port) + return lines + + def _get_ip_entry(self): + lines = [] + lines.append("Host " + self.ip) + lines.append("\tUser " + self.local_user) + if self.cert_file: + lines.append("\tCertificateFile \"" + self.cert_file + "\"") + if self.private_key_file: + lines.append("\tIdentityFile \"" + self.private_key_file + "\"") + if self.port: + lines.append("\tPort " + self.port) + return lines diff --git a/src/ssh/azext_ssh/ssh_utils.py b/src/ssh/azext_ssh/ssh_utils.py index bf27741aa7a..6b33833c47b 100644 --- a/src/ssh/azext_ssh/ssh_utils.py +++ b/src/ssh/azext_ssh/ssh_utils.py @@ -7,6 +7,8 @@ import subprocess import time import multiprocessing as mp +import datetime +import re from azext_ssh import file_utils from knack import log @@ -19,11 +21,13 @@ CLEANUP_AWAIT_TERMINATION_IN_SECONDS = 30 -def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_file, delete_keys, delete_cert): +def start_ssh_connection(ssh_info, delete_keys, delete_cert): ssh_arg_list = [] - if ssh_args: - ssh_arg_list = ssh_args + if ssh_info.ssh_args: + ssh_arg_list = ssh_info.ssh_args + + command = [_get_ssh_client_path('ssh', ssh_info.ssh_client_folder), ssh_info.get_host()] log_file = None if delete_keys or delete_cert: @@ -31,20 +35,19 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi # If the user either provides his own client log file (-E) or # wants the client log messages to be printed to the console (-vvv/-vv/-v), # we should not use the log files to check for connection success. - log_file_dir = os.path.dirname(cert_file) + log_file_dir = os.path.dirname(ssh_info.cert_file) log_file_name = 'ssh_client_log_' + str(os.getpid()) log_file = os.path.join(log_file_dir, log_file_name) ssh_arg_list = ['-E', log_file, '-v'] + ssh_arg_list # Create a new process that will wait until the connection is established and then delete keys. - cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, delete_cert, cert_file, private_key_file, - log_file, True)) + cleanup_process = mp.Process(target=_do_cleanup, args=(delete_keys, delete_cert, ssh_info.cert_file, + ssh_info.private_key_file, ssh_info.public_key_file, log_file, True)) cleanup_process.start() - command = [_get_ssh_path(), _get_host(username, ip)] - command = command + _build_args(cert_file, private_key_file, port) + ssh_arg_list + command = command + ssh_info.build_args() + ssh_arg_list logger.debug("Running ssh command %s", ' '.join(command)) - subprocess.call(command, shell=platform.system() == 'Windows') + connection_status = subprocess.call(command, shell=platform.system() == 'Windows') if delete_keys or delete_cert: if cleanup_process.is_alive(): @@ -54,29 +57,69 @@ def start_ssh_connection(port, ssh_args, ip, username, cert_file, private_key_fi while cleanup_process.is_alive() and (time.time() - t0) < CLEANUP_AWAIT_TERMINATION_IN_SECONDS: time.sleep(1) + if log_file: + _print_error_messages_from_ssh_log(log_file, connection_status) + # Make sure all files have been properly removed. - _do_cleanup(delete_keys, delete_cert, cert_file, private_key_file) + _do_cleanup(delete_keys, delete_cert, ssh_info.cert_file, ssh_info.private_key_file, ssh_info.public_key_file) if log_file: file_utils.delete_file(log_file, f"Couldn't delete temporary log file {log_file}. ", True) if delete_keys: - temp_dir = os.path.dirname(cert_file) + temp_dir = os.path.dirname(ssh_info.cert_file) file_utils.delete_folder(temp_dir, f"Couldn't delete temporary folder {temp_dir}", True) -def create_ssh_keyfile(private_key_file): - command = [_get_ssh_path("ssh-keygen"), "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] +def write_ssh_config(config_info, delete_keys, delete_cert): + + if delete_keys or delete_cert: + # Warn users to delete credentials once config file is no longer being used. + # If user provided keys, only ask them to delete the certificate. + path_to_delete = os.path.dirname(config_info.cert_file) + items_to_delete = " (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub)" + if not delete_keys: + path_to_delete = config_info.cert_file + items_to_delete = "" + + expiration = None + # pylint: disable=broad-except + try: + expiration = get_certificate_start_and_end_times(config_info.cert_file, config_info.ssh_client_folder)[1] + expiration = expiration.strftime("%Y-%m-%d %I:%M:%S %p") + except Exception as e: + logger.warning("Couldn't determine certificate expiration. Error: %s", str(e)) + + if expiration: + logger.warning("The generated certificate %s is valid until %s in local time.", + config_info.cert_file, expiration) + logger.warning("%s contains sensitive information%s. Please delete it once you no longer this config file.", + path_to_delete, items_to_delete) + + config_text = config_info.get_config_text() + + if config_info.overwrite: + mode = 'w' + else: + mode = 'a' + with open(config_info.config_path, mode, encoding='utf-8') as f: + f.write('\n'.join(config_text)) + + +def create_ssh_keyfile(private_key_file, ssh_client_folder=None): + sshkeygen_path = _get_ssh_client_path("ssh-keygen", ssh_client_folder) + command = [sshkeygen_path, "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] logger.debug("Running ssh-keygen command %s", ' '.join(command)) subprocess.call(command, shell=platform.system() == 'Windows') -def get_ssh_cert_info(cert_file): - command = [_get_ssh_path("ssh-keygen"), "-L", "-f", cert_file] +def get_ssh_cert_info(cert_file, ssh_client_folder=None): + sshkeygen_path = _get_ssh_client_path("ssh-keygen", ssh_client_folder) + command = [sshkeygen_path, "-L", "-f", cert_file] logger.debug("Running ssh-keygen command %s", ' '.join(command)) return subprocess.check_output(command, shell=platform.system() == 'Windows').decode().splitlines() -def get_ssh_cert_principals(cert_file): - info = get_ssh_cert_info(cert_file) +def get_ssh_cert_principals(cert_file, ssh_client_folder=None): + info = get_ssh_cert_info(cert_file, ssh_client_folder) principals = [] in_principal = False for line in info: @@ -91,116 +134,139 @@ def get_ssh_cert_principals(cert_file): return principals -def get_ssh_cert_validity(cert_file): - info = get_ssh_cert_info(cert_file) - for line in info: - if "Valid:" in line: - return line.strip() +def _get_ssh_cert_validity(cert_file, ssh_client_folder=None): + if cert_file: + info = get_ssh_cert_info(cert_file, ssh_client_folder) + for line in info: + if "Valid:" in line: + return line.strip() return None -def write_ssh_config(config_path, resource_group, vm_name, overwrite, port, - ip, username, cert_file, private_key_file, delete_keys, delete_cert): - - if delete_keys or delete_cert: - # Warn users to delete credentials once config file is no longer being used. - # If user provided keys, only ask them to delete the certificate. - path_to_delete = os.path.dirname(cert_file) - items_to_delete = " (id_rsa, id_rsa.pub, id_rsa.pub-aadcert.pub)" - if not delete_keys: - path_to_delete = cert_file - items_to_delete = "" - validity = get_ssh_cert_validity(cert_file) - validity_warning = "" - if validity: - validity_warning = f" {validity.lower()}" - logger.warning("%s contains sensitive information%s%s\n" - "Please delete it once you no longer need this config file. ", - path_to_delete, items_to_delete, validity_warning) - - lines = [""] - - if resource_group and vm_name: - lines.append("Host " + resource_group + "-" + vm_name) - lines.append("\tUser " + username) - lines.append("\tHostName " + ip) - if cert_file: - lines.append("\tCertificateFile " + cert_file) - if private_key_file: - lines.append("\tIdentityFile " + private_key_file) - if port: - lines.append("\tPort " + port) - - # default to all hosts for config - if not ip: - ip = "*" - - lines.append("Host " + ip) - lines.append("\tUser " + username) - if cert_file: - lines.append("\tCertificateFile " + cert_file) - if private_key_file: - lines.append("\tIdentityFile " + private_key_file) - if port: - lines.append("\tPort " + port) - - if overwrite: - mode = 'w' - else: - mode = 'a' +def get_certificate_start_and_end_times(cert_file, ssh_client_folder=None): + validity_str = _get_ssh_cert_validity(cert_file, ssh_client_folder) + times = None + if validity_str and "Valid: from " in validity_str and " to " in validity_str: + times = validity_str.replace("Valid: from ", "").split(" to ") + t0 = datetime.datetime.strptime(times[0], '%Y-%m-%dT%X') + t1 = datetime.datetime.strptime(times[1], '%Y-%m-%dT%X') + times = (t0, t1) + return times + + +def _print_error_messages_from_ssh_log(log_file, connection_status): + with open(log_file, 'r', encoding='utf-8') as ssh_log: + log_text = ssh_log.read() + log_lines = log_text.splitlines() + if "debug1: Authentication succeeded" not in log_text or connection_status != 0: + for line in log_lines: + if "debug1:" not in line: + print(line) + + if "Permission denied (publickey)." in log_text: + # pylint: disable=bare-except + # pylint: disable=too-many-boolean-expressions + # Check if OpenSSH client and server versions are incompatible + try: + regex = 'OpenSSH.*_([0-9]+)\\.([0-9]+)' + local_major, local_minor = re.findall(regex, log_lines[0])[0] + remote_major, remote_minor = re.findall(regex, _get_line_that_contains("remote software version", + log_lines))[0] + local_major = int(local_major) + local_minor = int(local_minor) + remote_major = int(remote_major) + remote_minor = int(remote_minor) + except: + ssh_log.close() + return + + if (remote_major < 7 or (remote_major == 7 and remote_minor < 8)) and \ + (local_major > 8 or (local_major == 8 and local_minor >= 8)): + logger.warning("The OpenSSH server version in the target VM %d.%d is too old. " + "Version incompatible with OpenSSH client version %d.%d. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + remote_major, remote_minor, local_major, local_minor) + + elif (local_major < 7 or (local_major == 7 and local_minor < 8)) and \ + (remote_major > 8 or (remote_major == 8 and remote_minor >= 8)): + logger.warning("The OpenSSH client version %d.%d is too old. " + "Version incompatible with OpenSSH server version %d.%d in the target VM. " + "Refer to https://bugzilla.mindrot.org/show_bug.cgi?id=3351 for more information.", + local_major, local_minor, remote_major, remote_minor) + ssh_log.close() + + +def _get_line_that_contains(substring, lines): + for line in lines: + if substring in line: + return line + return None - with open(config_path, mode) as f: - f.write('\n'.join(lines)) +def _get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): + if ssh_client_folder: + ssh_path = os.path.join(ssh_client_folder, ssh_command) + if platform.system() == 'Windows': + ssh_path = ssh_path + '.exe' + if os.path.isfile(ssh_path): + logger.debug("Attempting to run %s from path %s", ssh_command, ssh_path) + return ssh_path + logger.warning("Could not find %s in provided --ssh-client-folder %s. " + "Attempting to get pre-installed OpenSSH bits.", ssh_command, ssh_client_folder) -def _get_ssh_path(ssh_command="ssh"): ssh_path = ssh_command if platform.system() == 'Windows': - arch_data = platform.architecture() - is_32bit = arch_data[0] == '32bit' - sys_path = 'SysNative' if is_32bit else 'System32' + # If OS architecture is 64bit and python architecture is 32bit, + # look for System32 under SysNative folder. + machine = platform.machine() + os_architecture = None + # python interpreter architecture + platform_architecture = platform.architecture()[0] + sys_path = None + + if machine.endswith('64'): + os_architecture = '64bit' + elif machine.endswith('86'): + os_architecture = '32bit' + elif machine == '': + raise azclierror.BadRequestError("Couldn't identify the OS architecture.") + else: + raise azclierror.BadRequestError(f"Unsuported OS architecture: {machine} is not currently supported") + + if os_architecture == "64bit": + sys_path = 'SysNative' if platform_architecture == '32bit' else 'System32' + else: + sys_path = 'System32' + system_root = os.environ['SystemRoot'] system32_path = os.path.join(system_root, sys_path) ssh_path = os.path.join(system32_path, "openSSH", (ssh_command + ".exe")) - logger.debug("Platform architecture: %s", str(arch_data)) + logger.debug("Platform architecture: %s", platform_architecture) + logger.debug("OS architecture: %s", os_architecture) logger.debug("System Root: %s", system_root) - logger.debug("Attempting to run ssh from path %s", ssh_path) + logger.debug("Attempting to run %s from path %s", ssh_command, ssh_path) if not os.path.isfile(ssh_path): raise azclierror.UnclassifiedUserFault( - "Could not find " + ssh_command + ".exe.", - "https://docs.microsoft.com/en-us/windows-server/administration/openssh/openssh_install_firstuse") + "Could not find " + ssh_command + ".exe on path " + ssh_path + ". " + "Make sure OpenSSH is installed correctly: " + "https://docs.microsoft.com/en-us/windows-server/administration/openssh/openssh_install_firstuse . " + "Or use --ssh-client-folder to provide folder path with ssh executables. ") return ssh_path -def _get_host(username, ip): - return username + "@" + ip - - -def _build_args(cert_file, private_key_file, port): - private_key = [] - port_arg = [] - certificate = [] - if private_key_file: - private_key = ["-i", private_key_file] - if port: - port_arg = ["-p", port] - if cert_file: - certificate = ["-o", "CertificateFile=" + cert_file] - return private_key + certificate + port_arg - - -def _do_cleanup(delete_keys, delete_cert, cert_file, private_key, log_file=None, wait=False): +def _do_cleanup(delete_keys, delete_cert, cert_file, private_key, public_key, log_file=None, wait=False): # if there is a log file, use it to check for the connection success if log_file: t0 = time.time() match = False while (time.time() - t0) < CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS and not match: time.sleep(CLEANUP_TIME_INTERVAL_IN_SECONDS) + # pylint: disable=bare-except try: - with open(log_file, 'r') as ssh_client_log: + with open(log_file, 'r', encoding='utf-8') as ssh_client_log: match = "debug1: Authentication succeeded" in ssh_client_log.read() ssh_client_log.close() except: @@ -211,10 +277,9 @@ def _do_cleanup(delete_keys, delete_cert, cert_file, private_key, log_file=None, # if we are not checking the logs, but still want to wait for connection before deleting files time.sleep(CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS) - # TO DO: Once arc changes are merged, delete relay information as well if delete_keys and private_key: - public_key = private_key + '.pub' file_utils.delete_file(private_key, f"Couldn't delete private key {private_key}. ", True) + if delete_keys and public_key: file_utils.delete_file(public_key, f"Couldn't delete public key {public_key}. ", True) if delete_cert and cert_file: file_utils.delete_file(cert_file, f"Couldn't delete certificate {cert_file}. ", True) diff --git a/src/ssh/azext_ssh/tests/latest/test_custom.py b/src/ssh/azext_ssh/tests/latest/test_custom.py index 7952b63f3c1..96ee88879dc 100644 --- a/src/ssh/azext_ssh/tests/latest/test_custom.py +++ b/src/ssh/azext_ssh/tests/latest/test_custom.py @@ -10,42 +10,103 @@ from azext_ssh import custom +from azext_ssh import ssh_info +from azext_ssh import ssh_utils class SshCustomCommandTest(unittest.TestCase): + @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.custom._assert_args') - def test_ssh_vm(self, mock_assert, mock_do_op): + @mock.patch('azext_ssh.ssh_info.SSHSession') + def test_ssh_vm(self, mock_info, mock_assert, mock_do_op): cmd = mock.Mock() - custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", None) + cmd.cli_ctx.data = {'safe_params': []} + + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv']) + + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv']) mock_assert.assert_called_once_with("rg", "vm", "ip", "cert", "username") - mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", None, mock.ANY) - + mock_do_op.assert_called_once_with(cmd, mock.ANY, None, ssh_utils.start_ssh_connection) + + @mock.patch('azext_ssh.custom._do_ssh_op') + @mock.patch('azext_ssh.custom._assert_args') + @mock.patch('azext_ssh.ssh_info.SSHSession') + def test_ssh_vm_debug(self, mock_info, mock_assert, mock_do_op): + cmd = mock.Mock() + + cmd.cli_ctx.data = {'safe_params': ['--debug']} + + custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", []) + + mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-v']) + mock_assert.assert_called_once_with("rg", "vm", "ip", "cert", "username") + mock_do_op.assert_called_once_with(cmd, mock.ANY, None, ssh_utils.start_ssh_connection) + @mock.patch('azext_ssh.custom._do_ssh_op') @mock.patch('azext_ssh.ssh_utils.write_ssh_config') @mock.patch('azext_ssh.custom._assert_args') @mock.patch('os.path.isdir') @mock.patch('os.path.dirname') @mock.patch('os.path.join') - def test_ssh_config(self, mock_join, mock_dirname, mock_isdir, mock_assert, mock_ssh_utils, mock_do_op): + @mock.patch('azext_ssh.ssh_info.ConfigSession') + def test_ssh_config(self, mock_info, mock_join, mock_dirname, mock_isdir, mock_assert, mock_ssh_utils, mock_do_op): cmd = mock.Mock() mock_dirname.return_value = "configdir" mock_isdir.return_value = True mock_join.side_effect = ['az_ssh_config/rg-vm', 'path/to/az_ssh_config/rg-vm'] - def do_op_side_effect(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip, local_user, cert_file, credentials_folder, op_call): - op_call(ssh_ip, "username", "cert", private_key_file, False, False) - - mock_do_op.side_effect = do_op_side_effect - custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False, "username", "cert", "port", None) + custom.ssh_config(cmd, "path/to/file", "rg", "vm", "ip", "public", "private", False, False, "username", "cert", "port", None, "client/folder") - mock_ssh_utils.assert_called_once_with("path/to/file", "rg", "vm", False, "port", "ip", "username", "cert", "private", False, False) + mock_info.assert_called_once_with("path/to/file", "rg", "vm", "ip", "public", "private", False, False, "username", "cert", "port", "client/folder") mock_assert.assert_called_once_with("rg", "vm", "ip", "cert", "username") - mock_do_op.assert_called_once_with( - cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", 'path/to/az_ssh_config/rg-vm', mock.ANY) + mock_do_op.assert_called_once_with(cmd, mock.ANY, 'path/to/az_ssh_config/rg-vm', ssh_utils.write_ssh_config) + + def test_ssh_cert_no_args(self): + cmd = mock.Mock() + self.assertRaises( + azclierror.RequiredArgumentMissingError, custom.ssh_cert, cmd) + @mock.patch('os.path.isdir') + def test_ssh_cert_cert_file_missing(self, mock_isdir): + cmd = mock.Mock() + mock_isdir.return_value = False + self.assertRaises( + azclierror.InvalidArgumentValueError, custom.ssh_cert, cmd, cert_path="cert") + + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') + @mock.patch('azext_ssh.custom._get_and_write_certificate') + def test_ssh_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdir): + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = ['/pubkey/path', '/cert/path', '/client/path'] + mock_get_keys.return_value = "pubkey", "privkey", False + mock_write_cert.return_value = "cert", "username" + + custom.ssh_cert(cmd, "cert", "pubkey", "ssh/folder") + + mock_get_keys.assert_called_once_with('/pubkey/path', None, None, '/client/path') + mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', '/client/path') + + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') + @mock.patch('azext_ssh.custom._get_and_write_certificate') + def test_ssh_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdir): + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = ['/pubkey/path', '/cert/path', '/client/path'] + mock_get_keys.return_value = "pubkey", "privkey", False + mock_write_cert.return_value = "cert", "username" + + custom.ssh_cert(cmd, "cert", "pubkey", "ssh/folder") + + mock_get_keys.assert_called_once_with('/pubkey/path', None, None, '/client/path') + mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', '/client/path') + @mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals') @mock.patch('os.path.join') @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @@ -59,6 +120,16 @@ def test_do_ssh_op_aad_user(self, mock_write_cert, mock_ssh_creds, mock_get_mod_ cmd.cli_ctx = mock.Mock() cmd.cli_ctx.cloud = mock.Mock() cmd.cli_ctx.cloud.name = "azurecloud" + + op_info = mock.Mock() + op_info.ip = "1.2.3.4" + op_info.public_key_file = "publicfile" + op_info.private_key_file = "privatefile" + op_info.use_private_ip = False + op_info.local_user = None + op_info.cert_file = None + op_info.ssh_client_folder = "/client/folder" + mock_op = mock.Mock() mock_check_files.return_value = "public", "private", False mock_principal.return_value = ["username"] @@ -68,15 +139,14 @@ def test_do_ssh_op_aad_user(self, mock_write_cert, mock_ssh_creds, mock_get_mod_ profile.get_msal_token.return_value = "username", "certificate" mock_join.return_value = "public-aadcert.pub" - custom._do_ssh_op(cmd, None, None, "1.2.3.4", "publicfile", "privatefile", False, None, None, "cred/folder", mock_op) + custom._do_ssh_op(cmd, op_info, "cred/folder", mock_op) - mock_check_files.assert_called_once_with("publicfile", "privatefile", "cred/folder") + mock_check_files.assert_called_once_with("publicfile", "privatefile", "cred/folder", "/client/folder") mock_ip.assert_not_called() mock_get_mod_exp.assert_called_once_with("public") mock_write_cert.assert_called_once_with("certificate", "public-aadcert.pub") - mock_op.assert_called_once_with( - "1.2.3.4", "username", "public-aadcert.pub", "private", False, True) - + mock_op.assert_called_once_with(op_info, False, True) + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @mock.patch('azext_ssh.ip_utils.get_ssh_ip') def test_do_ssh_op_local_user(self, mock_ip, mock_check_files): @@ -84,28 +154,42 @@ def test_do_ssh_op_local_user(self, mock_ip, mock_check_files): mock_op = mock.Mock() mock_ip.return_value = "1.2.3.4" - custom._do_ssh_op(cmd, "vm", "rg", None, "publicfile", "privatefile", False, "username", "cert", "cred/folder", mock_op) + op_info = mock.Mock() + op_info.resource_group_name = "rg" + op_info.vm_name = "vm" + op_info.ip = None + op_info.public_key_file = "publicfile" + op_info.private_key_file = "privatefile" + op_info.use_private_ip = False + op_info.local_user = "username" + op_info.certificate = "cert" + op_info.ssh_client_folder = "/client/folder" + + custom._do_ssh_op(cmd, op_info, "/cred/folder", mock_op) mock_check_files.assert_not_called() - mock_ip.assert_called_once_with(cmd, "vm", "rg", False) - mock_op.assert_called_once_with( - "1.2.3.4", "username", "cert", "privatefile", False, False) - + mock_ip.assert_called_once_with(cmd, "rg", "vm", False) + mock_op.assert_called_once_with(op_info, False, False) + @mock.patch('azext_ssh.custom._check_or_create_public_private_files') @mock.patch('azext_ssh.ip_utils.get_ssh_ip') - @mock.patch('azext_ssh.custom._get_modulus_exponent') - def test_do_ssh_op_no_public_ip(self, mock_get_mod_exp, mock_ip, mock_check_files): + def test_do_ssh_op_no_public_ip(self, mock_ip, mock_check_files): cmd = mock.Mock() mock_op = mock.Mock() - mock_get_mod_exp.return_value = "modulus", "exponent" mock_ip.return_value = None + op_info = mock.Mock() + op_info.vm_name = "vm" + op_info.resource_group_name = "rg" + op_info.ip = None + op_info.use_private_ip = False + self.assertRaises( - azclierror.ResourceNotFoundError, custom._do_ssh_op, cmd, "rg", "vm", None, - "publicfile", "privatefile", False, None, None, "cred/folder", mock_op) + azclierror.ResourceNotFoundError, custom._do_ssh_op, cmd, op_info, "/cred/folder", mock_op) mock_check_files.assert_not_called() mock_ip.assert_called_once_with(cmd, "rg", "vm", False) + mock_op.assert_not_called() def test_assert_args_no_ip_or_vm(self): self.assertRaises(azclierror.RequiredArgumentMissingError, custom._assert_args, None, None, None, None, None) @@ -149,7 +233,7 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf mock.call('/tmp/aadtemp/id_rsa') ]) mock_create.assert_has_calls([ - mock.call('/tmp/aadtemp/id_rsa') + mock.call('/tmp/aadtemp/id_rsa', None) ]) @mock.patch('azext_ssh.ssh_utils.create_ssh_keyfile') @@ -160,7 +244,7 @@ def test_check_or_create_public_private_files_defaults_with_cred_folder(self,moc mock_isfile.return_value = True mock_isdir.return_value = True mock_join.side_effect = ['/cred/folder/id_rsa.pub', '/cred/folder/id_rsa'] - public, private, delete_key = custom._check_or_create_public_private_files(None, None, '/cred/folder') + public, private, delete_key = custom._check_or_create_public_private_files(None, None, '/cred/folder', '/ssh/client') self.assertEqual('/cred/folder/id_rsa.pub', public) self.assertEqual('/cred/folder/id_rsa', private) self.assertEqual(True, delete_key) @@ -173,13 +257,11 @@ def test_check_or_create_public_private_files_defaults_with_cred_folder(self,moc mock.call('/cred/folder/id_rsa') ]) mock_create.assert_has_calls([ - mock.call('/cred/folder/id_rsa') + mock.call('/cred/folder/id_rsa', '/ssh/client') ]) - - + @mock.patch('os.path.isfile') - @mock.patch('os.path.join') - def test_check_or_create_public_private_files_no_public(self, mock_join, mock_isfile): + def test_check_or_create_public_private_files_no_public(self, mock_isfile): mock_isfile.side_effect = [False] self.assertRaises( azclierror.FileOperationError, custom._check_or_create_public_private_files, "public", None, None) @@ -207,7 +289,7 @@ def test_write_cert_file(self, mock_open): custom._write_cert_file("cert", "publickey-aadcert.pub") - mock_open.assert_called_once_with("publickey-aadcert.pub", 'w') + mock_open.assert_called_once_with("publickey-aadcert.pub", 'w', encoding='utf-8') mock_file.write.assert_called_once_with("ssh-rsa-cert-v01@openssh.com cert") @mock.patch('azext_ssh.rsa_parser.RSAParser') @@ -222,7 +304,7 @@ def test_get_modulus_exponent_success(self, mock_open, mock_isfile, mock_parser) self.assertEqual(mock_parser.return_value.modulus, modulus) self.assertEqual(mock_parser.return_value.exponent, exponent) mock_isfile.assert_called_once_with('file') - mock_open.assert_called_once_with('file', 'r') + mock_open.assert_called_once_with('file', 'r', encoding='utf-8') mock_parser.return_value.parse.assert_called_once_with('publickey') @mock.patch('os.path.isfile') @@ -244,6 +326,5 @@ def test_get_modulus_exponent_parse_error(self, mock_open, mock_isfile, mock_par self.assertRaises(azclierror.FileOperationError, custom._get_modulus_exponent, 'file') - if __name__ == '__main__': unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_info.py b/src/ssh/azext_ssh/tests/latest/test_ssh_info.py new file mode 100644 index 00000000000..1157ed2c1e0 --- /dev/null +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_info.py @@ -0,0 +1,126 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from unittest import mock + +from azext_ssh import ssh_info + + +class SSHInfoTest(unittest.TestCase): + @mock.patch('os.path.abspath') + def test_ssh_session(self, mock_abspath): + mock_abspath.side_effect = ["pub_path", "priv_path", "cert_path", "client_path"] + expected_abspath_calls = [ + mock.call("pub"), + mock.call("priv"), + mock.call("cert"), + mock.call("client/folder") + ] + session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", ['-v', '-E', 'path']) + mock_abspath.assert_has_calls(expected_abspath_calls) + self.assertEqual(session.resource_group_name, "rg") + self.assertEqual(session.vm_name, "vm") + self.assertEqual(session.ip, "ip") + self.assertEqual(session.public_key_file, "pub_path") + self.assertEqual(session.private_key_file, "priv_path") + self.assertEqual(session.use_private_ip, False) + self.assertEqual(session.local_user, "user") + self.assertEqual(session.port, "port") + self.assertEqual(session.ssh_args, ['-v', '-E', 'path']) + self.assertEqual(session.cert_file, "cert_path") + self.assertEqual(session.ssh_client_folder, "client_path") + + def test_ssh_session_get_host(self): + session = ssh_info.SSHSession(None, None, "ip", None, None, False, "user", None, None, None, []) + self.assertEqual("user@ip", session.get_host()) + + @mock.patch('os.path.abspath') + def test_ssh_session_build_args(self, mock_abspath): + mock_abspath.side_effect = ["pub_path", "priv_path", "cert_path", "client_path"] + session = ssh_info.SSHSession("rg", "vm", "ip", "pub", "priv", False, "user", "cert", "port", "client/folder", []) + self.assertEqual(["-i", "priv_path", "-o", "CertificateFile=\"cert_path\"", "-p", "port"], session.build_args()) + + @mock.patch('os.path.abspath') + def test_config_session(self, mock_abspath): + mock_abspath.side_effect = ["config_path", "pub_path", "priv_path", "cert_path", "client_path"] + expected_abspath_calls = [ + mock.call("config"), + mock.call("pub"), + mock.call("priv"), + mock.call("cert"), + mock.call("client/folder") + ] + session = ssh_info.ConfigSession("config", "rg", "vm", "ip", "pub", "priv", False, False, "user", "cert", "port", "client/folder") + mock_abspath.assert_has_calls(expected_abspath_calls) + self.assertEqual(session.config_path, "config_path") + self.assertEqual(session.resource_group_name, "rg") + self.assertEqual(session.vm_name, "vm") + self.assertEqual(session.ip, "ip") + self.assertEqual(session.public_key_file, "pub_path") + self.assertEqual(session.private_key_file, "priv_path") + self.assertEqual(session.use_private_ip, False) + self.assertEqual(session.overwrite, False) + self.assertEqual(session.local_user, "user") + self.assertEqual(session.port, "port") + self.assertEqual(session.cert_file, "cert_path") + self.assertEqual(session.ssh_client_folder, "client_path") + + @mock.patch('os.path.abspath') + def test_get_rg_and_vm_entry(self, mock_abspath): + expected_lines = [ + "Host rg-vm", + "\tUser user", + "\tHostName ip", + "\tCertificateFile \"cert_path\"", + "\tIdentityFile \"priv_path\"", + "\tPort port", + ] + + mock_abspath.side_effect = ["config_path", "pub_path", "priv_path", "cert_path", "client_path"] + session = ssh_info.ConfigSession("config", "rg", "vm", "ip", "pub", "priv", False, False, "user", "cert", "port", "client/folder") + + self.assertEqual(session._get_rg_and_vm_entry(), expected_lines) + + @mock.patch('os.path.abspath') + def test_get_ip_entry(self, mock_abspath): + expected_lines = [ + "Host ip", + "\tUser user", + "\tCertificateFile \"cert_path\"", + "\tIdentityFile \"priv_path\"" + ] + + mock_abspath.side_effect = ["config_path", "pub_path", "priv_path", "cert_path", "client_path"] + session = ssh_info.ConfigSession("config", "rg", "vm", "ip", "pub", "priv", False, False, "user", "cert", None, "client/folder") + + self.assertEqual(session._get_ip_entry(), expected_lines) + + @mock.patch('os.path.abspath') + def test_get_config_text(self, mock_abspath): + expected_lines = [ + "", + "Host rg-vm", + "\tUser user", + "\tHostName ip", + "\tCertificateFile \"cert_path\"", + "\tIdentityFile \"priv_path\"", + "\tPort port", + "Host ip", + "\tUser user", + "\tCertificateFile \"cert_path\"", + "\tIdentityFile \"priv_path\"", + "\tPort port", + ] + + mock_abspath.side_effect = ["config_path", "pub_path", "priv_path", "cert_path", "client_path"] + session = ssh_info.ConfigSession("config", "rg", "vm", "ip", "pub", "priv", False, False, "user", "cert", "port", "client/folder") + + self.assertEqual(session.get_config_text(), expected_lines) + + + +if __name__ == '__main__': + unittest.main() diff --git a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py index 9e20e610317..c1ef7f2f134 100644 --- a/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py +++ b/src/ssh/azext_ssh/tests/latest/test_ssh_utils.py @@ -11,46 +11,73 @@ from azext_ssh import ssh_utils class SSHUtilsTests(unittest.TestCase): - @mock.patch('os.path.join') - @mock.patch.object(ssh_utils, '_get_ssh_path') - @mock.patch.object(ssh_utils, '_get_host') - @mock.patch.object(ssh_utils, '_build_args') + @mock.patch.object(ssh_utils, '_get_ssh_client_path') @mock.patch('subprocess.call') - def test_start_ssh_connection(self, mock_call, mock_build, mock_host, mock_path, mock_join): + @mock.patch('azext_ssh.ssh_info.SSHSession.build_args') + @mock.patch('azext_ssh.ssh_info.SSHSession.get_host') + @mock.patch('os.path.dirname') + @mock.patch('multiprocessing.Process.start') + @mock.patch('azext_ssh.ssh_utils._print_error_messages_from_ssh_log') + def test_start_ssh_connection(self, mock_print_error, mock_start, mock_dirname, mock_host, mock_build, mock_call, mock_path, mock_join): mock_path.return_value = "ssh" - mock_host.return_value = "user@ip" - mock_build.return_value = ['-i', 'file', '-o', 'option'] mock_join.return_value = "/log/file/path" + mock_build.return_value = ['-i', 'file', '-o', 'option'] + mock_host.return_value = "user@ip" + mock_dirname.return_value = "dirname" + mock_call.return_value = 0 expected_command = ["ssh", "user@ip", "-i", "file", "-o", "option", "-E", "/log/file/path", "-v"] - ssh_utils.start_ssh_connection("port", None, "ip", "user", "cert", "private", True, True) - - mock_path.assert_called_once_with() - mock_host.assert_called_once_with("user", "ip") - mock_build.assert_called_once_with("cert", "private", "port") + op_info = mock.Mock() + op_info.ip = "ip" + op_info.port = "port" + op_info.local_user = "user" + op_info.private_key_file = "private" + op_info.public_key_file = "public" + op_info.cert_file = "cert" + op_info.ssh_args = None + op_info.ssh_client_folder = "client/folder" + op_info.build_args = mock_build + op_info.get_host = mock_host + + ssh_utils.start_ssh_connection(op_info, True, True) + mock_start.assert_called_once() + mock_print_error.assert_called_once_with("/log/file/path", 0) + mock_path.assert_called_once_with('ssh', 'client/folder') mock_call.assert_called_once_with(expected_command, shell=platform.system() == 'Windows') - - @mock.patch.object(ssh_utils, '_get_ssh_path') - @mock.patch.object(ssh_utils, '_get_host') + + @mock.patch.object(ssh_utils, '_get_ssh_client_path') @mock.patch('subprocess.call') - def test_start_ssh_connection_with_args(self, mock_call, mock_host, mock_path): + @mock.patch('azext_ssh.ssh_info.SSHSession.build_args') + @mock.patch('azext_ssh.ssh_info.SSHSession.get_host') + def test_start_ssh_connection_with_args(self, mock_host, mock_build, mock_call, mock_path): mock_path.return_value = "ssh" mock_host.return_value = "user@ip" + mock_build.return_value = ["-i", "private", "-o", "CertificateFile=cert", "-p", "2222"] expected_command = ["ssh", "user@ip", "-i", "private", "-o", "CertificateFile=cert", "-p", "2222", "--thing", "-vv"] - ssh_utils.start_ssh_connection("2222", ["--thing", "-vv"], "ip", "user", "cert", "private", True, True) - - mock_path.assert_called_once_with() - mock_host.assert_called_once_with("user", "ip") + op_info = mock.Mock() + op_info.ip = "ip" + op_info.port = "2222" + op_info.local_user = "user" + op_info.private_key_file = "private" + op_info.public_key_file = "public" + op_info.cert_file = "cert" + op_info.ssh_args = ["--thing", "-vv"] + op_info.ssh_client_folder = "client/folder" + op_info.build_args = mock_build + op_info.get_host = mock_host + + ssh_utils.start_ssh_connection(op_info, True, True) + + mock_path.assert_called_once_with('ssh', 'client/folder') mock_call.assert_called_once_with(expected_command, shell=platform.system() == 'Windows') - - @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') - def test_write_ssh_config_ip_and_vm(self, mock_validity): - mock_validity.return_value = None + @mock.patch.object(ssh_utils, 'get_certificate_start_and_end_times') + @mock.patch('azext_ssh.ssh_info.ConfigSession.get_config_text') + def test_write_ssh_config_ip_and_vm(self, mock_get_text, mock_validity): expected_lines = [ "", "Host rg-vm", @@ -65,19 +92,35 @@ def test_write_ssh_config_ip_and_vm(self, mock_validity): "\tIdentityFile privatekey", "\tPort port" ] + + mock_validity.return_value = None + mock_get_text.return_value = expected_lines + + op_info = mock.Mock() + op_info.config_path = "path/to/file" + op_info.resource_group_name = "rg" + op_info.vm_name = "vm" + op_info.overwrite = True + op_info.port = "port" + op_info.ip = "1.2.3.4" + op_info.local_user = "username" + op_info.cert_file = "cert" + op_info.private_key_file = "privatekey" + op_info.ssh_client_folder = "client/folder" + op_info.get_config_text = mock_get_text with mock.patch('builtins.open') as mock_open: mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file - ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", True, "port", "1.2.3.4", "username", "cert", "privatekey", True, False - ) - mock_validity.assert_called_once_with("cert") - mock_open.assert_called_once_with("path/to/file", "w") + ssh_utils.write_ssh_config(op_info, True, False) + mock_validity.assert_called_once_with("cert", "client/folder") + mock_open.assert_called_once_with("path/to/file", "w", encoding='utf-8') mock_file.write.assert_called_once_with('\n'.join(expected_lines)) - @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') - def test_write_ssh_config_append(self, mock_validity): + + @mock.patch.object(ssh_utils, 'get_certificate_start_and_end_times') + @mock.patch('azext_ssh.ssh_info.ConfigSession.get_config_text') + def test_write_ssh_config_append(self, mock_get_text, mock_validity): expected_lines = [ "", "Host rg-vm", @@ -92,94 +135,105 @@ def test_write_ssh_config_append(self, mock_validity): ] mock_validity.return_value = None + mock_get_text.return_value = expected_lines + + op_info = mock.Mock() + op_info.config_path = "path/to/file" + op_info.resource_group_name = "rg" + op_info.vm_name = "vm" + op_info.overwrite = False + op_info.ip = "1.2.3.4" + op_info.local_user = "username" + op_info.cert_file = "cert" + op_info.private_key_file = "privatekey" + op_info.ssh_client_folder = "client/folder" + op_info.get_config_text = mock_get_text with mock.patch('builtins.open') as mock_open: mock_file = mock.Mock() mock_open.return_value.__enter__.return_value = mock_file ssh_utils.write_ssh_config( - "path/to/file", "rg", "vm", False, None, "1.2.3.4", "username", "cert", "privatekey", True, True + op_info, True, True ) - mock_validity.assert_called_once_with("cert") + mock_validity.assert_called_once_with("cert", "client/folder") - mock_open.assert_called_once_with("path/to/file", "a") + mock_open.assert_called_once_with("path/to/file", "a", encoding='utf-8') mock_file.write.assert_called_once_with('\n'.join(expected_lines)) - @mock.patch.object(ssh_utils, 'get_ssh_cert_validity') - def test_write_ssh_config_ip_only(self, mock_validity): - expected_lines = [ - "", - "Host 1.2.3.4", - "\tUser username", - "\tCertificateFile cert", - "\tIdentityFile privatekey" - ] - mock_validity.return_value = None - - with mock.patch('builtins.open') as mock_open: - mock_file = mock.Mock() - mock_open.return_value.__enter__.return_value = mock_file - ssh_utils.write_ssh_config( - "path/to/file", None, None, True, None, "1.2.3.4", "username", "cert", "privatekey", False, False - ) - - mock_validity.assert_not_called() - - mock_open.assert_called_once_with("path/to/file", "w") - mock_file.write.assert_called_once_with('\n'.join(expected_lines)) + @mock.patch('os.path.join') @mock.patch('platform.system') - def test_get_ssh_path_non_windows(self, mock_system): - mock_system.return_value = "Mac" - - actual_path = ssh_utils._get_ssh_path() - self.assertEqual('ssh', actual_path) - mock_system.assert_called_once_with() - - def test_get_ssh_path_windows_32bit(self): - self._test_ssh_path_windows('32bit', 'SysNative') - - def test_get_ssh_path_windows_64bit(self): - self._test_ssh_path_windows('64bit', 'System32') + @mock.patch('os.path.isfile') + def test_get_ssh_client_path_with_client_folder_non_windows(self, mock_isfile, mock_system, mock_join): + mock_join.return_value = "ssh_path" + mock_system.return_value = "Linux" + mock_isfile.return_value = True + actual_path = ssh_utils._get_ssh_client_path(ssh_client_folder='/client/folder') + self.assertEqual(actual_path, "ssh_path") + mock_join.assert_called_once_with('/client/folder', 'ssh') + mock_isfile.assert_called_once_with("ssh_path") + @mock.patch('os.path.join') @mock.patch('platform.system') - @mock.patch('platform.architecture') - @mock.patch('os.environ') @mock.patch('os.path.isfile') - def test_get_ssh_path_windows_ssh_not_found(self, mock_isfile, mock_environ, mock_arch, mock_sys): - mock_sys.return_value = "Windows" - mock_arch.return_value = ("32bit", "foo", "bar") - mock_environ.__getitem__.return_value = "rootpath" + def test_get_ssh_client_path_with_client_folder_windows(self, mock_isfile, mock_system, mock_join): + mock_join.return_value = "ssh_keygen_path" + mock_system.return_value = "Windows" + mock_isfile.return_value = True + actual_path = ssh_utils._get_ssh_client_path(ssh_command='ssh-keygen', ssh_client_folder='/client/folder') + self.assertEqual(actual_path, "ssh_keygen_path.exe") + mock_join.assert_called_once_with('/client/folder', 'ssh-keygen') + mock_isfile.assert_called_once_with("ssh_keygen_path.exe") + + @mock.patch('os.path.join') + @mock.patch('platform.system') + @mock.patch('os.path.isfile') + def test_get_ssh_client_path_with_client_folder_no_file(self, mock_isfile, mock_system, mock_join): + mock_join.return_value = "ssh_path" + mock_system.return_value = "Mac" mock_isfile.return_value = False + actual_path = ssh_utils._get_ssh_client_path(ssh_client_folder='/client/folder') + self.assertEqual(actual_path, "ssh") + mock_join.assert_called_once_with('/client/folder', 'ssh') + mock_isfile.assert_called_once_with("ssh_path") - self.assertRaises(azclierror.UnclassifiedUserFault, ssh_utils._get_ssh_path) + @mock.patch('platform.system') + def test_get_ssh_client_preinstalled_non_windows(self, mock_system): + mock_system.return_value = "Mac" + actual_path = ssh_utils._get_ssh_client_path() + self.assertEqual('ssh', actual_path) + mock_system.assert_called_once_with() - def test_get_host(self): - actual_host = ssh_utils._get_host("username", "10.0.0.1") - self.assertEqual("username@10.0.0.1", actual_host) + def test_get_ssh_client_preinstalled_windows_32bit(self): + self._test_get_ssh_client_path_preinstalled_windows('32bit', 'x86', 'System32') - def test_build_args(self): - actual_args = ssh_utils._build_args("cert", "privatekey", "2222") - expected_args = ["-i", "privatekey", "-o", "CertificateFile=cert", "-p", "2222"] - self.assertEqual(expected_args, actual_args) + def test_get_ssh_client_preinstalled_windows_64bitOS_32bitPlatform(self): + self._test_get_ssh_client_path_preinstalled_windows('32bit', 'x64', 'SysNative') + + def test_get_ssh_client_preinstalled_windows_64bitOS_64bitPlatform(self): + self._test_get_ssh_client_path_preinstalled_windows('64bit', 'x64', 'System32') @mock.patch('platform.system') @mock.patch('platform.architecture') + @mock.patch('platform.machine') @mock.patch('os.path.join') @mock.patch('os.environ') @mock.patch('os.path.isfile') - def _test_ssh_path_windows(self, arch, expected_sys_path, mock_isfile, mock_environ, mock_join, mock_arch, mock_system): + def _test_get_ssh_client_path_preinstalled_windows(self, platform_arch, os_arch, expected_sysfolder, mock_isfile, mock_environ, mock_join, mock_machine, mock_arch, mock_system): mock_system.return_value = "Windows" - mock_arch.return_value = (arch, "foo", "bar") + mock_arch.return_value = (platform_arch, "foo", "bar") + mock_machine.return_value = os_arch mock_environ.__getitem__.return_value = "rootpath" mock_join.side_effect = ["system32path", "sshfilepath"] mock_isfile.return_value = True + expected_join_calls = [ - mock.call("rootpath", expected_sys_path), + mock.call("rootpath", expected_sysfolder), mock.call("system32path", "openSSH", "ssh.exe") ] - - actual_path = ssh_utils._get_ssh_path() + + actual_path = ssh_utils._get_ssh_client_path() self.assertEqual("sshfilepath", actual_path) mock_system.assert_called_once_with() @@ -187,3 +241,18 @@ def _test_ssh_path_windows(self, arch, expected_sys_path, mock_isfile, mock_envi mock_environ.__getitem__.assert_called_once_with("SystemRoot") mock_join.assert_has_calls(expected_join_calls) mock_isfile.assert_called_once_with("sshfilepath") + + + @mock.patch('platform.system') + @mock.patch('platform.architecture') + @mock.patch('platform.machine') + @mock.patch('os.environ') + @mock.patch('os.path.isfile') + def test_get_ssh_path_windows_ssh_preinstalled_not_found(self, mock_isfile, mock_environ, mock_machine, mock_arch, mock_sys): + mock_sys.return_value = "Windows" + mock_arch.return_value = ("32bit", "foo", "bar") + mock_machine.return_value = "x64" + mock_environ.__getitem__.return_value = "rootpath" + mock_isfile.return_value = False + + self.assertRaises(azclierror.UnclassifiedUserFault, ssh_utils._get_ssh_client_path) diff --git a/src/ssh/setup.py b/src/ssh/setup.py index b9731a46c7b..18140019801 100644 --- a/src/ssh/setup.py +++ b/src/ssh/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages -VERSION = "1.0.0" +VERSION = "1.0.1" CLASSIFIERS = [ 'Development Status :: 4 - Beta',