diff --git a/lib/cloudregister/registerutils.py b/lib/cloudregister/registerutils.py index 173b499a..1e7157d7 100644 --- a/lib/cloudregister/registerutils.py +++ b/lib/cloudregister/registerutils.py @@ -54,7 +54,7 @@ def add_hosts_entry(smt_server): smt_hosts_entry_comment = '\n# Added by SMT registration do not remove, ' smt_hosts_entry_comment += 'retain comment as well\n' smt_ip = smt_server.get_ipv4() - if has_ipv6_access(smt_server): + if has_rmt_ipv6_access(smt_server): smt_ip = smt_server.get_ipv6() entry = '%s\t%s\t%s\n' % ( smt_ip, @@ -271,7 +271,7 @@ def fetch_smt_data(cfg, proxies, quiet=False): region_servers_ipv4.append(ip_addr) random.shuffle(region_servers_ipv4) random.shuffle(region_servers_ipv6) - if socket.has_ipv6: + if has_ipv6_access(): region_servers = region_servers_ipv6 + region_servers_ipv4 else: region_servers = region_servers_ipv4 @@ -894,10 +894,10 @@ def get_zypper_target_root(): # ---------------------------------------------------------------------------- -def has_ipv6_access(smt): +def has_rmt_ipv6_access(smt): """IPv6 access is possible if we have an SMT server that has an IPv6 address and it can be accessed over IPv6""" - if not smt.get_ipv6(): + if not has_ipv6_access() or not smt.get_ipv6(): return False logging.info('Attempt to access update server over IPv6') protocol = 'http' # Default for backward compatibility @@ -1442,6 +1442,31 @@ def write_framework_identifier(cfg): framework_file.write(json.dumps(identifier)) +# ---------------------------------------------------------------------------- +def has_ipv4_access(): + """Check if we have IPv4 network configuration""" + return has_network_access_by_ip_address('8.8.8.8') + + +# ---------------------------------------------------------------------------- +def has_ipv6_access(): + """Check if we have IPv6 network configuration""" + return has_network_access_by_ip_address('2001:4860:4860::8888') + + +# ---------------------------------------------------------------------------- +def has_network_access_by_ip_address(server_ip): + """Check if we can connect to the given server""" + try: + connection = socket.create_connection((server_ip, 443), timeout=2) + except OSError as e: + logging.info('Network access error: "%s"', e) + return False + + connection.close() + return True + + # Private # ---------------------------------------------------------------------------- def __get_framework_plugin(cfg): diff --git a/tests/test_registerutils.py b/tests/test_registerutils.py index f4829ff3..aab3474f 100644 --- a/tests/test_registerutils.py +++ b/tests/test_registerutils.py @@ -542,8 +542,8 @@ def test_clean_host_file_raised_exception(): assert m().write.mock_calls == [] -@patch('cloudregister.registerutils.has_ipv6_access') -def test_add_hosts_entry(mock_has_ipv6_access): +@patch('cloudregister.registerutils.has_rmt_ipv6_access') +def test_add_hosts_entry(mock_has_rmt_ipv6_access): """Test hosts entry has a new entry added by us.""" smt_data_ipv46 = dedent('''\ ''') smt_server = SMT(etree.fromstring(smt_data_ipv46)) - mock_has_ipv6_access.return_value = True + mock_has_rmt_ipv6_access.return_value = True with patch('builtins.open', create=True) as mock_open: mock_open.return_value = MagicMock(spec=io.IOBase) file_handle = mock_open.return_value.__enter__.return_value utils.add_hosts_entry(smt_server) - mock_open.assert_called_once_with('/etc/hosts', 'a') - file_content_comment = ( - '\n# Added by SMT registration do not remove, ' - 'retain comment as well\n' - ) - file_content_entry = '{ip}\t{fqdn}\t{name}\n'.format( - ip=smt_server.get_ipv6(), - fqdn=smt_server.get_FQDN(), - name=smt_server.get_name() - ) - assert file_handle.write.mock_calls == [ - call(file_content_comment), - call(file_content_entry) - ] + mock_open.assert_called_once_with('/etc/hosts', 'a') + file_content_comment = ( + '\n# Added by SMT registration do not remove, ' + 'retain comment as well\n' + ) + file_content_entry = '{ip}\t{fqdn}\t{name}\n'.format( + ip=smt_server.get_ipv6(), + fqdn=smt_server.get_FQDN(), + name=smt_server.get_name() + ) + assert file_handle.write.mock_calls == [ + call(file_content_comment), + call(file_content_entry) + ] @patch('cloudregister.amazonec2.generateRegionSrvArgs') @@ -747,17 +747,20 @@ def test_fetch_smt_data_metadata_server( etree.tostring(smt_server, encoding='utf-8') +@patch('cloudregister.registerutils.has_network_access_by_ip_address') @patch('cloudregister.registerutils.time.sleep') @patch('cloudregister.registerutils.logging') def test_fetch_smt_data_api_no_answer( mock_logging, - mock_time_sleep + mock_time_sleep, + mock_has_network_access ): cfg = get_test_config() del cfg['server']['metadata_server'] cfg.set('server', 'regionsrv', '1.1.1.1') with raises(SystemExit): utils.fetch_smt_data(cfg, None) + mock_has_network_access.return_value = False assert mock_logging.info.call_args_list == [ call('Using API: regionInfo'), call('Getting update server information, attempt 1'), @@ -788,7 +791,7 @@ def test_fetch_smt_data_api_no_answer( ] -@patch('cloudregister.registerutils.socket.has_ipv6', False) +@patch('cloudregister.registerutils.has_network_access_by_ip_address') @patch('cloudregister.registerutils.requests.get') @patch('cloudregister.registerutils.os.path.isfile') @patch('cloudregister.registerutils.time.sleep') @@ -798,6 +801,7 @@ def test_fetch_smt_data_api_answered( mock_time_sleep, mock_os_path_isfile, mock_request_get, + mock_has_network_access ): cfg = get_test_config() del cfg['server']['metadata_server'] @@ -815,6 +819,7 @@ def test_fetch_smt_data_api_answered( ''') response.text = smt_xml mock_request_get.return_value = response + mock_has_network_access.return_value = False utils.fetch_smt_data(cfg, None) assert mock_logging.info.call_args_list == [ call('Using API: regionInfo'), @@ -858,6 +863,7 @@ def test_fetch_smt_data_api_no_valid_ip( assert etree.tostring(smt_data, encoding='utf-8') == smt_xml.encode() +@patch('cloudregister.registerutils.has_network_access_by_ip_address') @patch('cloudregister.registerutils.requests.get') @patch('cloudregister.registerutils.os.path.isfile') @patch('cloudregister.registerutils.time.sleep') @@ -867,6 +873,7 @@ def test_fetch_smt_data_api_error_response( mock_time_sleep, mock_os_path_isfile, mock_request_get, + mock_has_network_access ): cfg = get_test_config() del cfg['server']['metadata_server'] @@ -876,8 +883,10 @@ def test_fetch_smt_data_api_error_response( response.status_code = 422 response.reason = 'well, you shall not pass' mock_request_get.return_value = response + mock_has_network_access.return_value = False with raises(SystemExit): utils.fetch_smt_data(cfg, None) + print(mock_logging.info.call_args_list) assert mock_logging.info.call_args_list == [ call('Using API: regionInfo'), call('Getting update server information, attempt 1'), @@ -910,6 +919,7 @@ def test_fetch_smt_data_api_error_response( ] +@patch('cloudregister.registerutils.has_network_access_by_ip_address') @patch('cloudregister.registerutils.requests.get') @patch('cloudregister.registerutils.os.path.isfile') @patch('cloudregister.registerutils.time.sleep') @@ -918,7 +928,8 @@ def test_fetch_smt_data_api_exception( mock_logging, mock_time_sleep, mock_os_path_isfile, - mock_request_get + mock_request_get, + mock_has_network_access ): cfg = get_test_config() del cfg['server']['metadata_server'] @@ -928,6 +939,7 @@ def test_fetch_smt_data_api_exception( response.status_code = 422 response.reason = 'well, you shall not pass' mock_request_get.side_effect = requests.exceptions.RequestException('foo') + mock_has_network_access.return_value = True with raises(SystemExit): utils.fetch_smt_data(cfg, None) assert mock_logging.info.call_args_list == [ @@ -956,6 +968,7 @@ def test_fetch_smt_data_api_exception( ] +@patch('cloudregister.registerutils.has_network_access_by_ip_address') @patch('cloudregister.registerutils.requests.get') @patch('cloudregister.registerutils.os.path.isfile') @patch('cloudregister.registerutils.time.sleep') @@ -964,7 +977,8 @@ def test_fetch_smt_data_api_exception_quiet( mock_logging, mock_time_sleep, mock_os_path_isfile, - mock_request_get + mock_request_get, + mock_has_network_access ): cfg = get_test_config() del cfg['server']['metadata_server'] @@ -974,6 +988,7 @@ def test_fetch_smt_data_api_exception_quiet( response.status_code = 422 response.reason = 'well, you shall not pass' mock_request_get.side_effect = requests.exceptions.RequestException('foo') + mock_has_network_access.return_value = True with raises(SystemExit): utils.fetch_smt_data(cfg, 'foo', quiet=True) assert mock_logging.info.call_args_list == [ @@ -1217,16 +1232,21 @@ def test_get_current_smt_no_match(mock_get_smt_from_store, mock_os_unlink): utils.get_current_smt() +@patch('cloudregister.registerutils.glob.glob') @patch('cloudregister.registerutils.get_smt_from_store') -def test_get_current_smt_no_registered(mock_get_smt_from_store): +def test_get_current_smt_no_registered( + mock_get_smt_from_store, mock_glob_glob +): smt_data_ipv46 = dedent('''\ ''') - smt_server = SMT(etree.fromstring(smt_data_ipv46)) - mock_get_smt_from_store.return_value = smt_server + mock_get_smt_from_store.return_value = SMT( + etree.fromstring(smt_data_ipv46) + ) + mock_glob_glob.return_value = [] hosts_content = """ # simulates hosts file containing the ipv4 we are looking for in the test @@ -1822,20 +1842,26 @@ def test_get_zypper_pid(mock_popen): assert utils.get_zypper_pid() == 'pid' -def test_has_ipv6_access_no_ipv6_defined(): +@patch('cloudregister.registerutils.has_ipv6_access') +def test_has_rmt_ipv6_access_no_ipv6_defined(mock_ipv6_access): smt_data_ipv4 = dedent('''\ ''') smt_server = SMT(etree.fromstring(smt_data_ipv4)) - assert utils.has_ipv6_access(smt_server) is False + mock_ipv6_access.return_value = True + assert utils.has_rmt_ipv6_access(smt_server) is False +@patch('cloudregister.registerutils.has_ipv6_access') @patch('cloudregister.registerutils.get_config') @patch('cloudregister.registerutils.requests.get') @patch('cloudregister.registerutils.https_only') -def test_has_ipv6_access_https(mock_https_only, mock_request, mock_get_config): +def test_has_rmt_ipv6_access_https( + mock_https_only, mock_request, + mock_get_config, mock_ipv6_access +): smt_data_ipv46 = dedent('''\