Skip to content

Commit

Permalink
Merge pull request #11 from DanCardin/dc/fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Dec 7, 2022
2 parents c329b9e + 912e3d9 commit 4d822ca
Show file tree
Hide file tree
Showing 16 changed files with 307 additions and 38 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.9"]
# Test our minimum version bound, the highest version available,
# and something in the middle (i.e. what gets run locally).
python-version: ["3.7", "3.9", "3.11"]
sqlalchemy-version: ["1.3", "1.4"]

name: Python ${{ matrix.python-version }} Tests
Expand Down
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ The primary function(s) of this library include:
- (Optionally) Registers into Alembic such that `alembic revision --autogenerate`
automatically creates/updates/deletes declared objects.

## Example Usage
## Kitchen Sink Example Usage

```python
from sqlalchemy import Column, types
from sqlalchemy import Column, types, select
from sqlalchemy.orm import as_declarative
from sqlalchemy_declarative_extensions import (
declarative_database, Schemas, Roles, Grants, Rows, Row
declarative_database, Schemas, Roles, Grants, Rows, Row, Views, View, view
)
from sqlalchemy_declarative_extensions.dialects.postgresql import DefaultGrant, Role

Expand All @@ -52,12 +52,19 @@ class Base:
rows = Rows().are(
Row('foo', id=1),
)
views = Views().are(View("low_foo", "select * from foo where i < 10"))


class Foo(Base):
__tablename__ = 'foo'

id = Column(types.Integer(), primary_key=True)


@view()
class HighFoo:
__tablename__ = "high_foo"
__view__ = select(Foo.__table__).where(Foo.__table__.c.id >= 10)
```

Note, there is also support for declaring objects directly through the `MetaData` for
Expand Down
59 changes: 55 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqlalchemy-declarative-extensions"
version = "0.3.1"
version = "0.3.2"
description = "Library to declare additional kinds of objects not natively supported by SqlAlchemy/Alembic."

authors = ["Dan Cardin <ddcardin@gmail.com>"]
Expand Down Expand Up @@ -43,6 +43,7 @@ alembic = { version = ">=1.0", optional = true }
psycopg2-binary = "*"

[tool.poetry.group.dev.dependencies]
alembic-utils = "*"
black = ">=22.3.0"
coverage = ">=5"
ruff = ">=0.0.165"
Expand Down
6 changes: 3 additions & 3 deletions src/sqlalchemy_declarative_extensions/alembic/grant.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def compare_grants(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, _):


@renderers.dispatch_for(GrantPrivilegesOp)
def render_grant(_, op: RevokePrivilegesOp):
return f'op.execute(sa.text("""{op.grant.to_sql()}"""))'
def render_grant(_, op: GrantPrivilegesOp):
return f'op.execute(sa.text("""{op.to_sql()}"""))'


@renderers.dispatch_for(RevokePrivilegesOp)
def render_revoke(_, op: RevokePrivilegesOp):
return f'op.execute(sa.text("""{op.grant.to_sql()}"""))'
return f'op.execute(sa.text("""{op.to_sql()}"""))'
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _str_to_kind(cls):
"S": cls.sequence,
"r": cls.table,
"T": cls.type,
"v": cls.table,
}

@classmethod
Expand Down
16 changes: 10 additions & 6 deletions src/sqlalchemy_declarative_extensions/dialects/postgresql/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,20 @@ def to_sql_create(self) -> str:
def to_sql_update(self, to_role: Role) -> list[str]:
role_name = to_role.name
diff = RoleDiff.diff(self, to_role)
segments = ["ALTER ROLE", role_name, "WITH"]
segments.extend(postgres_render_role_options(diff))

alter = " ".join(segments) + ";"
result = [alter]
result = []

diff_options = postgres_render_role_options(diff)
if diff_options:
segments = ["ALTER ROLE", role_name, "WITH", *diff_options]
alter_role = " ".join(segments) + ";"
result.append(alter_role)

for add_name in diff.add_roles:
result.append(f"GRANT {add_name} to {role_name};")
result.append(f"GRANT {add_name} TO {role_name};")

for remove_name in diff.remove_roles:
result.append(f"GRANT {remove_name} to {role_name};")
result.append(f"REVOKE {remove_name} FROM {role_name};")

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _schema_not_pg(column=pg_namespace.c.nspname):
pg_authid, pg_class.c.relowner == pg_authid.c.oid
)
)
.where(pg_class.c.relkind.in_(["r", "S", "f", "n", "T"]))
.where(pg_class.c.relkind.in_(["r", "S", "f", "n", "T", "v"]))
.where(_table_not_pg)
.where(_schema_not_pg()),
select(
Expand All @@ -166,7 +166,7 @@ def _schema_not_pg(column=pg_namespace.c.nspname):
.select_from(
pg_class.join(pg_namespace, pg_class.c.relnamespace == pg_namespace.c.oid)
)
.where(pg_class.c.relkind.in_(["r", "S", "f", "n", "T"]))
.where(pg_class.c.relkind.in_(["r", "S", "f", "n", "T", "v"]))
.where(_table_not_pg)
.where(_schema_not_pg())
)
Expand Down
24 changes: 14 additions & 10 deletions src/sqlalchemy_declarative_extensions/grant/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@ class GrantPrivilegesOp:
grant: DefaultGrantStatement | GrantStatement

def reverse(self):
return RevokePrivilegesOp(self.grant.invert())
return RevokePrivilegesOp(self.grant)

def to_sql(self):
return self.grant.to_sql()


@dataclass
class RevokePrivilegesOp:
grant: DefaultGrantStatement | GrantStatement

def reverse(self):
return GrantPrivilegesOp(self.grant.invert())
return GrantPrivilegesOp(self.grant)

def to_sql(self):
return self.grant.invert().to_sql()


Operation = Union[GrantPrivilegesOp, RevokePrivilegesOp]
Expand Down Expand Up @@ -82,12 +88,11 @@ def compare_default_grants(
extra_grants = set(existing_default_grants) - set(expected_grants)

if not grants.ignore_unspecified:
revoke_statements = [extra_grant.invert() for extra_grant in extra_grants]
for revoke in DefaultGrantStatement.combine(revoke_statements):
result.append(RevokePrivilegesOp(revoke))
for grant in DefaultGrantStatement.combine(list(extra_grants)):
result.append(RevokePrivilegesOp(grant))

for grant in DefaultGrantStatement.combine(list(missing_grants)):
result.append(RevokePrivilegesOp(grant))
result.append(GrantPrivilegesOp(grant))

return result

Expand Down Expand Up @@ -141,11 +146,10 @@ def compare_object_grants(
extra_grants = set(existing_grants) - set(expected_grants)

if not grants.ignore_unspecified:
revoke_statements = [extra_grant.invert() for extra_grant in extra_grants]
for revoke in GrantStatement.combine(revoke_statements):
result.append(RevokePrivilegesOp(revoke))
for grant in GrantStatement.combine(list(extra_grants)):
result.append(RevokePrivilegesOp(grant))

for grant in GrantStatement.combine(list(missing_grants)):
result.append(RevokePrivilegesOp(grant))
result.append(GrantPrivilegesOp(grant))

return result
2 changes: 1 addition & 1 deletion src/sqlalchemy_declarative_extensions/grant/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def receive_event(metadata: MetaData, connection: Connection, **_):
roles: Optional[Roles] = metadata.info.get("roles")
result = compare_grants(connection, grants, roles=roles)
for op in result:
connection.execute(op.grant.to_sql())
connection.execute(op.to_sql())

return receive_event
42 changes: 37 additions & 5 deletions src/sqlalchemy_declarative_extensions/view/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field, replace
from typing import Callable, Container, Iterable, TypeVar, Union
from typing import Any, Callable, Container, Iterable, TypeVar, Union

from sqlalchemy import MetaData, column, select, table
from sqlalchemy.engine import Dialect
Expand All @@ -12,7 +12,7 @@

try:
from sqlalchemy.orm import DeclarativeMeta
except ImportError:
except ImportError: # pragma: no cover
from sqlalchemy.ext.orm import DeclarativeMeta # type: ignore

T = TypeVar("T", bound=Union[HasMetaData, MetaData])
Expand Down Expand Up @@ -139,6 +139,32 @@ class View:
schema: str | None = None
materialized: bool = False

@classmethod
def coerce_from_unknown(cls, unknown: Any) -> View:
if isinstance(unknown, View):
return unknown

try:
import alembic_utils # noqa
except ImportError:
pass
else:
from alembic_utils.pg_materialized_view import PGMaterializedView
from alembic_utils.pg_view import PGView

if isinstance(unknown, (PGView, PGMaterializedView)):
materialized = isinstance(unknown, PGMaterializedView)
return cls(
name=unknown.signature,
definition=unknown.definition,
schema=unknown.schema,
materialized=materialized,
)

raise NotImplementedError( # pragma: no cover
f"Unsupported view source object {unknown}"
)

@property
def qualified_name(self):
return qualify_name(self.schema, self.name)
Expand Down Expand Up @@ -197,7 +223,13 @@ def to_sql_drop(self):

@dataclass
class Views:
"""The collection of views and associated options comparisons."""
"""The collection of views and associated options comparisons.
Note, `Views` supports views being specified from certain alternative sources, such
as `alembic_utils`'s `PGView` and `PGMaterializedView`. In order for that to work,
one needs to either call `View.coerce_from_unknown(alembic_utils_view)` directly, or
use `Views().are(...)` (which internally calls `coerce_from_unknown`).
"""

views: list[View] = field(default_factory=list)

Expand All @@ -212,7 +244,7 @@ def coerce_from_unknown(
return unknown

if isinstance(unknown, Iterable):
return cls(list(unknown))
return cls().are(*unknown)

return None

Expand All @@ -224,7 +256,7 @@ def __iter__(self):
yield grant

def are(self, *views: View):
return replace(self, views=list(views))
return replace(self, views=[View.coerce_from_unknown(v) for v in views])


def find_schema(table_args=None):
Expand Down

0 comments on commit 4d822ca

Please sign in to comment.