Skip to content

Commit

Permalink
Remove connections scoped to schema
Browse files Browse the repository at this point in the history
It's too tricky to get this right. It'll lead to surprises because:


1. It breaks with transaction pooling.
2. Interaction with `SET LOCAL` is strange. A `SET` command after
a `SET LOCAL` overrides it.

I already shot myself in the foot twice since implementing this.
  • Loading branch information
Photonios committed Apr 10, 2023
1 parent 9d6fedf commit 57d95b1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 169 deletions.
33 changes: 0 additions & 33 deletions docs/source/schemas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,36 +153,3 @@ The ``public`` schema cannot be dropped. This is a Postgres built-in and it is a
schema = PostgresSchema.drop("myprefix")
schema = PostgresSchema.drop("myprefix", cascade=True)
Executing queries within a schema
---------------------------------
By default, a connection operates in the ``public`` schema. The schema offers a connection scoped to that schema that sets the Postgres ``search_path`` to only search within that schema.
.. warning::
This can be abused to manage Django models in a custom schema. This is not a supported workflow and there might be unexpected issues from attempting to do so.
.. warning::
Do not use this in the following scenarios:
1. You access the connection from multiple threads. Scoped connections are **NOT** thread safe.
2. The underlying database connection is passed through a connection pooler in transaction pooling mode.
.. code-block:: python
from psqlextra.schema import PostgresSchema
schema = PostgresSchema.create("myschema")
with schema.connection.cursor() as cursor:
# table gets created within the `myschema` schema, without
# explicitly specifying the schema name
cursor.execute("CREATE TABLE mytable AS SELECT 'hello'")
with schema.connection.schema_editor() as schema_editor:
# creates a table for the model within the schema
schema_editor.create_model(MyModel)
46 changes: 0 additions & 46 deletions psqlextra/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,11 @@

from contextlib import contextmanager

import wrapt

from django.core.exceptions import SuspiciousOperation, ValidationError
from django.db import DEFAULT_DB_ALIAS, connections, transaction
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.utils import CursorWrapper
from django.utils import timezone


class PostgresSchemaConnectionWrapper(wrapt.ObjectProxy):
"""Wraps a Django database connection and ensures that each cursor operates
within the specified schema."""

def __init__(self, connection, schema) -> None:
super().__init__(connection)

self._self_schema = schema

@contextmanager
def schema_editor(self):
with self.__wrapped__.schema_editor() as schema_editor:
schema_editor.connection = self
yield schema_editor

@contextmanager
def cursor(self) -> CursorWrapper:
schema = self._self_schema

with self.__wrapped__.cursor() as cursor:
quoted_name = self.ops.quote_name(schema.name)
cursor.execute(f"SET search_path = {quoted_name}")
try:
yield cursor
finally:
cursor.execute("SET search_path TO DEFAULT")


class PostgresSchema:
"""Represents a Postgres schema.
Expand Down Expand Up @@ -191,20 +159,6 @@ def delete(self, *, cascade: bool = False) -> None:
with connections[self.using].schema_editor() as schema_editor:
schema_editor.delete_schema(self.name, cascade=cascade)

@property
def connection(self) -> BaseDatabaseWrapper:
"""Obtains a database connection scoped to this schema.
Do not use this in the following scenarios:
1. You access the connection from multiple threads. Scoped
connections are NOT thread safe.
2. The underlying database connection is passed through a
connection pooler in transaction pooling mode.
"""

return PostgresSchemaConnectionWrapper(connections[self.using], self)

@classmethod
def _verify_generated_name_length(cls, prefix: str, suffix: str) -> None:
max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix)
Expand Down
101 changes: 11 additions & 90 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import uuid

import freezegun
import pytest
Expand Down Expand Up @@ -136,7 +135,7 @@ def test_postgres_schema_delete_not_empty():
schema = PostgresSchema.create("test")
assert _does_schema_exist(schema.name)

with schema.connection.cursor() as cursor:
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'")

with pytest.raises(InternalError) as exc_info:
Expand All @@ -150,96 +149,14 @@ def test_postgres_schema_delete_cascade_not_empty():
schema = PostgresSchema.create("test")
assert _does_schema_exist(schema.name)

with schema.connection.cursor() as cursor:
with connection.cursor() as cursor:
cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'")

schema.delete(cascade=True)
assert not _does_schema_exist(schema.name)


def test_postgres_schema_connection():
schema = PostgresSchema.create("test")

with schema.connection.cursor() as cursor:
# Creating a table without specifying the schema should create
# it in our schema and we should be able to select from it without
# specifying the schema.
cursor.execute("CREATE TABLE myschematable AS SELECT 'myschema'")
cursor.execute("SELECT * FROM myschematable")
assert cursor.fetchone() == ("myschema",)

# Proof that the table was created in our schema even though we
# never explicitly told it to do so.
cursor.execute(
"SELECT table_schema FROM information_schema.tables WHERE table_name = %s",
("myschematable",),
)
assert cursor.fetchone() == (schema.name,)

# Creating a table in another schema, we should not be able
# to select it without specifying the schema since our
# schema scoped connection only looks at our schema by default.
cursor.execute(
"CREATE TABLE public.otherschematable AS SELECT 'otherschema'"
)
with pytest.raises(ProgrammingError) as exc_info:
cursor.execute("SELECT * FROM otherschematable")

cursor.execute("ROLLBACK")

pg_error = extract_postgres_error(exc_info.value)
assert pg_error.pgcode == errorcodes.UNDEFINED_TABLE


def test_postgres_schema_connection_does_not_affect_default():
schema = PostgresSchema.create("test")

with schema.connection.cursor() as cursor:
cursor.execute("SHOW search_path")
assert cursor.fetchone() == ("test",)

with connection.cursor() as cursor:
cursor.execute("SHOW search_path")
assert cursor.fetchone() == ('"$user", public',)


@pytest.mark.django_db(transaction=True)
def test_postgres_schema_connection_does_not_affect_default_after_throw():
schema = PostgresSchema.create(str(uuid.uuid4()))

with pytest.raises(ProgrammingError):
with schema.connection.cursor() as cursor:
cursor.execute("COMMIT")
cursor.execute("SELECT frombadtable")

with connection.cursor() as cursor:
cursor.execute("ROLLBACK")
cursor.execute("SHOW search_path")
assert cursor.fetchone() == ('"$user", public',)


def test_postgres_schema_connection_schema_editor():
schema = PostgresSchema.create("test")

with schema.connection.schema_editor() as schema_editor:
with schema_editor.connection.cursor() as cursor:
cursor.execute("SHOW search_path")
assert cursor.fetchone() == ("test",)

with connection.cursor() as cursor:
cursor.execute("SHOW search_path")
assert cursor.fetchone() == ('"$user", public',)


def test_postgres_schema_connection_does_not_catch():
schema = PostgresSchema.create("test")

with pytest.raises(ValueError):
with schema.connection.cursor():
raise ValueError("test")


def test_postgres_schema_connection_no_delete_default():
def test_postgres_schema_no_delete_default():
with pytest.raises(SuspiciousOperation):
PostgresSchema.default.delete()

Expand All @@ -261,17 +178,21 @@ def test_postgres_temporary_schema():
def test_postgres_temporary_schema_not_empty():
with pytest.raises(InternalError) as exc_info:
with postgres_temporary_schema("temp") as schema:
with schema.connection.cursor() as cursor:
cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'")
with connection.cursor() as cursor:
cursor.execute(
f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'"
)

pg_error = extract_postgres_error(exc_info.value)
assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST


def test_postgres_temporary_schema_not_empty_cascade():
with postgres_temporary_schema("temp", cascade=True) as schema:
with schema.connection.cursor() as cursor:
cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'")
with connection.cursor() as cursor:
cursor.execute(
f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'"
)

assert not _does_schema_exist(schema.name)

Expand Down

0 comments on commit 57d95b1

Please sign in to comment.