Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix CLI connections import and migrate logic from secrets to Connecti…
…on model (#15425)

* Add field 'extra' to Connection init

* Fix connections import CLI

In connections_import, each connection was deserialized and stored into a
Connection model instance rather than a dictionary, so an erroneous call to the
dictionary methods .items() resulted in an AttributeError. With this fix,
connection information is loaded from dictionaries directly into the
Connection constructor and committed to the DB.

* Apply suggestions from code review

* Use load_connections_dict in connections import

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
  • Loading branch information
natanweinberger and ashb committed Jun 11, 2021
1 parent 7432c4d commit 002075a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 56 deletions.
27 changes: 5 additions & 22 deletions airflow/cli/commands/connection_command.py
Expand Up @@ -28,7 +28,7 @@
from airflow.exceptions import AirflowNotFoundException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.secrets.local_filesystem import _create_connection, load_connections_dict
from airflow.secrets.local_filesystem import load_connections_dict
from airflow.utils import cli as cli_utils, yaml
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.session import create_session
Expand Down Expand Up @@ -238,39 +238,22 @@ def connections_delete(args):

@cli_utils.action_logging
def connections_import(args):
"""Imports connections from a given file"""
"""Imports connections from a file"""
if os.path.exists(args.file):
_import_helper(args.file)
else:
raise SystemExit("Missing connections file.")


def _import_helper(file_path):
"""Helps import connections from a file"""
"""Load connections from a file and save them to the DB. On collision, skip."""
connections_dict = load_connections_dict(file_path)
with create_session() as session:
for conn_id, conn_values in connections_dict.items():
for conn_id, conn in connections_dict.items():
if session.query(Connection).filter(Connection.conn_id == conn_id).first():
print(f'Could not import connection {conn_id}: connection already exists.')
continue

allowed_fields = [
'extra',
'description',
'conn_id',
'login',
'conn_type',
'host',
'password',
'schema',
'port',
'uri',
'extra_dejson',
]
filtered_connection_values = {
key: value for key, value in conn_values.items() if key in allowed_fields
}
connection = _create_connection(conn_id, filtered_connection_values)
session.add(connection)
session.add(conn)
session.commit()
print(f'Imported connection {conn_id}')
6 changes: 4 additions & 2 deletions airflow/models/connection.py
Expand Up @@ -19,7 +19,7 @@
import json
import warnings
from json import JSONDecodeError
from typing import Dict, Optional
from typing import Dict, Optional, Union
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse

from sqlalchemy import Boolean, Column, Integer, String, Text
Expand Down Expand Up @@ -117,12 +117,14 @@ def __init__( # pylint: disable=too-many-arguments
password: Optional[str] = None,
schema: Optional[str] = None,
port: Optional[int] = None,
extra: Optional[str] = None,
extra: Optional[Union[str, dict]] = None,
uri: Optional[str] = None,
):
super().__init__()
self.conn_id = conn_id
self.description = description
if extra and not isinstance(extra, str):
extra = json.dumps(extra)
if uri and ( # pylint: disable=too-many-boolean-expressions
conn_type or host or login or password or schema or port or extra
):
Expand Down
66 changes: 34 additions & 32 deletions tests/cli/commands/test_connection_command.py
Expand Up @@ -758,9 +758,9 @@ def test_cli_connections_import_should_return_error_if_file_format_is_invalid(
):
connection_command.connections_import(self.parser.parse_args(["connections", "import", filepath]))

@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_load_connections_dict):
def test_cli_connections_import_should_load_connections(self, mock_exists, mock_parse_secret_file):
mock_exists.return_value = True

# Sample connections to import
Expand All @@ -769,26 +769,26 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
"conn_type": "postgres",
"description": "new0 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 5432,
"schema": "airflow",
"extra": "test",
},
"new1": {
"conn_type": "mysql",
"description": "new1 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 3306,
"schema": "airflow",
"extra": "test",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections
# We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
mock_parse_secret_file.return_value = expected_connections

connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
Expand All @@ -799,14 +799,15 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"password",
"port",
"schema",
"extra",
]

current_conns_as_dicts = {
Expand All @@ -816,80 +817,81 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_
assert expected_connections == current_conns_as_dicts

@provide_session
@mock.patch('airflow.cli.commands.connection_command.load_connections_dict')
@mock.patch('airflow.secrets.local_filesystem._parse_secret_file')
@mock.patch('os.path.exists')
def test_cli_connections_import_should_not_overwrite_existing_connections(
self, mock_exists, mock_load_connections_dict, session=None
self, mock_exists, mock_parse_secret_file, session=None
):
mock_exists.return_value = True

# Add a pre-existing connection "new1"
# Add a pre-existing connection "new3"
merge_conn(
Connection(
conn_id="new1",
conn_id="new3",
conn_type="mysql",
description="mysql description",
description="original description",
host="mysql",
login="root",
password="",
password="password",
schema="airflow",
),
session=session,
)

# Sample connections to import, including a collision with "new1"
# Sample connections to import, including a collision with "new3"
expected_connections = {
"new0": {
"new2": {
"conn_type": "postgres",
"description": "new0 description",
"description": "new2 description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "password",
"port": 5432,
"schema": "airflow",
"extra": "test",
},
"new1": {
"new3": {
"conn_type": "mysql",
"description": "new1 description",
"description": "updated description",
"host": "host",
"is_encrypted": False,
"is_extra_encrypted": False,
"login": "airflow",
"password": "new password",
"port": 3306,
"schema": "airflow",
"extra": "test",
},
}

# We're not testing the behavior of load_connections_dict, assume successfully reads JSON, YAML or env
mock_load_connections_dict.return_value = expected_connections
# We're not testing the behavior of _parse_secret_file, assume it successfully reads JSON, YAML or env
mock_parse_secret_file.return_value = expected_connections

with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_import(
self.parser.parse_args(["connections", "import", 'sample.json'])
)

assert 'Could not import connection new1: connection already exists.' in stdout.getvalue()
assert 'Could not import connection new3: connection already exists.' in stdout.getvalue()

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()

comparable_attrs = [
"conn_id",
"conn_type",
"description",
"host",
"is_encrypted",
"is_extra_encrypted",
"login",
"password",
"port",
"schema",
"extra",
]

current_conns_as_dicts = {
current_conn.conn_id: {attr: getattr(current_conn, attr) for attr in comparable_attrs}
for current_conn in current_conns
}
assert current_conns_as_dicts['new0'] == expected_connections['new0']
assert current_conns_as_dicts['new2'] == expected_connections['new2']

# The existing connection's description should not have changed
assert current_conns_as_dicts['new1']['description'] == 'new1 description'
assert current_conns_as_dicts['new3']['description'] == 'original description'

0 comments on commit 002075a

Please sign in to comment.