diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3937190..57300b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,7 +56,7 @@ jobs: strategy: matrix: python: - - "3.7" + - "3.8" - "3.10" platform: - ubuntu-latest diff --git a/.gitignore b/.gitignore index e9e1e9b..ff7b152 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ MANIFEST .venv*/ .conda*/ .python-version +notebooks diff --git a/README.md b/README.md index efff2d7..f27fe99 100644 --- a/README.md +++ b/README.md @@ -30,14 +30,14 @@ pip install fastapi-rowsecurity ## Basic Usage -In your SQLAlchemy model, create a `classmethod` named `__rls_policies__` that returns a list of `Permissive` or `Restrictive` policies: +In your SQLAlchemy model, create an attribute named `__rls_policies__` that is a list of `Permissive` or `Restrictive` policies: ```py -from fastapi_rowsecurity import Permissive, set_rls_policies +from fastapi_rowsecurity import Permissive, register_rls from fastapi_rowsecurity.principals import Authenticated, UserOwner Base = declarative_base() -set_rls_policies(Base) # <- create all policies +register_rls(Base) # <- create all policies class Item(Base): @@ -48,18 +48,17 @@ class Item(Base): owner_id = Column(Integer, ForeignKey("users.id")) owner = relationship("User", back_populates="items") - @classmethod - def __rls_policies__(cls): - return [ - Permissive(principal=Authenticated, policy="SELECT"), - Permissive(principal=UserOwner, policy=["INSERT", "UPDATE", "DELETE"]), + + __rls_policies__ = [ + Permissive(expr=Authenticated, cmd="SELECT"), + Permissive(expr=UserOwner, cmd=["INSERT", "UPDATE", "DELETE"]), ] ``` The above implies that any authenticated user can read all items; but can only insert, update or delete owned items. -- `principal`: any Boolean expression as a string; -- `policy`: any of `ALL`/`SELECT`/`INSERT`/`UPDATE`/`DELETE`. +- `expr`: any Boolean expression as a string; +- `cmd`: any command of `ALL`/`SELECT`/`INSERT`/`UPDATE`/`DELETE`. Next, attach the `current_user_id` (or other [runtime parameters](https://www.postgresql.org/docs/current/sql-set.html) that you need) to the user session: @@ -78,9 +77,9 @@ Find a simple example in the ![tests](./tests/simple_model.py). then ... - [ ] Support for Alembic -- [ ] How to deal with `BYPASSRLS` such as table owners? - [ ] When item is tried to delete, no error is raised? - [ ] Python 3.11 +- [ ] Coverage report ## Final note diff --git a/setup.cfg b/setup.cfg index a81cb15..754803f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,18 +5,18 @@ [metadata] name = fastapi-rowsecurity -description = Add a short description here! +description = Row-Level Security (RLS) in SQLAlchemy. author = jwdobken author_email = joost@dobken.nl license = MIT license_files = LICENSE.txt long_description = file: README.md long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -url = https://github.com/pyscaffold/pyscaffold/ +url = https://github.com/JWDobken/fastapi-rowsecurity # Add here related links, for example: project_urls = - Documentation = https://pyscaffold.org/ -# Source = https://github.com/pyscaffold/pyscaffold/ + Documentation = https://github.com/JWDobken/fastapi-rowsecurity + Source = https://github.com/JWDobken/fastapi-rowsecurity # Changelog = https://pyscaffold.org/en/latest/changelog.html # Tracker = https://github.com/pyscaffold/pyscaffold/issues # Conda-Forge = https://anaconda.org/conda-forge/pyscaffold @@ -49,8 +49,8 @@ package_dir = # For more information, check out https://semver.org/. install_requires = importlib-metadata; python_version<"3.8" - alembic_utils>=0.8 pydantic>=2.5 + sqlalchemy [options.packages.find] @@ -72,6 +72,7 @@ testing = faker asyncpg greenlet + alembic [options.entry_points] # Add here console scripts like: diff --git a/src/fastapi_rowsecurity/__init__.py b/src/fastapi_rowsecurity/__init__.py index 8c8253f..5ee9f08 100644 --- a/src/fastapi_rowsecurity/__init__.py +++ b/src/fastapi_rowsecurity/__init__.py @@ -15,7 +15,7 @@ finally: del version, PackageNotFoundError -from .main import set_rls_policies +from .register_rls import register_rls from .schemas import Permissive, Policy, Restrictive -__all__ = ["set_rls_policies", "Permissive", "Policy", "Restrictive"] +__all__ = ["register_rls", "Permissive", "Policy", "Restrictive"] diff --git a/src/fastapi_rowsecurity/create_policies.py b/src/fastapi_rowsecurity/create_policies.py new file mode 100644 index 0000000..9614790 --- /dev/null +++ b/src/fastapi_rowsecurity/create_policies.py @@ -0,0 +1,24 @@ +from typing import Type + +from sqlalchemy import text +from sqlalchemy.engine import Connection +from sqlalchemy.ext.declarative import DeclarativeMeta + + +def create_policies(Base: Type[DeclarativeMeta], connection: Connection): + """Create policies for `Base.metadata.create_all()`.""" + for table, settings in Base.metadata.info["rls_policies"].items(): + # enable + stmt = text(f"ALTER TABLE {table} ENABLE ROW LEVEL SECURITY;") + connection.execute(stmt) + # force by default + stmt = text(f"ALTER TABLE {table} FORCE ROW LEVEL SECURITY;") + connection.execute(stmt) + # policies + print("SETTINGS", settings) + for ix, policy in enumerate(settings): + for pol_stmt in policy.get_sql_policies( + table_name=table, name_suffix=str(ix) + ): + connection.execute(pol_stmt) + connection.commit() diff --git a/src/fastapi_rowsecurity/functions.py b/src/fastapi_rowsecurity/functions.py deleted file mode 100644 index 3286c7e..0000000 --- a/src/fastapi_rowsecurity/functions.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List - -from alembic_utils.pg_function import PGFunction - -drop_all_policies_on_table = PGFunction( - schema="public", - signature="drop_all_policies_on_table(target_schema text,target_table_name text)", - definition=""" -RETURNS void LANGUAGE plpgsql AS $$ - DECLARE - policy_name text; - sql_text text; - BEGIN - FOR policy_name IN ( - SELECT policyname - FROM pg_policies - WHERE schemaname = target_schema AND tablename = target_table_name - ) - LOOP - sql_text := format('DROP POLICY "%s" on %I.%I', - policy_name, target_schema, target_table_name); - RAISE NOTICE '%', sql_text; - EXECUTE sql_text; - END LOOP; -END $$; -""", -) - - -def get_functions() -> List[PGFunction]: - return [] diff --git a/src/fastapi_rowsecurity/main.py b/src/fastapi_rowsecurity/main.py deleted file mode 100644 index c7fd941..0000000 --- a/src/fastapi_rowsecurity/main.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Type - -from sqlalchemy import event -from sqlalchemy.ext.declarative import DeclarativeMeta - -from .policies import get_policies - - -def set_rls_policies(Base: Type[DeclarativeMeta]): - @event.listens_for(Base.metadata, "after_create") - def receive_after_create(target, connection, tables, **kw): - policies = get_policies(Base) - for ent in policies: - connection.execute(ent.to_sql_statement_create()) - connection.commit() diff --git a/src/fastapi_rowsecurity/policies.py b/src/fastapi_rowsecurity/policies.py deleted file mode 100644 index d5bf614..0000000 --- a/src/fastapi_rowsecurity/policies.py +++ /dev/null @@ -1,61 +0,0 @@ -from typing import List - -from alembic_utils.pg_policy import PGPolicy - -from .rls_entity import EnableRowLevelSecurity, ForceRowLevelSecurity - - -def get_policies(Base) -> List[PGPolicy]: - policy_lists = [] - for mapper in Base.registry.mappers: - if not hasattr(mapper.class_, "__rls_policies__"): - continue - table_name = mapper.tables[0].fullname - schema_name = mapper.tables[0].schema or "public" - # Set the default row-level security policy - policy_lists.append( - EnableRowLevelSecurity(schema=schema_name, on_entity=table_name) - ) - policy_lists.append( - ForceRowLevelSecurity(schema=schema_name, on_entity=table_name) - ) - for ix, permission in enumerate(mapper.class_.__rls_policies__()): - table_policies = ( - [permission.policy] - if isinstance(permission.policy, str) - else permission.policy - ) - for pol in table_policies: - policy_name = ( - f"{table_name}_{permission.__class__.__name__}" - f"_{pol}_policy_{ix}".lower() - ) - if pol in ["ALL", "SELECT", "UPDATE", "DELETE"]: - policy_lists.append( - PGPolicy( - schema=schema_name, - signature=policy_name, - on_entity=table_name, - definition=f""" - AS {permission.__class__.__name__.upper()} - FOR {pol} - USING ({permission.principal}) - """, - ) - ) - elif pol in ["INSERT"]: - policy_lists.append( - PGPolicy( - schema=schema_name, - signature=policy_name, - on_entity=table_name, - definition=f""" - AS {permission.__class__.__name__.upper()} - FOR {pol} - WITH CHECK ({permission.principal}) - """, - ) - ) - else: - raise ValueError(f'Unknown policy "{pol}"') - return policy_lists diff --git a/src/fastapi_rowsecurity/register_rls.py b/src/fastapi_rowsecurity/register_rls.py new file mode 100644 index 0000000..00372d5 --- /dev/null +++ b/src/fastapi_rowsecurity/register_rls.py @@ -0,0 +1,169 @@ +from typing import Type + +from alembic.autogenerate import comparators, renderers +from alembic.operations import MigrateOperation, Operations +from sqlalchemy import event, text +from sqlalchemy.ext.declarative import DeclarativeMeta + +from .create_policies import create_policies + +############################ +# OPERATIONS +############################ + + +@Operations.register_operation("enable_rls") +class EnableRlsOp(MigrateOperation): + """Enable RowLevelSecurity.""" + + def __init__(self, tablename, schemaname=None): + self.tablename = tablename + self.schemaname = schemaname + + @classmethod + def enable_rls(cls, operations, tablename, **kw): + """Issue a "CREATE SEQUENCE" instruction.""" + + op = EnableRlsOp(tablename, **kw) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return DisableRlsOp(self.tablename, schemaname=self.schemaname) + + +@Operations.register_operation("disable_rls") +class DisableRlsOp(MigrateOperation): + """Drop a SEQUENCE.""" + + def __init__(self, tablename, schemaname=None): + self.tablename = tablename + self.schemaname = schemaname + + @classmethod + def disable_rls(cls, operations, tablename, **kw): + """Issue a "DROP SEQUENCE" instruction.""" + + op = DisableRlsOp(tablename, **kw) + return operations.invoke(op) + + def reverse(self): + # only needed to support autogenerate + return EnableRlsOp(self.tablename, schemaname=self.schemaname) + + +############################ +# IMPLEMENTATION +############################ + + +@Operations.implementation_for(EnableRlsOp) +def enable_rls(operations, operation): + if operation.schemaname is not None: + name = "%s.%s" % (operation.schemaname, operation.tablename) + else: + name = operation.tablename + operations.execute("ALTER TABLE %s ENABLE ROW LEVEL SECURITY" % name) + + +@Operations.implementation_for(DisableRlsOp) +def disable_rls(operations, operation): + if operation.schemaname is not None: + name = "%s.%s" % (operation.schemaname, operation.sequence_name) + else: + name = operation.tablename + operations.execute("ALTER TABLE %s DISABLE ROW LEVEL SECURITY" % name) + + +############################ +# RENDER +############################ + + +@renderers.dispatch_for(EnableRlsOp) +def render_enable_rls(autogen_context, op): + return "op.enable_rls(%r, **%r)" % (op.tablename, {"schemaname": op.schemaname}) + + +@renderers.dispatch_for(DisableRlsOp) +def render_disable_rls(autogen_context, op): + return "op.disable_rls(%r, **%r)" % (op.tablename, {"schemaname": op.schemaname}) + + +############################ +# COMPARATORS +############################ + + +def check_table_exists(conn, schemaname, tablename) -> bool: + result = conn.execute( + text( + f"""SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = '{schemaname if schemaname else "public"}' + AND table_name = '{tablename}' +);""" + ) + ).scalar() + return result + + +def check_rls_enabled(conn, schemaname, tablename) -> bool: + result = conn.execute( + text( + f"""select relrowsecurity + from pg_class + where oid = '{tablename}'::regclass;""" + ) + ).scalar() + return result + + +@comparators.dispatch_for("table") +def compare_table_level( + autogen_context, modify_ops, schemaname, tablename, conn_table, metadata_table +): + # STEP 1. check table exists and RLS is enabled + table_exists = check_table_exists(autogen_context.connection, schemaname, tablename) + rls_enabled_db = ( + check_rls_enabled(autogen_context.connection, schemaname, tablename) + if table_exists + else False + ) + + # STEP 2. check if RLS should be enabled + rls_enabled_meta = tablename in metadata_table.metadata.info["rls_policies"] + + # STEP 3. apply + if rls_enabled_meta and not rls_enabled_db: + modify_ops.ops.append(EnableRlsOp(tablename=tablename, schemaname=schemaname)) + if rls_enabled_db and not rls_enabled_meta: + modify_ops.ops.append(DisableRlsOp(tablename=tablename, schemaname=schemaname)) + + +def set_metadata_info(Base: Type[DeclarativeMeta]): + """RLS policies are first added to the Metadata before applied.""" + Base.metadata.info.setdefault("rls_policies", dict()) + for mapper in Base.registry.mappers: + if not hasattr(mapper.class_, "__rls_policies__"): + continue + Base.metadata.info["rls_policies"][ + mapper.tables[0].fullname + ] = mapper.class_.__rls_policies__ + # [ + # p.model_dump(mode="json") for p in mapper.class_.__rls_policies__ + # ] + + +def register_rls(Base: Type[DeclarativeMeta]): + + # required for `alembic revision --autogenerate`` + set_metadata_info(Base) + + @event.listens_for(Base.metadata, "after_create") + def receive_after_create(target, connection, tables, **kw): + + # required for `Base.metadata.create_all()` + set_metadata_info(Base) + create_policies(Base, connection) diff --git a/src/fastapi_rowsecurity/rls_entity.py b/src/fastapi_rowsecurity/rls_entity.py deleted file mode 100644 index 8b7dd13..0000000 --- a/src/fastapi_rowsecurity/rls_entity.py +++ /dev/null @@ -1,29 +0,0 @@ -from alembic_utils.on_entity_mixin import OnEntityMixin -from alembic_utils.replaceable_entity import ReplaceableEntity -from sqlalchemy import text as sql_text - - -class EnableRowLevelSecurity(OnEntityMixin, ReplaceableEntity): - def __init__(self, schema: str, on_entity: str, **kwargs): - super().__init__( - schema=schema, on_entity=on_entity, definition="", signature="" - ) - - def to_sql_statement_create(self): - return sql_text(f"ALTER TABLE {self.on_entity} ENABLE ROW LEVEL SECURITY;") - - def to_sql_statement_drop(self): - return sql_text(f"ALTER TABLE {self.on_entity} DISABLE ROW LEVEL SECURITY;") - - -class ForceRowLevelSecurity(OnEntityMixin, ReplaceableEntity): - def __init__(self, schema: str, on_entity: str, **kwargs): - super().__init__( - schema=schema, on_entity=on_entity, definition="", signature="" - ) - - def to_sql_statement_create(self): - return sql_text(f"ALTER TABLE {self.on_entity} FORCE ROW LEVEL SECURITY;") - - def to_sql_statement_drop(self): - return sql_text(f"ALTER TABLE {self.on_entity} DISABLE ROW LEVEL SECURITY;") diff --git a/src/fastapi_rowsecurity/schemas.py b/src/fastapi_rowsecurity/schemas.py index 9e58745..ac6cc23 100644 --- a/src/fastapi_rowsecurity/schemas.py +++ b/src/fastapi_rowsecurity/schemas.py @@ -1,10 +1,11 @@ from enum import Enum -from typing import List, Union +from typing import List, Literal, Union from pydantic import BaseModel +from sqlalchemy import text -class Policy(str, Enum): +class Command(str, Enum): # policies: https://www.postgresql.org/docs/current/sql-createpolicy.html all = "ALL" select = "SELECT" @@ -13,11 +14,48 @@ class Policy(str, Enum): delete = "DELETE" -class Permissive(BaseModel): - principal: str - policy: Union[Policy, List[Policy]] +class Policy(BaseModel): + definition: str + expr: str + cmd: Union[Command, List[Command]] + def get_sql_policies(self, table_name: str, name_suffix: str = "0"): + commands = [self.cmd] if isinstance(self.cmd, str) else self.cmd + policy_lists = [] + for cmd in commands: + policy_name = ( + f"{table_name}_{self.definition}" f"_{cmd}_policy_{name_suffix}".lower() + ) + if cmd in ["ALL", "SELECT", "UPDATE", "DELETE"]: + policy_lists.append( + text( + f""" + CREATE POLICY {policy_name} ON {table_name} + AS {self.definition} + FOR {cmd} + USING ({self.expr}) + """ + ) + ) + elif cmd in ["INSERT"]: + policy_lists.append( + text( + f""" + CREATE POLICY {policy_name} ON {table_name} + AS {self.definition} + FOR {cmd} + WITH CHECK ({self.expr}) + """ + ) + ) + else: + raise ValueError(f'Unknown policy command"{cmd}"') + return policy_lists -class Restrictive(BaseModel): - principal: str - policy: Union[Policy, List[Policy]] + +class Permissive(Policy): + definition: Literal["PERMISSIVE"] = "PERMISSIVE" + + +class Restrictive(Policy): + definition: Literal["RESTRICTIVE"] = "RESTRICTIVE" diff --git a/tests/simple_model.py b/tests/simple_model.py index 6704f50..2ae5c4e 100644 --- a/tests/simple_model.py +++ b/tests/simple_model.py @@ -1,11 +1,11 @@ from sqlalchemy import Column, ForeignKey, Integer, String from sqlalchemy.orm import declarative_base, relationship -from fastapi_rowsecurity import Permissive, set_rls_policies +from fastapi_rowsecurity import Permissive, register_rls from fastapi_rowsecurity.principals import Authenticated, UserOwner Base = declarative_base() -set_rls_policies(Base) +register_rls(Base) class User(Base): @@ -24,9 +24,7 @@ class Item(Base): owner_id = Column(Integer, ForeignKey("users.id")) owner = relationship("User", back_populates="items") - @classmethod - def __rls_policies__(cls): - return [ - Permissive(principal=Authenticated, policy="SELECT"), - Permissive(principal=UserOwner, policy=["INSERT", "UPDATE", "DELETE"]), - ] + __rls_policies__ = [ + Permissive(expr=Authenticated, cmd="SELECT"), + Permissive(expr=UserOwner, cmd=["INSERT", "UPDATE", "DELETE"]), + ]