diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 03e1fdef..90717b6a 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,9 +1,6 @@ -from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, List, Optional, Tuple -from django.db import transaction - from psqlextra.types import PostgresPartitioningMethod from . import base_impl @@ -250,6 +247,7 @@ def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]: pg_catalog.pg_am am ON (c.relam = am.oid) WHERE c.relname::text = %s + AND pg_catalog.pg_table_is_visible(c.oid) """ cursor.execute(sql, (table_name,)) @@ -288,29 +286,3 @@ def get_relations(self, cursor, table_name: str): [table_name], ) return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} - - @contextmanager - def in_search_path(self, search_path: List[str]): - """Changes the Postgres `search_path` within the context and switches - it back when it exits.""" - - # Wrap in a transaction so a savepoint is created. If - # something goes wrong, the `SET LOCAL search_path` - # statement will be rolled back. - with transaction.atomic(using=self.connection.alias): - with self.connection.cursor() as cursor: - cursor.execute("SHOW search_path") - (original_search_path,) = cursor.fetchone() - - # Syntax in Postgres is a bit weird here. It isn't really - # a list of names like in `WHERE bla in (val1, val2)`. - placeholder = ", ".join(["%s" for _ in search_path]) - cursor.execute( - f"SET LOCAL search_path = {placeholder}", search_path - ) - - yield self - - cursor.execute( - f"SET LOCAL search_path = {original_search_path}" - ) diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index 24adf5d0..52793fac 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -21,6 +21,3 @@ class PostgresOperations(base_impl.operations()): SQLUpdateCompiler, SQLInsertCompiler, ] - - def default_schema_name(self) -> str: - return "public" diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 435b0bdd..1e21b366 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -12,6 +12,10 @@ from django.db.backends.ddl_references import Statement from django.db.models import Field, Model +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, +) from psqlextra.type_assertions import is_sql_with_params from psqlextra.types import PostgresPartitioningMethod @@ -40,6 +44,8 @@ class PostgresSchemaEditor(SchemaEditor): sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)" sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" + sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s" + sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" sql_drop_view = "DROP VIEW IF EXISTS %s" @@ -203,8 +209,8 @@ def clone_model_constraints_and_indexes_to_schema( resides. """ - with self.introspection.in_search_path( - [schema_name, self.connection.ops.default_schema_name()] + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias ): for constraint in model._meta.constraints: self.add_constraint(model, constraint) @@ -226,8 +232,8 @@ def clone_model_constraints_and_indexes_to_schema( # Django creates primary keys later added to the model with # a custom name. We want the name as it was created originally. if field.primary_key: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [primary_key_name] = self._constraint_names( model, primary_key=True @@ -251,8 +257,8 @@ def clone_model_constraints_and_indexes_to_schema( # a separate transaction later to validate the entries without # acquiring a AccessExclusiveLock. if field.remote_field: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [fk_name] = self._constraint_names( model, [field.column], foreign_key=True @@ -277,8 +283,8 @@ def clone_model_constraints_and_indexes_to_schema( # manually. field_check = field.db_parameters(self.connection).get("check") if field_check: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [field_check_name] = self._constraint_names( model, @@ -337,10 +343,12 @@ def clone_model_foreign_keys_to_schema( resides. """ - with self.introspection.in_search_path( - [schema_name, self.connection.ops.default_schema_name()] + constraint_names = self._constraint_names(model, foreign_key=True) + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias ): - for fk_name in self._constraint_names(model, foreign_key=True): + for fk_name in constraint_names: self.execute( self.sql_validate_fk % ( @@ -438,6 +446,47 @@ def reset_model_storage_setting( self.reset_table_storage_setting(model._meta.db_table, name) + def alter_table_schema(self, table_name: str, schema_name: str) -> None: + """Moves the specified table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + table_name: + Name of the table to move into the specified schema. + + schema_name: + Name of the schema to move the table to. + """ + + self.execute( + self.sql_alter_table_schema + % (self.quote_name(table_name), self.quote_name(schema_name)) + ) + + def alter_model_schema(self, model: Type[Model], schema_name: str) -> None: + """Moves the specified model's table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + model: + Model of which to move the table. + + schema_name: + Name of the schema to move the model's table to. + """ + + self.execute( + self.sql_alter_table_schema + % ( + self.quote_name(model._meta.db_table), + self.quote_name(schema_name), + ) + ) + def refresh_materialized_view_model( self, model: Type[Model], concurrently: bool = False ) -> None: diff --git a/psqlextra/settings.py b/psqlextra/settings.py new file mode 100644 index 00000000..6dd32f37 --- /dev/null +++ b/psqlextra/settings.py @@ -0,0 +1,118 @@ +from contextlib import contextmanager +from typing import Dict, List, Optional, Union + +from django.core.exceptions import SuspiciousOperation +from django.db import DEFAULT_DB_ALIAS, connections + + +@contextmanager +def postgres_set_local( + *, + using: str = DEFAULT_DB_ALIAS, + **options: Dict[str, Optional[Union[str, int, float, List[str]]]], +) -> None: + """Sets the specified PostgreSQL options using SET LOCAL so that they apply + to the current transacton only. + + The effect is undone when the context manager exits. + + See https://www.postgresql.org/docs/current/runtime-config-client.html + for an overview of all available options. + """ + + connection = connections[using] + qn = connection.ops.quote_name + + if not connection.in_atomic_block: + raise SuspiciousOperation( + "SET LOCAL makes no sense outside a transaction. Start a transaction first." + ) + + sql = [] + params = [] + for name, value in options.items(): + if value is None: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + continue + + # Settings that accept a list of values are actually + # stored as string lists. We cannot just pass a list + # of values. We have to create the comma separated + # string ourselves. + if isinstance(value, list) or isinstance(value, tuple): + placeholder = ", ".join(["%s" for _ in value]) + params.extend(value) + else: + placeholder = "%s" + params.append(value) + + sql.append(f"SET LOCAL {qn(name)} = {placeholder}") + + with connection.cursor() as cursor: + cursor.execute( + "SELECT name, setting FROM pg_settings WHERE name = ANY(%s)", + (list(options.keys()),), + ) + original_values = dict(cursor.fetchall()) + cursor.execute("; ".join(sql), params) + + yield + + # Put everything back to how it was. DEFAULT is + # not good enough as a outer SET LOCAL might + # have set a different value. + with connection.cursor() as cursor: + sql = [] + params = [] + + for name, value in options.items(): + original_value = original_values.get(name) + if original_value: + sql.append(f"SET LOCAL {qn(name)} = {original_value}") + else: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + + cursor.execute("; ".join(sql), params) + + +@contextmanager +def postgres_set_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Sets the search path to the specified schemas.""" + + with postgres_set_local(search_path=search_path, using=using): + yield + + +@contextmanager +def postgres_prepend_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Prepends the current local search path with the specified schemas.""" + + connection = connections[using] + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + [ + original_search_path, + ] = cursor.fetchone() + + placeholders = ", ".join(["%s" for _ in search_path]) + cursor.execute( + f"SET LOCAL search_path = {placeholders}, {original_search_path}", + tuple(search_path), + ) + + yield + + cursor.execute(f"SET LOCAL search_path = {original_search_path}") + + +@contextmanager +def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None: + """Resets the local search path to the default.""" + + with postgres_set_local(search_path=None, using=using): + yield diff --git a/tests/db_introspection.py b/tests/db_introspection.py index eabc7414..285cd0e4 100644 --- a/tests/db_introspection.py +++ b/tests/db_introspection.py @@ -9,15 +9,14 @@ from django.db import connection +from psqlextra.settings import postgres_set_local + @contextmanager def introspect(schema_name: Optional[str] = None): - default_schema_name = connection.ops.default_schema_name() - search_path = [schema_name or default_schema_name] - - with connection.introspection.in_search_path(search_path) as introspection: + with postgres_set_local(search_path=schema_name or None): with connection.cursor() as cursor: - yield introspection, cursor + yield connection.introspection, cursor def table_names( diff --git a/tests/test_schema_editor_alter_schema.py b/tests/test_schema_editor_alter_schema.py new file mode 100644 index 00000000..7fda103b --- /dev/null +++ b/tests/test_schema_editor_alter_schema.py @@ -0,0 +1,44 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_alter_table_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_table_schema(fake_model._meta.db_table, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] + + +def test_schema_editor_alter_model_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_model_schema(fake_model, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] diff --git a/tests/test_schema_editor_clone_model_to_schema.py b/tests/test_schema_editor_clone_model_to_schema.py index 712d1433..c3d41917 100644 --- a/tests/test_schema_editor_clone_model_to_schema.py +++ b/tests/test_schema_editor_clone_model_to_schema.py @@ -22,6 +22,11 @@ def _create_schema() -> str: name = os.urandom(4).hex() with connection.cursor() as cursor: + cursor.execute( + "DROP SCHEMA IF EXISTS %s CASCADE" + % connection.ops.quote_name(name), + tuple(), + ) cursor.execute( "CREATE SCHEMA %s" % connection.ops.quote_name(name), tuple() ) @@ -29,6 +34,7 @@ def _create_schema() -> str: return name +@transaction.atomic def _assert_cloned_table_is_same( source_table_fqn: Tuple[str, str], target_table_fqn: Tuple[str, str], @@ -40,49 +46,49 @@ def _assert_cloned_table_is_same( source_columns = db_introspection.get_columns( source_table_name, schema_name=source_schema_name ) - source_columns = db_introspection.get_columns( + target_columns = db_introspection.get_columns( target_table_name, schema_name=target_schema_name ) - assert source_columns == source_columns + assert source_columns == target_columns source_relations = db_introspection.get_relations( source_table_name, schema_name=source_schema_name ) - source_relations = db_introspection.get_relations( + target_relations = db_introspection.get_relations( target_table_name, schema_name=target_schema_name ) if excluding_constraints_and_indexes: - assert source_relations == {} + assert target_relations == {} else: - assert source_relations == source_relations + assert source_relations == target_relations source_constraints = db_introspection.get_constraints( source_table_name, schema_name=source_schema_name ) - source_constraints = db_introspection.get_constraints( + target_constraints = db_introspection.get_constraints( target_table_name, schema_name=target_schema_name ) if excluding_constraints_and_indexes: - assert source_constraints == {} + assert target_constraints == {} else: - assert source_constraints == source_constraints + assert source_constraints == target_constraints source_sequences = db_introspection.get_sequences( source_table_name, schema_name=source_schema_name ) - source_sequences = db_introspection.get_sequences( + target_sequences = db_introspection.get_sequences( target_table_name, schema_name=target_schema_name ) - assert source_sequences == source_sequences + assert source_sequences == target_sequences source_storage_settings = db_introspection.get_storage_settings( source_table_name, schema_name=source_schema_name, ) - source_storage_settings = db_introspection.get_storage_settings( + target_storage_settings = db_introspection.get_storage_settings( target_table_name, schema_name=target_schema_name ) - assert source_storage_settings == source_storage_settings + assert source_storage_settings == target_storage_settings def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: @@ -108,16 +114,16 @@ def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: def _clone_model_into_schema(model): schema_name = _create_schema() - schema_editor = PostgresSchemaEditor(connection) - schema_editor.clone_model_structure_to_schema( - model, schema_name=schema_name - ) - schema_editor.clone_model_constraints_and_indexes_to_schema( - model, schema_name=schema_name - ) - schema_editor.clone_model_foreign_keys_to_schema( - model, schema_name=schema_name - ) + with PostgresSchemaEditor(connection) as schema_editor: + schema_editor.clone_model_structure_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_constraints_and_indexes_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_foreign_keys_to_schema( + model, schema_name=schema_name + ) return schema_name @@ -208,15 +214,17 @@ def test_schema_editor_clone_model_to_schema( AccessExclusiveLock on the source table works as expected.""" schema_editor = PostgresSchemaEditor(connection) - schema_editor.alter_table_storage_setting( - fake_model._meta.db_table, "autovacuum_enabled", "false" - ) + + with schema_editor: + schema_editor.alter_table_storage_setting( + fake_model._meta.db_table, "autovacuum_enabled", "false" + ) table_name = fake_model._meta.db_table - source_schema_name = connection.ops.default_schema_name() + source_schema_name = "public" target_schema_name = _create_schema() - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_structure_to_schema( fake_model, schema_name=target_schema_name ) @@ -231,7 +239,7 @@ def test_schema_editor_clone_model_to_schema( excluding_constraints_and_indexes=True, ) - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_constraints_and_indexes_to_schema( fake_model, schema_name=target_schema_name ) @@ -246,7 +254,7 @@ def test_schema_editor_clone_model_to_schema( (target_schema_name, table_name), ) - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_foreign_keys_to_schema( fake_model, schema_name=target_schema_name ) @@ -267,13 +275,13 @@ def test_schema_editor_clone_model_to_schema( reason=django_32_skip_reason, ) def test_schema_editor_clone_model_to_schema_custom_constraint_names( - fake_model, + fake_model, fake_model_fk_target_1 ): """Tests that even if constraints were given custom names, the cloned table has those same custom names.""" table_name = fake_model._meta.db_table - source_schema_name = connection.ops.default_schema_name() + source_schema_name = "public" constraints = db_introspection.get_constraints(table_name) @@ -290,6 +298,7 @@ def test_schema_editor_clone_model_to_schema_custom_constraint_names( name for name, constraint in constraints.items() if constraint["foreign_key"] + == (fake_model_fk_target_1._meta.db_table, "id") ), None, ) @@ -297,7 +306,7 @@ def test_schema_editor_clone_model_to_schema_custom_constraint_names( ( name for name, constraint in constraints.items() - if constraint["check"] + if constraint["check"] and constraint["columns"] == ["age"] ), None, ) diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..44519714 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,93 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection + +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, + postgres_set_local, + postgres_set_local_search_path, +) + + +def _get_current_setting(name: str) -> None: + with connection.cursor() as cursor: + cursor.execute(f"SHOW {name}") + return cursor.fetchone()[0] + + +@postgres_set_local(statement_timeout="2s", lock_timeout="3s") +def test_postgres_set_local_function_decorator(): + assert _get_current_setting("statement_timeout") == "2s" + assert _get_current_setting("lock_timeout") == "3s" + + +def test_postgres_set_local_context_manager(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +def test_postgres_set_local_iterable(): + with postgres_set_local(search_path=["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_set_local_nested(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + with postgres_set_local(statement_timeout="3s"): + assert _get_current_setting("statement_timeout") == "3s" + + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +@pytest.mark.django_db(transaction=True) +def test_postgres_set_local_no_transaction(): + with pytest.raises(SuspiciousOperation): + with postgres_set_local(statement_timeout="2s"): + pass + + +def test_postgres_set_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_reset_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + with postgres_reset_local_search_path(): + assert _get_current_setting("search_path") == '"$user", public' + + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path(): + with postgres_prepend_local_search_path(["a", "b"]): + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path_nested(): + with postgres_prepend_local_search_path(["a", "b"]): + with postgres_prepend_local_search_path(["c"]): + assert ( + _get_current_setting("search_path") + == 'c, a, b, "$user", public' + ) + + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public'