Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SSH] Add SSHARC support to ARM64 architecture #7338

Merged
merged 11 commits into from
Mar 6, 2024
Merged
4 changes: 4 additions & 0 deletions src/ssh/HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
Release History
===============
upcoming
-----
* Add support to ARM64 clients when connecting to Arc Machines. Connect proxy now available for ARM64 architecture.

2.0.2
-----
* [Bug Fix] Fix logic that checks for the OS of the target machine to avoid "cannot unpack non-iterable NoneType object" error
Expand Down
1 change: 1 addition & 0 deletions src/ssh/azext_ssh/_process_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

# pylint: disable=too-few-public-methods
# pylint: disable=consider-using-with
# pylint: disable=superfluous-parens

import subprocess
from ctypes import WinDLL, c_int, c_size_t, Structure, WinError, sizeof, pointer
Expand Down
46 changes: 44 additions & 2 deletions src/ssh/azext_ssh/connectivity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def get_client_side_proxy(arc_proxy_folder):
os.chmod(install_location, os.stat(install_location).st_mode | stat.S_IXUSR)
print_styled_text((Style.SUCCESS, f"SSH Client Proxy saved to {install_location}"))

_download_proxy_license()

return install_location


Expand All @@ -235,7 +237,9 @@ def _get_proxy_filename_and_url(arc_proxy_folder):
logger.debug("Platform OS: %s", operating_system)
logger.debug("Platform architecture: %s", machine)

if machine.endswith('64') and 'ARM' not in machine.upper():
if "arm64" in machine.lower() or "aarch64" in machine.lower():
architecture = 'arm64'
elif machine.endswith('64'):
architecture = 'amd64'
elif machine.endswith('86'):
architecture = '386'
Expand All @@ -244,7 +248,7 @@ def _get_proxy_filename_and_url(arc_proxy_folder):
else:
raise azclierror.BadRequestError(f"Unsuported architecture: {machine} is not currently supported")

# define the request url and install location based on the os and architecture
# define the request url and install location based on the os and architecture.
proxy_name = f"sshProxy_{operating_system.lower()}_{architecture}"
request_uri = (f"{consts.CLIENT_PROXY_STORAGE_URL}/{consts.CLIENT_PROXY_RELEASE}"
f"/{proxy_name}_{consts.CLIENT_PROXY_VERSION}")
Expand All @@ -268,6 +272,44 @@ def _get_proxy_filename_and_url(arc_proxy_folder):
return request_uri, install_location, older_location


def _download_proxy_license():
proxy_dir = os.path.join('~', ".clientsshproxy")
license_uri = f"{consts.CLIENT_PROXY_STORAGE_URL}/{consts.CLIENT_PROXY_RELEASE}/LICENSE.txt"
license_install_location = os.path.expanduser(os.path.join(proxy_dir, "LICENSE.txt"))

notice_uri = f"{consts.CLIENT_PROXY_STORAGE_URL}/{consts.CLIENT_PROXY_RELEASE}/ThirdPartyNotice.txt"
notice_install_location = os.path.expanduser(os.path.join(proxy_dir, "ThirdPartyNotice.txt"))

_get_and_write_proxy_license_files(license_uri, license_install_location, "License")
_get_and_write_proxy_license_files(notice_uri, notice_install_location, "Third Party Notice")


def _get_and_write_proxy_license_files(uri, install_location, target_name):
try:
license_content = _download_from_uri(uri)
file_utils.write_to_file(file_path=install_location,
mode='wb',
content=license_content,
error_message=f"Failed to create {target_name} file at {install_location}.")
# pylint: disable=broad-except
except Exception:
logger.warning("Failed to download Connection Proxy %s file from %s.", target_name, uri)

print_styled_text((Style.SUCCESS, f"SSH Connection Proxy {target_name} saved to {install_location}."))


def _download_from_uri(request_uri):
response_content = None
with urllib.request.urlopen(request_uri) as response:
response_content = response.read()
response.close()

if response_content is None:
raise azclierror.ClientRequestError(f"Failed to download file from {request_uri}")

return response_content


def format_relay_info_string(relay_info):
relay_info_string = json.dumps(
{
Expand Down
4 changes: 2 additions & 2 deletions src/ssh/azext_ssh/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

AGENT_MINIMUM_VERSION_MAJOR = 1
AGENT_MINIMUM_VERSION_MINOR = 31
CLIENT_PROXY_VERSION = "1.3.022941"
CLIENT_PROXY_RELEASE = "release21-04-23"
CLIENT_PROXY_VERSION = "1.3.026031"
CLIENT_PROXY_RELEASE = "release17-02-24"
CLIENT_PROXY_STORAGE_URL = "https://sshproxysa.blob.core.windows.net"
CLEANUP_TOTAL_TIME_LIMIT_IN_SECONDS = 120
CLEANUP_TIME_INTERVAL_IN_SECONDS = 10
Expand Down
1 change: 0 additions & 1 deletion src/ssh/azext_ssh/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def ssh_vm(cmd, resource_group_name=None, vm_name=None, ssh_ip=None, public_key_
if platform.system() != 'Windows':
raise azclierror.BadRequestError("RDP connection is not supported for this platform. "
"Supported platforms: Windows")
logger.warning("RDP feature is in preview.")
op_call = rdp_utils.start_rdp_connection

ssh_session = ssh_info.SSHSession(resource_group_name, vm_name, ssh_ip, public_key_file,
Expand Down
2 changes: 2 additions & 0 deletions src/ssh/azext_ssh/rdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

# pylint: disable=used-before-assignment

import os
import platform
import subprocess
Expand Down
2 changes: 1 addition & 1 deletion src/ssh/azext_ssh/ssh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _check_for_known_errors(error_message, delete_cert, log_lines):

def check_for_service_config_delay_error(error_message):
service_config_delay_error = False
regex = ("{\"level\":\"fatal\",\"msg\":\"sshproxy: error connecting to the address: 404 Endpoint does not exist.*")
regex = "{\"level\":\"fatal\",\"msg\":\"sshproxy: error connecting to the address: 404 Endpoint does not exist.*"
if re.search(regex, error_message):
service_config_delay_error = True
return service_config_delay_error
Expand Down
59 changes: 28 additions & 31 deletions src/ssh/azext_ssh/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,33 @@ class SshCustomCommandTest(unittest.TestCase):
@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.ssh_info.SSHSession')
@mock.patch('azext_ssh.resource_type_utils.decide_resource_type')
def test_ssh_vm(self, mock_type, mock_info, mock_assert, mock_do_op):
def test_ssh_vm(self, mock_type, mock_info, mock_assert, mock_do_op):
cmd = mock.Mock()
ssh_info = mock.Mock()
mock_info.return_value = ssh_info

cmd.cli_ctx.data = {'safe_params': []}

custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", False, False, ['-vvv'])

mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, False, False)
mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username")
mock_type.assert_called_once_with(cmd, ssh_info)
mock_do_op.assert_called_once_with(cmd, ssh_info, ssh_utils.start_ssh_connection)

@mock.patch('azext_ssh.custom._do_ssh_op')
@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.ssh_info.SSHSession')
@mock.patch('azext_ssh.resource_type_utils.decide_resource_type')
@mock.patch('platform.system')
def test_ssh_vm_rdp(self, mock_sys, mock_type, mock_info, mock_assert, mock_do_op):
def test_ssh_vm_rdp(self, mock_sys, mock_type, mock_info, mock_assert, mock_do_op):
cmd = mock.Mock()
ssh_info = mock.Mock()
mock_info.return_value = ssh_info
mock_sys.return_value = 'Windows'

cmd.cli_ctx.data = {'safe_params': []}

custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", True, False, ['-vvv'])

mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, True, False)
Expand All @@ -57,13 +57,13 @@ def test_ssh_vm_rdp(self, mock_sys, mock_type, mock_info, mock_assert, mock_do_o
@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.ssh_info.SSHSession')
@mock.patch('azext_ssh.resource_type_utils.decide_resource_type')
def test_ssh_vm_debug(self, mock_type, mock_info, mock_assert, mock_do_op):
def test_ssh_vm_debug(self, mock_type, mock_info, mock_assert, mock_do_op):
cmd = mock.Mock()
ssh_info = mock.Mock()
mock_info.return_value = ssh_info

cmd.cli_ctx.data = {'safe_params': ['--debug']}

custom.ssh_vm(cmd, "rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", False, "type", "proxy", False, False, [])

mock_info.assert_called_once_with("rg", "vm", "ip", "public", "private", False, "username", "cert", "port", "ssh_folder", ['-vvv'], False, "type", "proxy", None, False, False)
Expand All @@ -89,7 +89,7 @@ def test_ssh_vm_delete_credentials_cloudshell(self, mock_info, mock_assert, mock
mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "username")
mock_type.assert_called_once_with(cmd, ssh_info)
mock_op.assert_called_once_with(cmd, ssh_info, ssh_utils.start_ssh_connection)

@mock.patch('os.environ.get')
def test_delete_credentials_not_cloudshell(self, mock_getenv):
mock_getenv.return_value = None
Expand All @@ -106,7 +106,7 @@ def test_delete_credentials_not_cloudshell(self, mock_getenv):
@mock.patch('os.path.join')
def test_ssh_config_no_cred_folder(self, mock_join, mock_info, mock_isdir, mock_dirname, mock_type, mock_do_op, mock_assert):
cmd = mock.Mock()

config_info = mock.Mock()
config_info.ip = "ip"
config_info.resource_group_name = "rg"
Expand All @@ -130,7 +130,7 @@ def test_ssh_config_no_cred_folder(self, mock_join, mock_info, mock_isdir, mock_
mock_type.assert_called_once_with(cmd, config_info)
mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "user")
mock_do_op.assert_called_once_with(cmd, config_info, ssh_utils.write_ssh_config)

@mock.patch('azext_ssh.custom._assert_args')
@mock.patch('azext_ssh.custom._do_ssh_op')
@mock.patch('azext_ssh.resource_type_utils.decide_resource_type')
Expand All @@ -145,13 +145,12 @@ def test_ssh_config_cred_folder(self, mock_dirname, mock_isdir, mock_info, mock_
mock_dirname.return_value = "config_folder"

custom.ssh_config(cmd, "config", "rg", "vm", "ip", None, None, True, False, "user", "cert", "port", "type", "cred", "proxy", "client", False)

mock_type.assert_called_once_with(cmd, config_info)
mock_info.assert_called_once_with("config", "rg", "vm", "ip", None, None, True, False, "user", "cert", "port", "type", "cred", "proxy", "client", False)
mock_assert.assert_called_once_with("rg", "vm", "ip", "type", "cert", "user")
mock_op.assert_called_once_with(cmd, config_info, ssh_utils.write_ssh_config)


def test_ssh_config_credentials_folder_and_key(self):
cmd = mock.Mock()
self.assertRaises(
Expand All @@ -169,7 +168,7 @@ def test_ssh_cert_no_args(self):
cmd = mock.Mock()
self.assertRaises(
azclierror.RequiredArgumentMissingError, custom.ssh_cert, cmd)

@mock.patch('os.path.isdir')
def test_ssh_cert_cert_file_missing(self, mock_isdir):
cmd = mock.Mock()
Expand All @@ -187,15 +186,15 @@ def test_ssh_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdir
mock_abspath.side_effect = ['/pubkey/path', '/cert/path', '/client/path']
mock_get_keys.return_value = "pubkey", "privkey", False
mock_write_cert.return_value = "cert", "username"

custom.ssh_cert(cmd, "cert", "pubkey", "ssh/folder")

mock_get_keys.assert_called_once_with('/pubkey/path', None, None, '/client/path')
mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', '/client/path')

def test_assert_args_invalid_resource_type(self):
self.assertRaises(azclierror.InvalidArgumentValueError, custom._assert_args, 'rg', 'vm', 'ip', "Microsoft.Network", 'cert', 'user')

def test_assert_args_no_ip_or_vm(self):
self.assertRaises(azclierror.RequiredArgumentMissingError, custom._assert_args, None, None, None, None, None, None)

Expand All @@ -207,7 +206,7 @@ def test_assert_args_ip_with_vm_or_rg(self):
self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, "vm", "ip", None, None, None)
self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", None, "ip", None, None, None)
self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, "rg", "vm", "ip", None, None, None)

def test_assert_args_cert_with_no_user(self):
self.assertRaises(azclierror.MutuallyExclusiveArgumentError, custom._assert_args, None, None, "ip", None, "certificate", None)

Expand Down Expand Up @@ -247,7 +246,7 @@ def test_check_or_create_public_private_files_defaults(self, mock_join, mock_isf
@mock.patch('os.path.isdir')
@mock.patch('os.path.isfile')
@mock.patch('os.path.join')
def test_check_or_create_public_private_files_defaults_with_cred_folder(self,mock_join, mock_isfile, mock_isdir, mock_create):
def test_check_or_create_public_private_files_defaults_with_cred_folder(self, mock_join, mock_isfile, mock_isdir, mock_create):
mock_isfile.return_value = True
mock_isdir.return_value = True
mock_join.side_effect = ['/cred/folder/id_rsa.pub', '/cred/folder/id_rsa']
Expand All @@ -266,7 +265,7 @@ def test_check_or_create_public_private_files_defaults_with_cred_folder(self,moc
mock_create.assert_has_calls([
mock.call('/cred/folder/id_rsa', '/ssh/client')
])

@mock.patch('os.path.isfile')
def test_check_or_create_public_private_files_no_public(self, mock_isfile):
mock_isfile.side_effect = [False]
Expand All @@ -284,7 +283,6 @@ def test_check_or_create_public_private_files_no_public(self, mock_isfile):
self.assertEqual(delete, False)
mock_isfile.assert_has_calls([mock.call("key.pub"), mock.call('key')])


@mock.patch('os.path.isfile')
@mock.patch('os.path.join')
def test_check_or_create_public_private_files_no_private(self, mock_join, mock_isfile):
Expand All @@ -298,7 +296,6 @@ def test_check_or_create_public_private_files_no_private(self, mock_join, mock_i
mock.call("public"),
mock.call("private")
])


@mock.patch('builtins.open')
@mock.patch('oschmod.set_mode')
Expand All @@ -311,7 +308,7 @@ def test_write_cert_file(self, mock_mode, mock_open):
mock_mode.assert_called_once_with("publickey-aadcert.pub", 0o644)
mock_open.assert_called_once_with("publickey-aadcert.pub", 'w', encoding='utf-8')
mock_file.write.assert_called_once_with("ssh-rsa-cert-v01@openssh.com cert")

@mock.patch('azext_ssh.rsa_parser.RSAParser')
@mock.patch('os.path.isfile')
@mock.patch('builtins.open')
Expand Down Expand Up @@ -345,7 +342,7 @@ def test_get_modulus_exponent_parse_error(self, mock_open, mock_isfile, mock_par
mock_parser_obj.parse.side_effect = ValueError

self.assertRaises(azclierror.FileOperationError, custom._get_modulus_exponent, 'file')

@mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals')
@mock.patch('os.path.join')
@mock.patch('azext_ssh.custom._check_or_create_public_private_files')
Expand Down Expand Up @@ -373,9 +370,9 @@ def test_do_ssh_op_aad_user_compute(self, mock_write_cert, mock_ssh_creds, mock_
profile._adal_cache = True
profile.get_msal_token.return_value = "username", "certificate"
mock_join.return_value = "public-aadcert.pub"

custom._do_ssh_op(cmd, op_info, mock_op)

mock_check_files.assert_called_once_with("publicfile", "privatefile", None, "/client/folder")
mock_ip.assert_not_called()
mock_get_mod_exp.assert_called_once_with("public")
Expand All @@ -393,7 +390,7 @@ def test_do_ssh_op_local_user_compute(self, mock_ip, mock_check_files):
op_info.public_key_file = "publicfile"
op_info.private_key_file = "privatefile"
op_info.cert_file = "cert"
op_info.ssh_client_folder = "/client/folder"
op_info.ssh_client_folder = "/client/folder"

custom._do_ssh_op(cmd, op_info, mock_op)

Expand All @@ -416,15 +413,15 @@ def test_do_ssh_op_no_public_ip(self, mock_ip, mock_check_files):
mock_check_files.assert_not_called()
mock_ip.assert_called_once_with(cmd, "rg", "vm", False)
mock_op.assert_not_called()

@mock.patch('azext_ssh.connectivity_utils.get_client_side_proxy')
@mock.patch('azext_ssh.connectivity_utils.get_relay_information')
@mock.patch('azext_ssh.ssh_utils.start_ssh_connection')
@mock.patch('azext_ssh.custom._check_or_create_public_private_files')
@mock.patch('azext_ssh.custom._get_and_write_certificate')
def test_do_ssh_op_arc_local_user(self, mock_get_cert, mock_check_keys, mock_start_ssh, mock_get_relay_info, mock_get_proxy):
mock_get_relay_info.return_value = ('relay', False)
cmd = mock.Mock()
cmd = mock.Mock()
mock_op = mock.Mock()

op_info = ssh_info.SSHSession("rg", "vm", None, None, None, False, "user", None, "port", None, [], False, "Microsoft.HybridCompute/machines", None, None, False, False)
Expand All @@ -434,13 +431,13 @@ def test_do_ssh_op_arc_local_user(self, mock_get_cert, mock_check_keys, mock_sta
op_info.ssh_proxy_folder = "proxy"

custom._do_ssh_op(cmd, op_info, mock_op)

mock_get_proxy.assert_called_once_with('proxy')
mock_get_relay_info.assert_called_once_with(cmd, 'rg', 'vm', 'Microsoft.HybridCompute/machines', None, "port", False)
mock_op.assert_called_once_with(op_info, False, False)
mock_get_cert.assert_not_called()
mock_check_keys.assert_not_called()

@mock.patch('azext_ssh.connectivity_utils.get_client_side_proxy')
@mock.patch('azext_ssh.custom.connectivity_utils.get_relay_information')
@mock.patch('azext_ssh.ssh_utils.get_ssh_cert_principals')
Expand All @@ -451,7 +448,7 @@ def test_do_ssh_op_arc_local_user(self, mock_get_cert, mock_check_keys, mock_sta
@mock.patch('azext_ssh.custom._write_cert_file')
@mock.patch('azext_ssh.ssh_utils.start_ssh_connection')
@mock.patch('azext_ssh.ssh_utils.get_certificate_lifetime')
def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_check_files,
def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_cert, mock_ssh_creds, mock_get_mod_exp, mock_check_files,
mock_join, mock_principal, mock_get_relay_info, mock_get_proxy):

mock_get_proxy.return_value = '/path/to/proxy'
Expand Down Expand Up @@ -489,6 +486,6 @@ def test_do_ssh_arc_op_aad_user(self, mock_cert_exp, mock_start_ssh, mock_write_
mock_get_proxy.assert_called_once_with('proxy')
mock_get_relay_info.assert_called_once_with(cmd, 'rg', 'vm', 'Microsoft.HybridCompute/machines', 3600, 'port', False)
mock_op.assert_called_once_with(op_info, False, True)

if __name__ == '__main__':
unittest.main()
Loading