Skip to content

Commit

Permalink
Rename schema to database in PostgresHook (#26744)
Browse files Browse the repository at this point in the history
* Rename schema to database in PostgresHook

In PostgresHook the "schema" field is only being called like that to make it compatible with the underlying DbApiHook which uses the schema for the sql alchemy connector. The postgres connector library however does not allow setting a schema in the connection instead a database can be set. To clarify that, we change all references in the PostgresHook code and documentation.
  • Loading branch information
feluelle committed Oct 31, 2022
1 parent 124fb39 commit 39caf1d
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 32 deletions.
39 changes: 35 additions & 4 deletions airflow/providers/postgres/hooks/postgres.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import os
import warnings
from contextlib import closing
from copy import deepcopy
from typing import Any, Iterable, Union
Expand Down Expand Up @@ -67,10 +68,38 @@ class PostgresHook(DbApiHook):
supports_autocommit = True

def __init__(self, *args, **kwargs) -> None:
if "schema" in kwargs:
warnings.warn(
'The "schema" arg has been renamed to "database" as it contained the database name.'
'Please use "database" to set the database name.',
DeprecationWarning,
stacklevel=2,
)
kwargs["database"] = kwargs["schema"]
super().__init__(*args, **kwargs)
self.connection: Connection | None = kwargs.pop("connection", None)
self.conn: connection = None
self.schema: str | None = kwargs.pop("schema", None)
self.database: str | None = kwargs.pop("database", None)

@property
def schema(self):
warnings.warn(
'The "schema" variable has been renamed to "database" as it contained the database name.'
'Please use "database" to get the database name.',
DeprecationWarning,
stacklevel=2,
)
return self.database

@schema.setter
def schema(self, value):
warnings.warn(
'The "schema" variable has been renamed to "database" as it contained the database name.'
'Please use "database" to set the database name.',
DeprecationWarning,
stacklevel=2,
)
self.database = value

def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
Expand All @@ -95,7 +124,7 @@ def get_conn(self) -> connection:
host=conn.host,
user=conn.login,
password=conn.password,
dbname=self.schema or conn.schema,
dbname=self.database or conn.schema,
port=conn.port,
)
raw_cursor = conn.extra_dejson.get("cursor", False)
Expand Down Expand Up @@ -143,7 +172,9 @@ def get_uri(self) -> str:
Extract the URI from the connection.
:return: the extracted uri.
"""
uri = super().get_uri().replace("postgres://", "postgresql://")
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn.schema = self.database or conn.schema
uri = conn.get_uri().replace("postgres://", "postgresql://")
return uri

def bulk_load(self, table: str, tmp_file: str) -> None:
Expand Down Expand Up @@ -196,7 +227,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=login,
DbName=self.schema or conn.schema,
DbName=self.database or conn.schema,
ClusterIdentifier=cluster_identifier,
AutoCreate=False,
)
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/postgres/operators/postgres.py
Expand Up @@ -38,6 +38,8 @@ class PostgresOperator(SQLExecuteQueryOperator):
(default value: False)
:param parameters: (optional) the parameters to render the SQL query with.
:param database: name of database which overwrite defined one in connection
:param runtime_parameters: a mapping of runtime params added to the final sql being executed.
For example, you could set the schema via `{"search_path": "CUSTOM_SCHEMA"}`.
"""

template_fields: Sequence[str] = ("sql",)
Expand Down
Expand Up @@ -29,7 +29,14 @@ Host (required)
The host to connect to.

Schema (optional)
Specify the schema name to be used in the database.
Specify the name of the database to connect to.

.. note::

If you want to define a default database schema:

* using ``PostgresOperator`` see :ref:`Passing Server Configuration Parameters into PostgresOperator <howto/operators:postgres>`
* using ``PostgresHook`` see `search_path <https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH>_`

Login (required)
Specify the user name to connect.
Expand Down
Expand Up @@ -15,6 +15,8 @@
specific language governing permissions and limitations
under the License.
.. _howto/operators:postgres:

How-to Guide for PostgresOperator
=================================

Expand Down
1 change: 0 additions & 1 deletion tests/providers/common/sql/operators/test_sql.py
Expand Up @@ -448,7 +448,6 @@ def test_get_hook(self, mock_get_conn, database):
if database:
self._operator.database = database
assert isinstance(self._operator._hook, PostgresHook)
assert self._operator._hook.schema == database
mock_get_conn.assert_called_once_with(self.conn_id)

def test_not_allowed_conn_type(self, mock_get_conn):
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/common/sql/sensors/test_sql.py
Expand Up @@ -263,8 +263,8 @@ def test_sql_sensor_hook_params(self):
conn_id="postgres_default",
sql="SELECT 1",
hook_params={
"schema": "public",
"log_sql": False,
},
)
hook = op._get_hook()
assert hook.schema == "public"
assert hook.log_sql == op.hook_params["log_sql"]
43 changes: 24 additions & 19 deletions tests/providers/postgres/hooks/test_postgres.py
Expand Up @@ -33,7 +33,7 @@
class TestPostgresHookConn:
@pytest.fixture(autouse=True)
def setup(self):
self.connection = Connection(login="login", password="password", host="host", schema="schema")
self.connection = Connection(login="login", password="password", host="host", schema="database")

class UnitTestPostgresHook(PostgresHook):
conn_name_attr = "test_conn_id"
Expand All @@ -47,15 +47,15 @@ def test_get_conn_non_default_id(self, mock_connect):
self.db_hook.test_conn_id = "non_default"
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user="login", password="password", host="host", dbname="schema", port=None
user="login", password="password", host="host", dbname="database", port=None
)
self.db_hook.get_connection.assert_called_once_with("non_default")

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user="login", password="password", host="host", dbname="schema", port=None
user="login", password="password", host="host", dbname="database", port=None
)

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
Expand All @@ -64,7 +64,7 @@ def test_get_uri(self, mock_connect):
self.connection.conn_type = "postgres"
self.db_hook.get_conn()
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() == "postgresql://login:password@host/schema?client_encoding=utf-8"
assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8"

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_cursor(self, mock_connect):
Expand All @@ -75,7 +75,7 @@ def test_get_conn_cursor(self, mock_connect):
user="login",
password="password",
host="host",
dbname="schema",
dbname="database",
port=None,
)

Expand All @@ -87,20 +87,20 @@ def test_get_conn_with_invalid_cursor(self, mock_connect):

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_from_connection(self, mock_connect):
conn = Connection(login="login-conn", password="password-conn", host="host", schema="schema")
conn = Connection(login="login-conn", password="password-conn", host="host", schema="database")
hook = PostgresHook(connection=conn)
hook.get_conn()
mock_connect.assert_called_once_with(
user="login-conn", password="password-conn", host="host", dbname="schema", port=None
user="login-conn", password="password-conn", host="host", dbname="database", port=None
)

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
def test_get_conn_from_connection_with_schema(self, mock_connect):
conn = Connection(login="login-conn", password="password-conn", host="host", schema="schema")
hook = PostgresHook(connection=conn, schema="schema-override")
def test_get_conn_from_connection_with_database(self, mock_connect):
conn = Connection(login="login-conn", password="password-conn", host="host", schema="database")
hook = PostgresHook(connection=conn, database="database-override")
hook.get_conn()
mock_connect.assert_called_once_with(
user="login-conn", password="password-conn", host="host", dbname="schema-override", port=None
user="login-conn", password="password-conn", host="host", dbname="database-override", port=None
)

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_get_conn_extra(self, mock_connect):
self.connection.extra = '{"connect_timeout": 3}'
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user="login", password="password", host="host", dbname="schema", port=None, connect_timeout=3
user="login", password="password", host="host", dbname="database", port=None, connect_timeout=3
)

@mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
Expand Down Expand Up @@ -225,32 +225,37 @@ def test_get_conn_rds_iam_redshift(
port=(port or 5439),
)

def test_get_uri_from_connection_without_schema_override(self):
def test_get_uri_from_connection_without_database_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="schema",
schema="database",
port=1,
)
)
assert "postgresql://login:password@host:1/schema" == self.db_hook.get_uri()
assert "postgresql://login:password@host:1/database" == self.db_hook.get_uri()

def test_get_uri_from_connection_with_schema_override(self):
hook = PostgresHook(schema="schema-override")
def test_get_uri_from_connection_with_database_override(self):
hook = PostgresHook(database="database-override")
hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="schema",
schema="database",
port=1,
)
)
assert "postgresql://login:password@host:1/schema-override" == hook.get_uri()
assert "postgresql://login:password@host:1/database-override" == hook.get_uri()

def test_schema_kwarg_database_kwarg_compatibility(self):
database = "database-override"
hook = PostgresHook(schema=database)
assert hook.database == database


class TestPostgresHook(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/postgres/operators/test_postgres.py
Expand Up @@ -79,14 +79,14 @@ def test_vacuum(self):
op = PostgresOperator(task_id="postgres_operator_test_vacuum", sql=sql, dag=self.dag, autocommit=True)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

def test_overwrite_schema(self):
def test_overwrite_database(self):
"""
Verifies option to overwrite connection schema
Verifies option to overwrite connection database
"""

sql = "SELECT 1;"
op = PostgresOperator(
task_id="postgres_operator_test_schema_overwrite",
task_id="postgres_operator_test_database_overwrite",
sql=sql,
dag=self.dag,
autocommit=True,
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/slack/transfers/test_sql_to_slack.py
Expand Up @@ -186,11 +186,11 @@ def test_hook_params(self, mock_get_conn):
sql="SELECT 1",
slack_message="message: {{ ds }}, {{ xxxx }}",
sql_hook_params={
"schema": "public",
"log_sql": False,
},
)
hook = op._get_hook()
assert hook.schema == "public"
assert hook.log_sql == op.sql_hook_params["log_sql"]

@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
def test_hook_params_snowflake(self, mock_get_conn):
Expand Down

0 comments on commit 39caf1d

Please sign in to comment.