Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres module_utils: add get_connect_params + unit tests #58067

Merged
merged 2 commits into from Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 58 additions & 40 deletions lib/ansible/module_utils/postgres.py
Expand Up @@ -66,52 +66,24 @@ def ensure_required_libs(module):
module.fail_json(msg='psycopg2 must be at least 2.4.3 in order to use the ca_cert parameter')


def connect_to_db(module, autocommit=False, fail_on_conn=True, warn_db_default=True):
"""Return psycopg2 connection object.

Keyword arguments:
module -- object of ansible.module_utils.basic.AnsibleModule class
autocommit -- commit automatically (default False)
fail_on_conn -- fail if connection failed or just warn and return None (default True)
warn_db_default -- warn that the default DB is used (default True)
"""
ensure_required_libs(module)

# To use defaults values, keyword arguments must be absent, so
# check which values are empty and don't include in the **kw
# dictionary
params_map = {
"login_host": "host",
"login_user": "user",
"login_password": "password",
"port": "port",
"ssl_mode": "sslmode",
"ca_cert": "sslrootcert"
}
def connect_to_db(module, conn_params, autocommit=False, fail_on_conn=True):
"""Connect to a PostgreSQL database.

# Might be different in the modules:
if module.params.get('db'):
params_map['db'] = 'database'
elif module.params.get('database'):
params_map['database'] = 'database'
elif module.params.get('login_db'):
params_map['login_db'] = 'database'
else:
if warn_db_default:
module.warn('Database name has not been passed, '
'used default database to connect to.')
Return psycopg2 connection object.

kw = dict((params_map[k], v) for (k, v) in iteritems(module.params)
if k in params_map and v != '' and v is not None)
Args:
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
conn_params (dict) -- dictionary with connection parameters

# If a login_unix_socket is specified, incorporate it here.
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
if is_localhost and module.params["login_unix_socket"] != "":
kw["host"] = module.params["login_unix_socket"]
Kwargs:
autocommit (bool) -- commit automatically (default False)
fail_on_conn (bool) -- fail if connection failed or just warn and return None (default True)
"""
ensure_required_libs(module)

db_connection = None
try:
db_connection = psycopg2.connect(**kw)
db_connection = psycopg2.connect(**conn_params)
if autocommit:
if LooseVersion(psycopg2.__version__) >= LooseVersion('2.4.2'):
db_connection.set_session(autocommit=True)
Expand Down Expand Up @@ -179,3 +151,49 @@ def exec_sql(obj, query, ddl=False, add_to_executed=True):
except Exception as e:
obj.module.fail_json(msg="Cannot execute SQL '%s': %s" % (query, to_native(e)))
return False


def get_conn_params(module, params_dict, warn_db_default=True):
"""Get connection parameters from the passed dictionary.

Return a dictionary with parameters to connect to PostgreSQL server.

Args:
module (AnsibleModule) -- object of ansible.module_utils.basic.AnsibleModule class
params_dict (dict) -- dictionary with variables

Kwargs:
warn_db_default (bool) -- warn that the default DB is used (default True)
"""
# To use defaults values, keyword arguments must be absent, so
# check which values are empty and don't include in the return dictionary
params_map = {
"login_host": "host",
"login_user": "user",
"login_password": "password",
"port": "port",
"ssl_mode": "sslmode",
"ca_cert": "sslrootcert"
}

# Might be different in the modules:
if params_dict.get('db'):
params_map['db'] = 'database'
elif params_dict.get('database'):
params_map['database'] = 'database'
elif params_dict.get('login_db'):
params_map['login_db'] = 'database'
else:
if warn_db_default:
module.warn('Database name has not been passed, '
'used default database to connect to.')

kw = dict((params_map[k], v) for (k, v) in iteritems(params_dict)
if k in params_map and v != '' and v is not None)

# If a login_unix_socket is specified, incorporate it here.
is_localhost = "host" not in kw or kw["host"] is None or kw["host"] == "localhost"
if is_localhost and params_dict["login_unix_socket"] != "":
kw["host"] = params_dict["login_unix_socket"]

return kw
4 changes: 3 additions & 1 deletion lib/ansible/modules/database/postgresql/postgresql_copy.py
Expand Up @@ -178,6 +178,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils.six import iteritems
Expand Down Expand Up @@ -351,7 +352,8 @@ def main():
module.fail_json(msg='src param is necessary with copy_to')

# Connect to DB and make cursor object:
db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)

##############
Expand Down
9 changes: 7 additions & 2 deletions lib/ansible/modules/database/postgresql/postgresql_ext.py
Expand Up @@ -143,7 +143,11 @@
pass

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
from ansible.module_utils.database import pg_quote_identifier

Expand Down Expand Up @@ -216,7 +220,8 @@ def main():
cascade = module.params["cascade"]
changed = False

db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

try:
Expand Down
4 changes: 3 additions & 1 deletion lib/ansible/modules/database/postgresql/postgresql_idx.py
Expand Up @@ -230,6 +230,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -474,7 +475,8 @@ def main():
if cascade and state != 'absent':
module.fail_json(msg="cascade parameter used only with state=absent")

db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

# Set defaults:
Expand Down
9 changes: 7 additions & 2 deletions lib/ansible/modules/database/postgresql/postgresql_info.py
Expand Up @@ -475,7 +475,11 @@
pass

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native


Expand All @@ -502,7 +506,8 @@ def connect(self):

Note: connection parameters are passed by self.module object.
"""
self.db_conn = connect_to_db(self.module, warn_db_default=False)
conn_params = get_conn_params(self.module, self.module.params, warn_db_default=False)
self.db_conn = connect_to_db(self.module, conn_params)
return self.db_conn.cursor(cursor_factory=DictCursor)

def reconnect(self, dbname):
Expand Down
9 changes: 7 additions & 2 deletions lib/ansible/modules/database/postgresql/postgresql_lang.py
Expand Up @@ -170,7 +170,11 @@
'''

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native
from ansible.module_utils.database import pg_quote_identifier

Expand Down Expand Up @@ -254,7 +258,8 @@ def main():
cascade = module.params["cascade"]
fail_on_drop = module.params["fail_on_drop"]

db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor()

changed = False
Expand Down
Expand Up @@ -147,6 +147,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -284,7 +285,8 @@ def main():
fail_on_role = module.params['fail_on_role']
state = module.params['state']

db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)

##############
Expand Down
4 changes: 3 additions & 1 deletion lib/ansible/modules/database/postgresql/postgresql_owner.py
Expand Up @@ -161,6 +161,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -415,7 +416,8 @@ def main():
reassign_owned_by = module.params['reassign_owned_by']
fail_on_role = module.params['fail_on_role']

db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)

##############
Expand Down
4 changes: 3 additions & 1 deletion lib/ansible/modules/database/postgresql/postgresql_ping.py
Expand Up @@ -82,6 +82,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -138,7 +139,8 @@ def main():
server_version=dict(),
)

db_connection = connect_to_db(module, fail_on_conn=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, fail_on_conn=False)

if db_connection is not None:
cursor = db_connection.cursor(cursor_factory=DictCursor)
Expand Down
9 changes: 7 additions & 2 deletions lib/ansible/modules/database/postgresql/postgresql_query.py
Expand Up @@ -146,7 +146,11 @@
pass

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native


Expand Down Expand Up @@ -189,7 +193,8 @@ def main():
except Exception as e:
module.fail_json(msg="Cannot read file '%s' : %s" % (path_to_script, to_native(e)))

db_connection = connect_to_db(module, autocommit=False)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=False)
cursor = db_connection.cursor(cursor_factory=DictCursor)

# Prepare args:
Expand Down
9 changes: 7 additions & 2 deletions lib/ansible/modules/database/postgresql/postgresql_schema.py
Expand Up @@ -129,7 +129,11 @@
pass

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils.database import SQLParseError, pg_quote_identifier
from ansible.module_utils._text import to_native

Expand Down Expand Up @@ -234,7 +238,8 @@ def main():
cascade_drop = module.params["cascade_drop"]
changed = False

db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

try:
Expand Down
Expand Up @@ -287,6 +287,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -498,7 +499,8 @@ def main():
# Change autocommit to False if check_mode:
autocommit = not module.check_mode
# Connect to DB and make cursor object:
db_connection = connect_to_db(module, autocommit=autocommit)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=autocommit)
cursor = db_connection.cursor(cursor_factory=DictCursor)

##############
Expand Down
11 changes: 8 additions & 3 deletions lib/ansible/modules/database/postgresql/postgresql_set.py
Expand Up @@ -165,7 +165,11 @@
from copy import deepcopy

from ansible.module_utils.basic import AnsibleModule
from ansible.module_utils.postgres import connect_to_db, postgres_common_argument_spec
from ansible.module_utils.postgres import (
connect_to_db,
get_conn_params,
postgres_common_argument_spec,
)
from ansible.module_utils._text import to_native

PG_REQ_VER = 90400
Expand Down Expand Up @@ -304,7 +308,8 @@ def main():
if not value and not reset:
module.fail_json(msg="%s: at least one of value or reset param must be specified" % name)

db_connection = connect_to_db(module, autocommit=True, warn_db_default=False)
conn_params = get_conn_params(module, module.params, warn_db_default=False)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

kw = {}
Expand Down Expand Up @@ -397,7 +402,7 @@ def main():

# Reconnect and recheck current value:
if context in ('sighup', 'superuser-backend', 'backend', 'superuser', 'user'):
db_connection = connect_to_db(module, autocommit=True)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

res = param_get(cursor, module, name)
Expand Down
4 changes: 3 additions & 1 deletion lib/ansible/modules/database/postgresql/postgresql_slot.py
Expand Up @@ -152,6 +152,7 @@
from ansible.module_utils.postgres import (
connect_to_db,
exec_sql,
get_conn_params,
postgres_common_argument_spec,
)

Expand Down Expand Up @@ -242,7 +243,8 @@ def main():
if immediately_reserve and slot_type == 'logical':
module.fail_json(msg="Module parameters immediately_reserve and slot_type=logical are mutually exclusive")

db_connection = connect_to_db(module, autocommit=True)
conn_params = get_conn_params(module, module.params)
db_connection = connect_to_db(module, conn_params, autocommit=True)
cursor = db_connection.cursor(cursor_factory=DictCursor)

##################################
Expand Down