Skip to content

Commit

Permalink
Add support to schema editor for moving tables between schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
Photonios committed Apr 11, 2023
1 parent e1a43cd commit 6eff3f1
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 80 deletions.
30 changes: 1 addition & 29 deletions psqlextra/backend/introspection.py
@@ -1,9 +1,6 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from django.db import transaction

from psqlextra.types import PostgresPartitioningMethod

from . import base_impl
Expand Down Expand Up @@ -250,6 +247,7 @@ def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]:
pg_catalog.pg_am am ON (c.relam = am.oid)
WHERE
c.relname::text = %s
AND pg_catalog.pg_table_is_visible(c.oid)
"""

cursor.execute(sql, (table_name,))
Expand Down Expand Up @@ -288,29 +286,3 @@ def get_relations(self, cursor, table_name: str):
[table_name],
)
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}

@contextmanager
def in_search_path(self, search_path: List[str]):
"""Changes the Postgres `search_path` within the context and switches
it back when it exits."""

# Wrap in a transaction so a savepoint is created. If
# something goes wrong, the `SET LOCAL search_path`
# statement will be rolled back.
with transaction.atomic(using=self.connection.alias):
with self.connection.cursor() as cursor:
cursor.execute("SHOW search_path")
(original_search_path,) = cursor.fetchone()

# Syntax in Postgres is a bit weird here. It isn't really
# a list of names like in `WHERE bla in (val1, val2)`.
placeholder = ", ".join(["%s" for _ in search_path])
cursor.execute(
f"SET LOCAL search_path = {placeholder}", search_path
)

yield self

cursor.execute(
f"SET LOCAL search_path = {original_search_path}"
)
3 changes: 0 additions & 3 deletions psqlextra/backend/operations.py
Expand Up @@ -21,6 +21,3 @@ class PostgresOperations(base_impl.operations()):
SQLUpdateCompiler,
SQLInsertCompiler,
]

def default_schema_name(self) -> str:
return "public"
71 changes: 60 additions & 11 deletions psqlextra/backend/schema.py
Expand Up @@ -12,6 +12,10 @@
from django.db.backends.ddl_references import Statement
from django.db.models import Field, Model

from psqlextra.settings import (
postgres_prepend_local_search_path,
postgres_reset_local_search_path,
)
from psqlextra.type_assertions import is_sql_with_params
from psqlextra.types import PostgresPartitioningMethod

Expand Down Expand Up @@ -40,6 +44,8 @@ class PostgresSchemaEditor(SchemaEditor):
sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)"
sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)"

sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s"

sql_create_view = "CREATE VIEW %s AS (%s)"
sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)"
sql_drop_view = "DROP VIEW IF EXISTS %s"
Expand Down Expand Up @@ -203,8 +209,8 @@ def clone_model_constraints_and_indexes_to_schema(
resides.
"""

with self.introspection.in_search_path(
[schema_name, self.connection.ops.default_schema_name()]
with postgres_prepend_local_search_path(
[schema_name], using=self.connection.alias
):
for constraint in model._meta.constraints:
self.add_constraint(model, constraint)
Expand All @@ -226,8 +232,8 @@ def clone_model_constraints_and_indexes_to_schema(
# Django creates primary keys later added to the model with
# a custom name. We want the name as it was created originally.
if field.primary_key:
with self.introspection.in_search_path(
[self.connection.ops.default_schema_name()]
with postgres_reset_local_search_path(
using=self.connection.alias
):
[primary_key_name] = self._constraint_names(
model, primary_key=True
Expand All @@ -251,8 +257,8 @@ def clone_model_constraints_and_indexes_to_schema(
# a separate transaction later to validate the entries without
# acquiring a AccessExclusiveLock.
if field.remote_field:
with self.introspection.in_search_path(
[self.connection.ops.default_schema_name()]
with postgres_reset_local_search_path(
using=self.connection.alias
):
[fk_name] = self._constraint_names(
model, [field.column], foreign_key=True
Expand All @@ -277,8 +283,8 @@ def clone_model_constraints_and_indexes_to_schema(
# manually.
field_check = field.db_parameters(self.connection).get("check")
if field_check:
with self.introspection.in_search_path(
[self.connection.ops.default_schema_name()]
with postgres_reset_local_search_path(
using=self.connection.alias
):
[field_check_name] = self._constraint_names(
model,
Expand Down Expand Up @@ -337,10 +343,12 @@ def clone_model_foreign_keys_to_schema(
resides.
"""

with self.introspection.in_search_path(
[schema_name, self.connection.ops.default_schema_name()]
constraint_names = self._constraint_names(model, foreign_key=True)

with postgres_prepend_local_search_path(
[schema_name], using=self.connection.alias
):
for fk_name in self._constraint_names(model, foreign_key=True):
for fk_name in constraint_names:
self.execute(
self.sql_validate_fk
% (
Expand Down Expand Up @@ -438,6 +446,47 @@ def reset_model_storage_setting(

self.reset_table_storage_setting(model._meta.db_table, name)

def alter_table_schema(self, table_name: str, schema_name: str) -> None:
"""Moves the specified table into the specified schema.
WARNING: Moving models into a different schema than the default
will break querying the model.
Arguments:
table_name:
Name of the table to move into the specified schema.
schema_name:
Name of the schema to move the table to.
"""

self.execute(
self.sql_alter_table_schema
% (self.quote_name(table_name), self.quote_name(schema_name))
)

def alter_model_schema(self, model: Type[Model], schema_name: str) -> None:
"""Moves the specified model's table into the specified schema.
WARNING: Moving models into a different schema than the default
will break querying the model.
Arguments:
model:
Model of which to move the table.
schema_name:
Name of the schema to move the model's table to.
"""

self.execute(
self.sql_alter_table_schema
% (
self.quote_name(model._meta.db_table),
self.quote_name(schema_name),
)
)

def refresh_materialized_view_model(
self, model: Type[Model], concurrently: bool = False
) -> None:
Expand Down
118 changes: 118 additions & 0 deletions psqlextra/settings.py
@@ -0,0 +1,118 @@
from contextlib import contextmanager
from typing import Dict, List, Optional, Union

from django.core.exceptions import SuspiciousOperation
from django.db import DEFAULT_DB_ALIAS, connections


@contextmanager
def postgres_set_local(
*,
using: str = DEFAULT_DB_ALIAS,
**options: Dict[str, Optional[Union[str, int, float, List[str]]]],
) -> None:
"""Sets the specified PostgreSQL options using SET LOCAL so that they apply
to the current transacton only.
The effect is undone when the context manager exits.
See https://www.postgresql.org/docs/current/runtime-config-client.html
for an overview of all available options.
"""

connection = connections[using]
qn = connection.ops.quote_name

if not connection.in_atomic_block:
raise SuspiciousOperation(
"SET LOCAL makes no sense outside a transaction. Start a transaction first."
)

sql = []
params = []
for name, value in options.items():
if value is None:
sql.append(f"SET LOCAL {qn(name)} TO DEFAULT")
continue

# Settings that accept a list of values are actually
# stored as string lists. We cannot just pass a list
# of values. We have to create the comma separated
# string ourselves.
if isinstance(value, list) or isinstance(value, tuple):
placeholder = ", ".join(["%s" for _ in value])
params.extend(value)
else:
placeholder = "%s"
params.append(value)

sql.append(f"SET LOCAL {qn(name)} = {placeholder}")

with connection.cursor() as cursor:
cursor.execute(
"SELECT name, setting FROM pg_settings WHERE name = ANY(%s)",
(list(options.keys()),),
)
original_values = dict(cursor.fetchall())
cursor.execute("; ".join(sql), params)

yield

# Put everything back to how it was. DEFAULT is
# not good enough as a outer SET LOCAL might
# have set a different value.
with connection.cursor() as cursor:
sql = []
params = []

for name, value in options.items():
original_value = original_values.get(name)
if original_value:
sql.append(f"SET LOCAL {qn(name)} = {original_value}")
else:
sql.append(f"SET LOCAL {qn(name)} TO DEFAULT")

cursor.execute("; ".join(sql), params)


@contextmanager
def postgres_set_local_search_path(
search_path: List[str], *, using: str = DEFAULT_DB_ALIAS
) -> None:
"""Sets the search path to the specified schemas."""

with postgres_set_local(search_path=search_path, using=using):
yield


@contextmanager
def postgres_prepend_local_search_path(
search_path: List[str], *, using: str = DEFAULT_DB_ALIAS
) -> None:
"""Prepends the current local search path with the specified schemas."""

connection = connections[using]

with connection.cursor() as cursor:
cursor.execute("SHOW search_path")
[
original_search_path,
] = cursor.fetchone()

placeholders = ", ".join(["%s" for _ in search_path])
cursor.execute(
f"SET LOCAL search_path = {placeholders}, {original_search_path}",
tuple(search_path),
)

yield

cursor.execute(f"SET LOCAL search_path = {original_search_path}")


@contextmanager
def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None:
"""Resets the local search path to the default."""

with postgres_set_local(search_path=None, using=using):
yield
9 changes: 4 additions & 5 deletions tests/db_introspection.py
Expand Up @@ -9,15 +9,14 @@

from django.db import connection

from psqlextra.settings import postgres_set_local


@contextmanager
def introspect(schema_name: Optional[str] = None):
default_schema_name = connection.ops.default_schema_name()
search_path = [schema_name or default_schema_name]

with connection.introspection.in_search_path(search_path) as introspection:
with postgres_set_local(search_path=schema_name or None):
with connection.cursor() as cursor:
yield introspection, cursor
yield connection.introspection, cursor


def table_names(
Expand Down
44 changes: 44 additions & 0 deletions tests/test_schema_editor_alter_schema.py
@@ -0,0 +1,44 @@
import pytest

from django.db import connection, models

from psqlextra.backend.schema import PostgresSchemaEditor

from .fake_model import get_fake_model


@pytest.fixture
def fake_model():
return get_fake_model(
{
"text": models.TextField(),
}
)


def test_schema_editor_alter_table_schema(fake_model):
obj = fake_model.objects.create(text="hello")

with connection.cursor() as cursor:
cursor.execute("CREATE SCHEMA target")

schema_editor = PostgresSchemaEditor(connection)
schema_editor.alter_table_schema(fake_model._meta.db_table, "target")

with connection.cursor() as cursor:
cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}")
assert cursor.fetchall() == [(obj.id, obj.text)]


def test_schema_editor_alter_model_schema(fake_model):
obj = fake_model.objects.create(text="hello")

with connection.cursor() as cursor:
cursor.execute("CREATE SCHEMA target")

schema_editor = PostgresSchemaEditor(connection)
schema_editor.alter_model_schema(fake_model, "target")

with connection.cursor() as cursor:
cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}")
assert cursor.fetchall() == [(obj.id, obj.text)]

0 comments on commit 6eff3f1

Please sign in to comment.