Skip to content

Commit

Permalink
SSH Extension v.1.0.1 (#4474)
Browse files Browse the repository at this point in the history
  • Loading branch information
vthiebaut10 committed Mar 10, 2022
1 parent 17164cb commit 4e4b9c2
Show file tree
Hide file tree
Showing 11 changed files with 777 additions and 271 deletions.
10 changes: 10 additions & 0 deletions src/ssh/HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
Release History
===============
1.0.1
-----
* Added --ssh-client-folder parameter.
* Fixed issues caused when there are spaces or non-english characters in paths provided by users.
* Ensure all paths provided by users are converted to absolute paths.
* Print OpenSSH error messages to console on "az ssh vm".
* Print level1 SSH client log messages when running "az ssh vm" in debug mode.
* Change "isPreview".
* Correctly find pre-installed OpenSSH binaries on Windows 32bit machines.

1.0.0
-----
* Delete all keys and certificates created during execution of ssh vm.
Expand Down
9 changes: 9 additions & 0 deletions src/ssh/azext_ssh/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def load_arguments(self, _):
c.argument('cert_file', options_list=['--certificate-file', '-c'],
help='Path to a certificate file used for authentication when using local user credentials.')
c.argument('port', options_list=['--port'], help='SSH port')
c.argument('ssh_client_folder', options_list=['--ssh-client-folder'],
help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). '
'Default to ssh pre-installed if not provided.')
c.positional('ssh_args', nargs='*', help='Additional arguments passed to OpenSSH')

with self.argument_context('ssh config') as c:
Expand All @@ -37,10 +40,16 @@ def load_arguments(self, _):
help='Folder where new generated keys will be stored.')
c.argument('cert_file', options_list=['--certificate-file', '-c'], help='Path to certificate file')
c.argument('port', options_list=['--port'], help='SSH port')
c.argument('ssh_client_folder', options_list=['--ssh-client-folder'],
help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). '
'Default to ssh pre-installed if not provided.')

with self.argument_context('ssh cert') as c:
c.argument('cert_path', options_list=['--file', '-f'],
help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appened')
c.argument('public_key_file', options_list=['--public-key-file', '-p'],
help='The RSA public key file path. If not provided, '
'generated key pair is stored in the same directory as --file.')
c.argument('ssh_client_folder', options_list=['--ssh-client-folder'],
help='Folder path that contains ssh executables (ssh.exe, ssh-keygen.exe, etc). '
'Default to ssh pre-installed if not provided.')
2 changes: 1 addition & 1 deletion src/ssh/azext_ssh/azext_metadata.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"azext.isPreview": true,
"azext.isPreview": false,
"azext.minCliCoreVersion": "2.4.0"
}
131 changes: 90 additions & 41 deletions src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,99 +3,145 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import functools
import os
import hashlib
import json
import tempfile

import colorama
from colorama import Fore
from colorama import Style

from knack import log
from azure.cli.core import azclierror

from . import ip_utils
from . import rsa_parser
from . import ssh_utils
from . import ssh_info

logger = log.get_logger(__name__)


def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_file=None,
private_key_file=None, use_private_ip=False, local_user=None, cert_file=None, port=None,
ssh_args=None):
ssh_client_folder=None, ssh_args=None):

if '--debug' in cmd.cli_ctx.data['safe_params'] and set(['-v', '-vv', '-vvv']).isdisjoint(ssh_args):
ssh_args = ['-v'] if not ssh_args else ['-v'] + ssh_args

_assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user)
credentials_folder = None
op_call = functools.partial(ssh_utils.start_ssh_connection, port, ssh_args)
_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip,
local_user, cert_file, credentials_folder, op_call)
op_call = ssh_utils.start_ssh_connection
ssh_session = ssh_info.SSHSession(resource_group_name, vm_name, ssh_ip, public_key_file,
private_key_file, use_private_ip, local_user, cert_file, port,
ssh_client_folder, ssh_args)
_do_ssh_op(cmd, ssh_session, credentials_folder, op_call)


def ssh_config(cmd, config_path, resource_group_name=None, vm_name=None, ssh_ip=None,
public_key_file=None, private_key_file=None, overwrite=False, use_private_ip=False,
local_user=None, cert_file=None, port=None, credentials_folder=None):
local_user=None, cert_file=None, port=None, credentials_folder=None, ssh_client_folder=None):

_assert_args(resource_group_name, vm_name, ssh_ip, cert_file, local_user)
# If user provides their own key pair, certificate will be written in the same folder as public key.
if (public_key_file or private_key_file) and credentials_folder:
raise azclierror.ArgumentUsageError("--keys-destination-folder can't be used in conjunction with "
"--public-key-file/-p or --private-key-file/-i.")

op_call = functools.partial(ssh_utils.write_ssh_config, config_path, resource_group_name, vm_name, overwrite, port)
config_session = ssh_info.ConfigSession(config_path, resource_group_name, vm_name, ssh_ip, public_key_file,
private_key_file, overwrite, use_private_ip, local_user, cert_file,
port, ssh_client_folder)

op_call = ssh_utils.write_ssh_config

# if the folder doesn't exist, this extension won't create a new one.
config_folder = os.path.dirname(config_session.config_path)
if not os.path.isdir(config_folder):
raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} "
"does not exist.")

# Default credential location
# Add logic to test if credentials folder is valid
if not credentials_folder:
config_folder = os.path.dirname(config_path)
if not os.path.isdir(config_folder):
raise azclierror.InvalidArgumentValueError(f"Config file destination folder {config_folder} "
"does not exist.")
folder_name = ssh_ip
if resource_group_name and vm_name:
folder_name = resource_group_name + "-" + vm_name
# * is not a valid name for a folder in Windows. Treat this as a special case.
folder_name = config_session.ip if config_session.ip != "*" else "all_ips"
if config_session.resource_group_name and config_session.vm_name:
folder_name = config_session.resource_group_name + "-" + config_session.vm_name
credentials_folder = os.path.join(config_folder, os.path.join("az_ssh_config", folder_name))
else:
credentials_folder = os.path.abspath(credentials_folder)

_do_ssh_op(cmd, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip,
local_user, cert_file, credentials_folder, op_call)
_do_ssh_op(cmd, config_session, credentials_folder, op_call)


def ssh_cert(cmd, cert_path=None, public_key_file=None):
def ssh_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None):
if not cert_path and not public_key_file:
raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.")
if cert_path and not os.path.isdir(os.path.dirname(cert_path)):
raise azclierror.InvalidArgumentValueError(f"{os.path.dirname(cert_path)} folder doesn't exist")

if public_key_file:
public_key_file = os.path.abspath(public_key_file)
if cert_path:
cert_path = os.path.abspath(cert_path)
if ssh_client_folder:
ssh_client_folder = os.path.abspath(ssh_client_folder)

# If user doesn't provide a public key, save generated key pair to the same folder as --file
keys_folder = None
if not public_key_file:
keys_folder = os.path.dirname(cert_path)
logger.warning("The generated SSH keys are stored at %s. Please delete SSH keys when the certificate "
"is no longer being used.", keys_folder)
public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder)
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path)
print(cert_file + "\n")

public_key_file, _, _ = _check_or_create_public_private_files(public_key_file, None, keys_folder, ssh_client_folder)
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder)

if keys_folder:
logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). "
"Please delete once this certificate is no longer being used.", keys_folder)

colorama.init()
# pylint: disable=broad-except
try:
cert_expiration = ssh_utils.get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1]
print(Fore.GREEN + f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time."
+ Style.RESET_ALL)
except Exception as e:
logger.warning("Couldn't determine certificate validity. Error: %s", str(e))
print(Fore.GREEN + f"Generated SSH certificate {cert_file}." + Style.RESET_ALL)


def _do_ssh_op(cmd, resource_group, vm_name, ssh_ip, public_key_file, private_key_file, use_private_ip,
username, cert_file, credentials_folder, op_call):
def _do_ssh_op(cmd, op_info, credentials_folder, op_call):
# Get ssh_ip before getting public key to avoid getting "ResourceNotFound" exception after creating the keys
ssh_ip = ssh_ip or ip_utils.get_ssh_ip(cmd, resource_group, vm_name, use_private_ip)
op_info.ip = op_info.ip or ip_utils.get_ssh_ip(cmd, op_info.resource_group_name,
op_info.vm_name, op_info.use_private_ip)

if not ssh_ip:
if not use_private_ip:
raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public IP address to SSH to")
if not op_info.ip:
if not op_info.use_private_ip:
raise azclierror.ResourceNotFoundError(f"VM '{op_info.vm_name}' does not have a public "
"IP address to SSH to")

raise azclierror.ResourceNotFoundError(f"VM '{vm_name}' does not have a public or private IP address to SSH to")
raise azclierror.ResourceNotFoundError("Internal Error. Couldn't determine the IP address.")

# If user provides local user, no credentials should be deleted.
delete_keys = False
delete_cert = False
# If user provides a local user, use the provided credentials for authentication
if not username:
if not op_info.local_user:
delete_cert = True
public_key_file, private_key_file, delete_keys = _check_or_create_public_private_files(public_key_file,
private_key_file,
credentials_folder)
cert_file, username = _get_and_write_certificate(cmd, public_key_file, None)
op_info.public_key_file, op_info.private_key_file, delete_keys = \
_check_or_create_public_private_files(op_info.public_key_file,
op_info.private_key_file,
credentials_folder,
op_info.ssh_client_folder)

op_call(ssh_ip, username, cert_file, private_key_file, delete_keys, delete_cert)
op_info.cert_file, op_info.local_user = _get_and_write_certificate(cmd, op_info.public_key_file,
None, op_info.ssh_client_folder)

op_call(op_info, delete_keys, delete_cert)

def _get_and_write_certificate(cmd, public_key_file, cert_file):

def _get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder):
cloudtoscope = {
"azurecloud": "https://pas.windows.net/CheckMyAccess/Linux/.default",
"azurechinacloud": "https://pas.chinacloudapi.cn/CheckMyAccess/Linux/.default",
Expand Down Expand Up @@ -124,9 +170,11 @@ def _get_and_write_certificate(cmd, public_key_file, cert_file):

if not cert_file:
cert_file = public_key_file + "-aadcert.pub"

logger.debug("Generating certificate %s", cert_file)
_write_cert_file(certificate, cert_file)
# instead we use the validprincipals from the cert due to mismatched upn and email in guest scenarios
username = ssh_utils.get_ssh_cert_principals(cert_file)[0]
username = ssh_utils.get_ssh_cert_principals(cert_file, ssh_client_folder)[0]
return cert_file, username.lower()


Expand Down Expand Up @@ -172,7 +220,8 @@ def _assert_args(resource_group, vm_name, ssh_ip, cert_file, username):
raise azclierror.FileOperationError(f"Certificate file {cert_file} not found")


def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder):
def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder,
ssh_client_folder=None):
delete_keys = False
# If nothing is passed, then create a directory with a ephemeral keypair
if not public_key_file and not private_key_file:
Expand All @@ -189,7 +238,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre
os.makedirs(credentials_folder)
public_key_file = os.path.join(credentials_folder, "id_rsa.pub")
private_key_file = os.path.join(credentials_folder, "id_rsa")
ssh_utils.create_ssh_keyfile(private_key_file)
ssh_utils.create_ssh_keyfile(private_key_file, ssh_client_folder)

if not public_key_file:
if private_key_file:
Expand All @@ -210,7 +259,7 @@ def _check_or_create_public_private_files(public_key_file, private_key_file, cre


def _write_cert_file(certificate_contents, cert_file):
with open(cert_file, 'w') as f:
with open(cert_file, 'w', encoding='utf-8') as f:
f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}")

return cert_file
Expand All @@ -220,7 +269,7 @@ def _get_modulus_exponent(public_key_file):
if not os.path.isfile(public_key_file):
raise azclierror.FileOperationError(f"Public key file '{public_key_file}' was not found")

with open(public_key_file, 'r') as f:
with open(public_key_file, 'r', encoding='utf-8') as f:
public_key_text = f.read()

parser = rsa_parser.RSAParser()
Expand Down
2 changes: 2 additions & 0 deletions src/ssh/azext_ssh/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def mkdir_p(path):


def delete_file(file_path, message, warning=False):
# pylint: disable=broad-except
if os.path.isfile(file_path):
try:
os.remove(file_path)
Expand All @@ -39,6 +40,7 @@ def delete_file(file_path, message, warning=False):


def delete_folder(dir_path, message, warning=False):
# pylint: disable=broad-except
if os.path.isdir(dir_path):
try:
os.rmdir(dir_path)
Expand Down
95 changes: 95 additions & 0 deletions src/ssh/azext_ssh/ssh_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# --------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import os
from azure.cli.core import azclierror


class SSHSession():
# pylint: disable=too-many-instance-attributes
def __init__(self, resource_group_name, vm_name, ssh_ip, public_key_file, private_key_file,
use_private_ip, local_user, cert_file, port, ssh_client_folder, ssh_args):
self.resource_group_name = resource_group_name
self.vm_name = vm_name
self.ip = ssh_ip
self.use_private_ip = use_private_ip
self.local_user = local_user
self.port = port
self.ssh_args = ssh_args
self.public_key_file = os.path.abspath(public_key_file) if public_key_file else None
self.private_key_file = os.path.abspath(private_key_file) if private_key_file else None
self.cert_file = os.path.abspath(cert_file) if cert_file else None
self.ssh_client_folder = os.path.abspath(ssh_client_folder) if ssh_client_folder else None

def get_host(self):
if self.local_user and self.ip:
return self.local_user + "@" + self.ip
raise azclierror.BadRequestError("Unable to determine host.")

def build_args(self):
private_key = []
port_arg = []
certificate = []
if self.private_key_file:
private_key = ["-i", self.private_key_file]
if self.port:
port_arg = ["-p", self.port]
if self.cert_file:
certificate = ["-o", "CertificateFile=\"" + self.cert_file + "\""]
return private_key + certificate + port_arg


class ConfigSession():
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes
def __init__(self, config_path, resource_group_name, vm_name, ssh_ip, public_key_file,
private_key_file, overwrite, use_private_ip, local_user, cert_file, port,
ssh_client_folder):
self.config_path = os.path.abspath(config_path)
self.resource_group_name = resource_group_name
self.vm_name = vm_name
self.ip = ssh_ip
self.overwrite = overwrite
self.use_private_ip = use_private_ip
self.local_user = local_user
self.port = port
self.public_key_file = os.path.abspath(public_key_file) if public_key_file else None
self.private_key_file = os.path.abspath(private_key_file) if private_key_file else None
self.cert_file = os.path.abspath(cert_file) if cert_file else None
self.ssh_client_folder = os.path.abspath(ssh_client_folder) if ssh_client_folder else None

def get_config_text(self):
lines = [""]
if self.resource_group_name and self.vm_name and self.ip:
lines = lines + self._get_rg_and_vm_entry()
# default to all hosts for config
if not self.ip:
self.ip = "*"
lines = lines + self._get_ip_entry()
return lines

def _get_rg_and_vm_entry(self):
lines = []
lines.append("Host " + self.resource_group_name + "-" + self.vm_name)
lines.append("\tUser " + self.local_user)
lines.append("\tHostName " + self.ip)
if self.cert_file:
lines.append("\tCertificateFile \"" + self.cert_file + "\"")
if self.private_key_file:
lines.append("\tIdentityFile \"" + self.private_key_file + "\"")
if self.port:
lines.append("\tPort " + self.port)
return lines

def _get_ip_entry(self):
lines = []
lines.append("Host " + self.ip)
lines.append("\tUser " + self.local_user)
if self.cert_file:
lines.append("\tCertificateFile \"" + self.cert_file + "\"")
if self.private_key_file:
lines.append("\tIdentityFile \"" + self.private_key_file + "\"")
if self.port:
lines.append("\tPort " + self.port)
return lines
Loading

0 comments on commit 4e4b9c2

Please sign in to comment.