diff --git a/airflow/providers/ssh/hooks/ssh.py b/airflow/providers/ssh/hooks/ssh.py index 88673ae371c76..f3f7d11a57ea2 100644 --- a/airflow/providers/ssh/hooks/ssh.py +++ b/airflow/providers/ssh/hooks/ssh.py @@ -266,17 +266,16 @@ def get_conn(self) -> paramiko.SSHClient: self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id) client = paramiko.SSHClient() - if not self.allow_host_key_change: + if self.allow_host_key_change: self.log.warning( "Remote Identification Change is not verified. " "This won't protect against Man-In-The-Middle attacks" ) + else: client.load_system_host_keys() if self.no_host_key_check: self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks") - # Default is RejectPolicy - client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) else: if self.host_key is not None: client_host_keys = client.get_host_keys() @@ -289,6 +288,10 @@ def get_conn(self) -> paramiko.SSHClient: else: pass # will fallback to system host keys if none explicitly specified in conn extra + if self.no_host_key_check or self.allow_host_key_change: + # Default is RejectPolicy + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + connect_kwargs: Dict[str, Any] = dict( hostname=self.remote_host, username=self.username, diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index c248ebf45d4bf..7362ed918dee8 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -92,6 +92,10 @@ class TestSSHHook(unittest.TestCase): CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false' CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true' CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false' + CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_no_host_key_and_no_host_key_check_true' + CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE = ( + 'ssh_with_host_key_and_allow_host_key_changes_true' + ) @classmethod def tearDownClass(cls) -> None: @@ -110,6 +114,7 @@ def tearDownClass(cls) -> None: cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE, + cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, ] connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset)) connections.delete(synchronize_session=False) @@ -236,6 +241,28 @@ def setUpClass(cls) -> None: extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}), ) ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE, + host='remote_host', + conn_type='ssh', + extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": True}), + ) + ) + db.merge_conn( + Connection( + conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE, + host='remote_host', + conn_type='ssh', + extra=json.dumps( + { + "private_key": TEST_PRIVATE_KEY, + "host_key": TEST_HOST_KEY, + "allow_host_key_change": True, + } + ), + ) + ) @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_with_password(self, ssh_mock): @@ -522,6 +549,27 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self, assert ssh_client.return_value.connect.called is True assert ssh_client.return_value.get_host_keys.return_value.add.called is False + def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self): + with pytest.raises(ValueError): + SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE) + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_true(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE) + assert hook.host_key is None + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.set_missing_host_key_policy.called is True + + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') + def test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, ssh_client): + hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE) + assert hook.host_key is not None + with hook.get_conn(): + assert ssh_client.return_value.connect.called is True + assert ssh_client.return_value.load_system_host_keys.called is False + assert ssh_client.return_value.set_missing_host_key_policy.called is True + @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient') def test_ssh_connection_with_conn_timeout(self, ssh_mock): hook = SSHHook(