Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing SSHHook bug when using allow_host_key_change param #24116

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,16 @@ def get_conn(self) -> paramiko.SSHClient:
self.log.debug('Creating SSH client for conn_id: %s', self.ssh_conn_id)
client = paramiko.SSHClient()

if not self.allow_host_key_change:
if self.allow_host_key_change:
self.log.warning(
"Remote Identification Change is not verified. "
"This won't protect against Man-In-The-Middle attacks"
)
else:
client.load_system_host_keys()

if self.no_host_key_check:
self.log.warning("No Host Key Verification. This won't protect against Man-In-The-Middle attacks")
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
else:
if self.host_key is not None:
client_host_keys = client.get_host_keys()
Expand All @@ -289,6 +288,10 @@ def get_conn(self) -> paramiko.SSHClient:
else:
pass # will fallback to system host keys if none explicitly specified in conn extra

if self.no_host_key_check or self.allow_host_key_change:
# Default is RejectPolicy
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

connect_kwargs: Dict[str, Any] = dict(
hostname=self.remote_host,
username=self.username,
Expand Down
48 changes: 48 additions & 0 deletions tests/providers/ssh/hooks/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class TestSSHHook(unittest.TestCase):
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_host_key_and_no_host_key_check_false'
CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_host_key_and_no_host_key_check_true'
CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE = 'ssh_with_no_host_key_and_no_host_key_check_false'
CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE = 'ssh_with_no_host_key_and_no_host_key_check_true'
CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE = (
'ssh_with_host_key_and_allow_host_key_changes_true'
)

@classmethod
def tearDownClass(cls) -> None:
Expand All @@ -110,6 +114,7 @@ def tearDownClass(cls) -> None:
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_FALSE,
cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
]
connections = session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
connections.delete(synchronize_session=False)
Expand Down Expand Up @@ -236,6 +241,28 @@ def setUpClass(cls) -> None:
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": False}),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
host='remote_host',
conn_type='ssh',
extra=json.dumps({"private_key": TEST_PRIVATE_KEY, "no_host_key_check": True}),
)
)
db.merge_conn(
Connection(
conn_id=cls.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE,
host='remote_host',
conn_type='ssh',
extra=json.dumps(
{
"private_key": TEST_PRIVATE_KEY,
"host_key": TEST_HOST_KEY,
"allow_host_key_change": True,
}
),
)
)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_password(self, ssh_mock):
Expand Down Expand Up @@ -522,6 +549,27 @@ def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_false(self,
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.get_host_keys.return_value.add.called is False

def test_ssh_connection_with_host_key_where_no_host_key_check_is_true(self):
with pytest.raises(ValueError):
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_no_host_key_where_no_host_key_check_is_true(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE)
assert hook.host_key is None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.set_missing_host_key_policy.called is True

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_host_key_where_allow_host_key_change_is_true(self, ssh_client):
hook = SSHHook(ssh_conn_id=self.CONN_SSH_WITH_HOST_KEY_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
assert hook.host_key is not None
with hook.get_conn():
assert ssh_client.return_value.connect.called is True
assert ssh_client.return_value.load_system_host_keys.called is False
assert ssh_client.return_value.set_missing_host_key_policy.called is True

@mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
def test_ssh_connection_with_conn_timeout(self, ssh_mock):
hook = SSHHook(
Expand Down