Skip to content

Commit

Permalink
fix: Conversion from view to materialized view.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Dec 16, 2022
1 parent 1b5da69 commit 6f43871
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
${{ runner.os }}-poetry-
- name: Install dependencies
run: poetry install
run: poetry install -E parse

- name: Install specific sqlalchemy version
run: |
Expand Down
39 changes: 16 additions & 23 deletions src/sqlalchemy_declarative_extensions/view/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from sqlalchemy_declarative_extensions.sql import qualify_name
from sqlalchemy_declarative_extensions.sqlalchemy import HasMetaData

T = TypeVar("T", HasMetaData, MetaData)
T = TypeVar("T", bound=HasMetaData)


def view(base_or_metadata: T, materialized: bool = False) -> Callable[[type], T]:
def view(base: T, materialized: bool = False) -> Callable[[type], T]:
"""Decorate a class or declarative base model in order to register a View.
Given some object with the attributes: `__tablename__`, (optionally for schema) `__table_args__`,
Expand Down Expand Up @@ -48,7 +48,7 @@ def decorator(cls):
table_args = getattr(cls, "__table_args__", None)
view_def = cls.__view__

mapper = instrument_sqlalchemy(base_or_metadata, cls)
mapper = instrument_sqlalchemy(base, cls)

schema = find_schema(table_args)
constraints = find_constraints(table_args)
Expand All @@ -60,17 +60,15 @@ def decorator(cls):
constraints=constraints,
)

register_view(base_or_metadata, instance)
register_view(base, instance)

return mapper # noqa

return decorator


def instrument_sqlalchemy(base_or_metadata: T, cls) -> T:
metadata = get_metadata(base_or_metadata)

temp_metadata = MetaData(naming_convention=metadata.naming_convention)
def instrument_sqlalchemy(base: T, cls) -> T:
temp_metadata = MetaData(naming_convention=base.metadata.naming_convention)
try:
try:
from sqlalchemy import orm
Expand All @@ -87,15 +85,18 @@ def instrument_sqlalchemy(base_or_metadata: T, cls) -> T:
return mapper


def register_view(base_or_metadata: HasMetaData, view: View):
def register_view(base_or_metadata: HasMetaData | MetaData, view: View):
"""Register a view onto the given declarative base or `Metadata`.
This can be used instead of the [view](view) decorator, if you are constructing
`View` objects directly. In this way, you can imperitively register views next
to their corresponding table definitions, rather than at the root declarative
base, like many of the other object types are documented to do.
"""
metadata = get_metadata(base_or_metadata)
if isinstance(base_or_metadata, MetaData):
metadata = base_or_metadata
else:
metadata = base_or_metadata.metadata

if not metadata.info.get("views"):
metadata.info["views"] = Views()
Expand Down Expand Up @@ -128,7 +129,7 @@ def coerce_from_unknown(cls, unknown: Any) -> View:

try:
import alembic_utils # noqa
except ImportError:
except ImportError: # pragma: no cover
pass
else:
from alembic_utils.pg_materialized_view import PGMaterializedView
Expand All @@ -155,7 +156,7 @@ def render_definition(self, dialect: Dialect):
try:
import sqlglot
from sqlglot.optimizer.normalize import normalize
except ImportError:
except ImportError: # pragma: no cover
raise ImportError("View autogeneration requires the 'parse' extra.")

if isinstance(self.definition, str):
Expand Down Expand Up @@ -265,24 +266,22 @@ def append(self, view: View):
self.views.append(view)

def __iter__(self):
for grant in self.grants:
yield grant
for view in self.views:
yield view

def are(self, *views: View):
return replace(self, views=[View.coerce_from_unknown(v) for v in views])


def find_schema(table_args=None):
if table_args is None:
return None

if isinstance(table_args, dict):
return table_args.get("schema")

if isinstance(table_args, Iterable):
for table_arg in table_args:
if isinstance(table_arg, dict):
return table_arg.get("schema")

return None


Expand All @@ -294,9 +293,3 @@ def find_constraints(table_args=None):
return [table_arg for table_arg in table_args if isinstance(table_arg, Index)]

return None


def get_metadata(base_or_metadata: T) -> MetaData:
if isinstance(base_or_metadata, MetaData):
return base_or_metadata
return base_or_metadata.metadata
24 changes: 2 additions & 22 deletions src/sqlalchemy_declarative_extensions/view/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
class CreateViewOp:
view: View

@classmethod
def create_view(
cls,
operations,
view_name: str,
definition: str,
*,
schema: str | None = None,
materialized: bool = False,
):
op = cls(View(view_name, definition, schema=schema, materialized=materialized))
return operations.invoke(op)

def reverse(self):
return DropViewOp(self.view)

Expand All @@ -37,11 +24,6 @@ def to_sql(self, dialect: Dialect) -> list[str]:
class DropViewOp:
view: View

@classmethod
def drop_view(cls, operations, view_name: str, schema: str | None = None):
op = cls(View(view_name, definition="", schema=schema))
return operations.invoke(op)

def reverse(self):
return CreateViewOp(self.view)

Expand All @@ -54,8 +36,6 @@ def to_sql(self, dialect: Dialect) -> list[str]:

def compare_views(connection: Connection, views: Views) -> list[Operation]:
result: list[Operation] = []
if not views:
return result

views_by_name = {r.qualified_name: r for r in views.views}
expected_view_names = set(views_by_name)
Expand All @@ -67,7 +47,7 @@ def compare_views(connection: Connection, views: Views) -> list[Operation]:
new_view_names = expected_view_names - existing_view_names
removed_view_names = existing_view_names - expected_view_names

for view in views.views:
for view in views:
view_name = view.qualified_name

if view_name in views.ignore_views:
Expand All @@ -83,7 +63,7 @@ def compare_views(connection: Connection, views: Views) -> list[Operation]:
view_updated = not existing_view.equals(view, connection.dialect)
if view_updated:
existing_view = existing_views_by_name[view_name]
result.append(DropViewOp(view))
result.append(DropViewOp(existing_view))
result.append(CreateViewOp(view))

if not views.ignore_unspecified:
Expand Down
4 changes: 2 additions & 2 deletions tests/examples/test_view_drop_pg/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sqlalchemy import Column, types
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_declarative_extensions import Row, Views, declarative_database
from sqlalchemy_declarative_extensions import Row, declarative_database

_Base = declarative_base()

Expand All @@ -17,7 +17,7 @@ class Base(_Base):
Row("foo", id=11),
Row("foo", id=12),
]
views = Views()
views = []


class Foo(Base):
Expand Down
17 changes: 3 additions & 14 deletions tests/examples/test_view_update_pg/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sqlalchemy
from sqlalchemy import Column, select, types
from sqlalchemy import Column, types
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_declarative_extensions import Row, Views, declarative_database, view
from sqlalchemy_declarative_extensions import Row, View, Views, declarative_database

_Base = declarative_base()

Expand All @@ -17,7 +17,7 @@ class Base(_Base):
Row("foo", id=11),
Row("foo", id=12),
]
views = Views()
views = Views().are(View("bar", "select id from foo where id > 10"))


class Foo(Base):
Expand All @@ -30,14 +30,3 @@ class Foo(Base):
server_default=sqlalchemy.text("CURRENT_TIMESTAMP"),
nullable=False,
)


foo_table = Foo.__table__


@view(Base.metadata)
class Bar:
__tablename__ = "bar"
__view__ = select(foo_table.c.id).where(foo_table.c.id > 10)

id = Column(types.Integer(), autoincrement=True, primary_key=True)
73 changes: 73 additions & 0 deletions tests/view/test_convert_to_materialized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_declarative_extensions import (
Row,
Rows,
Schemas,
View,
declarative_database,
register_sqlalchemy_events,
register_view,
)

Base_ = declarative_base()


@declarative_database
class Base(Base_):
__abstract__ = True

schemas = Schemas().are("fooschema")
rows = Rows().are(
Row("fooschema.foo", id=1),
Row("fooschema.foo", id=2),
Row("fooschema.foo", id=12),
Row("fooschema.foo", id=13),
)


class Foo(Base):
__tablename__ = "foo"
__table_args__ = {"schema": "fooschema"}

id = Column(types.Integer(), primary_key=True)


# Register imperitively
view = View(
"bar",
"select id from fooschema.foo where id < 10",
schema="fooschema",
materialized=True,
)

register_view(Base.metadata, view)


register_sqlalchemy_events(Base.metadata, schemas=True, views=True, rows=True)

pg = create_postgres_fixture(
scope="function", engine_kwargs={"echo": True}, session=True
)


def test_create_view_postgresql(pg):
pg.execute(text("CREATE SCHEMA fooschema"))
pg.execute(text("CREATE TABLE fooschema.foo (id integer)"))

pg.execute(
text(
"CREATE VIEW fooschema.bar AS (SELECT id FROM fooschema.foo WHERE id < 10)"
)
)

Base.metadata.create_all(bind=pg.connection())

result = [f.id for f in pg.query(Foo).all()]
assert result == [1, 2, 12, 13]

pg.execute(text("refresh materialized view fooschema.bar"))
result = [f.id for f in pg.execute(text("select * from fooschema.bar")).all()]
assert result == [1, 2]
46 changes: 46 additions & 0 deletions tests/view/test_ignore_unspecified.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from pytest_mock_resources import create_postgres_fixture
from sqlalchemy import Column, text, types
from sqlalchemy.ext.declarative import declarative_base

from sqlalchemy_declarative_extensions import (
Schemas,
Views,
declarative_database,
register_sqlalchemy_events,
)
from sqlalchemy_declarative_extensions.view.compare import compare_views

Base_ = declarative_base()


@declarative_database
class Base(Base_):
__abstract__ = True

schemas = Schemas().are("fooschema")
views = Views(ignore_unspecified=True)


class Foo(Base):
__tablename__ = "foo"
__table_args__ = {"schema": "fooschema"}

id = Column(types.Integer(), primary_key=True)


register_sqlalchemy_events(Base.metadata, schemas=True, views=True, rows=True)

pg = create_postgres_fixture(
scope="function", engine_kwargs={"echo": True}, session=True
)


def test_ignore_views(pg):
Base.metadata.create_all(bind=pg.connection())

pg.execute(text("CREATE VIEW meow as (SELECT id from fooschema.foo)"))

# Verify this no longer sees changes to make! Failing here would imply the autogenerate
# is not fully normalizing the difference.
result = compare_views(pg.connection(), views=Base.views)
assert result == []

0 comments on commit 6f43871

Please sign in to comment.