diff --git a/.gitignore b/.gitignore index 31c4ac5ab..9c09001ca 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ downloads/ .idea/ .DS_Store .vscode/ +.zed/ eggs/ .eggs/ lib/ diff --git a/app/__main__.py b/app/__main__.py new file mode 100644 index 000000000..f403abe63 --- /dev/null +++ b/app/__main__.py @@ -0,0 +1,5 @@ +"""MultiDirectory. + +Copyright (c) 2024 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" diff --git a/app/alembic/versions/275222846605_initial_ldap_schema.py b/app/alembic/versions/275222846605_initial_ldap_schema.py index e64a0b661..273cbd984 100644 --- a/app/alembic/versions/275222846605_initial_ldap_schema.py +++ b/app/alembic/versions/275222846605_initial_ldap_schema.py @@ -400,8 +400,6 @@ def downgrade() -> None: postgresql_using="gin", postgresql_ops={"name": "gin_trgm_ops"}, ) - op.execute(sa.text("DROP EXTENSION IF EXISTS pg_trgm")) - op.drop_constraint( "object_class_must_attribute_type_uc", "ObjectClassAttributeTypeMustMemberships", diff --git a/app/alembic/versions/a7971f00ba4d_index_single_level.py b/app/alembic/versions/a7971f00ba4d_index_single_level.py new file mode 100644 index 000000000..79d68ef91 --- /dev/null +++ b/app/alembic/versions/a7971f00ba4d_index_single_level.py @@ -0,0 +1,96 @@ +"""index_single_level. + +Revision ID: a7971f00ba4d +Revises: 35d1542d2505 +Create Date: 2025-07-22 11:13:48.397808 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a7971f00ba4d" +down_revision = "35d1542d2505" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Create index for Directory depth field.""" + op.execute( + sa.text( + 'CREATE INDEX "idx_User_san_gin" ' + 'ON "Users" USING gin ("sAMAccountName" gin_trgm_ops);' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_User_upn_gin" ' + 'ON "Users" USING gin ("userPrincipalName" gin_trgm_ops);' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_User_display_name_gin" ' + 'ON "Users" USING gin ("displayName" gin_trgm_ops);' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_user_hash_dir_id" ' + 'ON "Users" USING hash ("directoryId");' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_entity_type_dir_id" ' + 'ON "Directory" USING hash ("entity_type_id");' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_group_dir_id" ' + 'ON "Groups" USING hash ("directoryId");' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_Directory_depth_hash" ' + 'ON "Directory" ' + "USING hash (depth);" + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_composite_attributes_directory_id_name" ' + 'ON "Attributes" ("directoryId", lower("name"));' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_attributes_value" ' + 'ON "Attributes" USING gin ("value" gin_trgm_ops);' + ) + ) + op.execute( + sa.text( + 'CREATE INDEX "idx_attributes_name_value_trgm" ' + 'ON "Attributes" USING gin ' + '("name" gin_trgm_ops, "value" gin_trgm_ops);' + ) + ) + + +def downgrade() -> None: + """Remove indexes for Directory depth field and Attributes table.""" + op.drop_index("idx_User_san_gin", "Users") + op.drop_index("idx_User_upn_gin", "Users") + op.drop_index("idx_User_display_name_gin", "Users") + op.drop_index("idx_user_hash_dir_id", "Users") + op.drop_index("idx_entity_type_dir_id", "Directory") + op.drop_index("idx_group_dir_id", "Groups") + op.drop_index("idx_Directory_depth_hash", "Directory") + op.drop_index("idx_composite_attributes_directory_id_name", "Attributes") + op.drop_index("idx_attributes_value", "Attributes") + op.drop_index("idx_attributes_name_value_trgm", "Attributes") diff --git a/app/config.py b/app/config.py index ea12ccede..16e2f0d9d 100644 --- a/app/config.py +++ b/app/config.py @@ -164,7 +164,6 @@ def engine(self) -> AsyncEngine: str(self.POSTGRES_URI), poolclass=NullPool, future=True, - pool_pre_ping=True, ) @classmethod diff --git a/app/ioc.py b/app/ioc.py index 6803ca355..97572ee79 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -1,4 +1,4 @@ -"""DI Provider MiltiDirecory module. +"""DI Provider MultiDirecory module. Copyright (c) 2024 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE @@ -195,6 +195,10 @@ async def get_session_storage( settings.SESSION_KEY_EXPIRE_SECONDS, ) + entity_type_dao = provide(EntityTypeDAO, scope=Scope.REQUEST) + object_class_dao = provide(ObjectClassDAO, scope=Scope.REQUEST) + attribute_type_dao = provide(AttributeTypeDAO, scope=Scope.REQUEST) + class HTTPProvider(Provider): """HTTP LDAP session.""" @@ -206,32 +210,6 @@ async def get_session(self) -> LDAPSession: """Create ldap session.""" return LDAPSession() - @provide(scope=Scope.REQUEST) - async def get_attribute_type_dao( - self, - session: AsyncSession, - ) -> AttributeTypeDAO: - """Get Attribute Type DAO.""" - return AttributeTypeDAO(session) - - @provide(scope=Scope.REQUEST) - async def get_object_class_dao( - self, - attribute_type_dao: AttributeTypeDAO, - session: AsyncSession, - ) -> ObjectClassDAO: - """Get Object Class DAO.""" - return ObjectClassDAO(session, attribute_type_dao) - - @provide(scope=Scope.REQUEST) - async def get_entity_type_dao( - self, - object_class_dao: ObjectClassDAO, - session: AsyncSession, - ) -> EntityTypeDAO: - """Get Entity Type DAO.""" - return EntityTypeDAO(session, object_class_dao) - identity_fastapi_adapter = provide( IdentityFastAPIAdapter, scope=Scope.REQUEST, diff --git a/app/ldap_protocol/dependency.py b/app/ldap_protocol/dependency.py index 8ec0e8448..c067560ab 100644 --- a/app/ldap_protocol/dependency.py +++ b/app/ldap_protocol/dependency.py @@ -1,4 +1,4 @@ -"""DI Resolver MiltiDirecory module. +"""DI Resolver MultiDirecory module. Copyright (c) 2024 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE diff --git a/app/ldap_protocol/filter_interpreter.py b/app/ldap_protocol/filter_interpreter.py index 5a883365f..23043b0d3 100644 --- a/app/ldap_protocol/filter_interpreter.py +++ b/app/ldap_protocol/filter_interpreter.py @@ -13,7 +13,11 @@ from ldap_filter import Filter from sqlalchemy import and_, func, not_, or_, select -from sqlalchemy.sql.elements import ColumnElement, UnaryExpression +from sqlalchemy.sql.elements import ( + BinaryExpression, + ColumnElement, + UnaryExpression, +) from models import Attribute, Directory, EntityType, Group, User @@ -28,6 +32,20 @@ } +def _get_filter_condition( + attr: str, + condition: BinaryExpression | None = None, +) -> ColumnElement: + if condition is None: + f = Directory.attributes.any(Attribute.name.ilike(attr)) + else: + f = Directory.attributes.any( + and_(Attribute.name.ilike(attr), condition) + ) + + return f + + def _get_substring(right: ASN1Row) -> str: # RFC 4511 expr = right.value[0] value = expr.value @@ -151,7 +169,7 @@ def _cast_item(item: ASN1Row) -> UnaryExpression | ColumnElement: if attr in Directory.search_fields: return not_(eq(getattr(Directory, attr), None)) - return Directory.attributes.any(Attribute.name.ilike(item.value)) + return _get_filter_condition(attr) if ( len(item.value) == 3 @@ -172,7 +190,7 @@ def _cast_item(item: ASN1Row) -> UnaryExpression | ColumnElement: elif attr in MEMBERS_ATTRS: # NOTE: without oid return _ldap_filter_by_attribute(None, left, right) elif attr == "entitytypename": - return func.lower(EntityType.name) == right + return func.lower(EntityType.name) == right.lower() else: if is_substring: cond = Attribute.value.ilike(_get_substring(right)) @@ -182,7 +200,7 @@ def _cast_item(item: ASN1Row) -> UnaryExpression | ColumnElement: else: cond = Attribute.bvalue == right.value - return Directory.attributes.any(and_(Attribute.name.ilike(attr), cond)) + return _get_filter_condition(attr, cond) def cast_filter2sql(expr: ASN1Row) -> UnaryExpression | ColumnElement: @@ -229,7 +247,7 @@ def _cast_filt_item(item: Filter) -> UnaryExpression | ColumnElement: if item.attr in Directory.search_fields: return not_(eq(getattr(Directory, item.attr), None)) - return Directory.attributes.any(Attribute.name.ilike(item.attr)) + return _get_filter_condition(item.attr) is_substring = item.val.startswith("*") or item.val.endswith("*") @@ -240,16 +258,14 @@ def _cast_filt_item(item: Filter) -> UnaryExpression | ColumnElement: elif item.attr in MEMBERS_ATTRS: return _api_filter(item) elif item.attr == "entitytypename": - return func.lower(EntityType.name) == item.val + return func.lower(EntityType.name) == item.val.lower() else: if is_substring: cond = Attribute.value.ilike(item.val.replace("*", "%")) else: cond = Attribute.value.ilike(item.val) - return Directory.attributes.any( - and_(Attribute.name.ilike(item.attr), cond), - ) + return _get_filter_condition(item.attr, cond) def cast_str_filter2sql(expr: Filter) -> UnaryExpression | ColumnElement: diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index f4706cea4..cee3b893a 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -15,7 +15,7 @@ from sqlalchemy import func, or_ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import joinedload, selectinload, with_loader_criteria +from sqlalchemy.orm import selectinload, with_loader_criteria from sqlalchemy.sql.elements import ColumnElement, UnaryExpression from sqlalchemy.sql.expression import Select @@ -307,6 +307,9 @@ async def get_result( query, pages_total, count = await self.paginate_query(query, session) + if self.size_limit != 0: + query = query.limit(self.size_limit) + async for response in self.tree_view(query, session): yield response @@ -342,7 +345,10 @@ def _mutate_query_with_attributes_to_load( ) -> Select: """Get attributes to load.""" if self.entity_type_name: - query = query.options(selectinload(Directory.entity_type)) + query = ( + query.join(Directory.entity_type) + .options(selectinload(Directory.entity_type)) + ) # fmt: skip if self.all_attrs: return query.options(selectinload(Directory.attributes)) @@ -370,9 +376,8 @@ def build_query( query = ( select(Directory) .join(Directory.user, isouter=True) - .join(Directory.group, isouter=True) - .join(Directory.entity_type, isouter=True) - ) # fmt: skip + .options(selectinload(Directory.group)) + ) query = self._mutate_query_with_attributes_to_load(query) query = mutate_ap(query, user) @@ -402,9 +407,9 @@ def build_query( elif self.scope == Scope.SINGLE_LEVEL: query = query.filter( - func.cardinality(Directory.path) == len(search_path) + 1, + Directory.depth == len(search_path) + 1, get_path_filter( - column=Directory.path[0 : len(search_path)], + column=Directory.path[1 : len(search_path)], path=search_path, ), ) @@ -419,7 +424,7 @@ def build_query( if self.member: query = query.options( - joinedload(Directory.group).selectinload(Group.members) + selectinload(Directory.group).selectinload(Group.members) ) if self.member_of or self.token_groups: @@ -448,8 +453,7 @@ async def paginate_query( count = (await session.scalars(count_q)).one() start = (self.page_number - 1) * self.size_limit - end = start + self.size_limit - query = query.offset(start).limit(end) + query = query.offset(start).limit(self.size_limit) return query, int(ceil(count / float(self.size_limit))), count diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 493feb1d3..45be5d67d 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -144,6 +144,10 @@ from zoneinfo import ZoneInfo from loguru import logger +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import DDLCompiler +from sqlalchemy.sql.expression import ClauseElement, Executable, Visitable from models import Directory @@ -324,3 +328,42 @@ async def wrapper(*args, **kwargs) -> object: # type: ignore return result return wrapper + + +class explain(Executable, ClauseElement): # noqa: N801 + """EXPLAIN statement for PostgreSQL.""" + + inherit_cache = False + + def __init__( + self, + stmt: Visitable, + analyze: bool = False, + ) -> None: + """Initialize EXPLAIN statement.""" + self.statement = stmt + self.analyze = analyze + + +@compiles(explain, "postgresql") +def pg_explain(element: explain, compiler: DDLCompiler, **kw: dict) -> str: + """Compile EXPLAIN statement for PostgreSQL.""" + text = "EXPLAIN " + if element.analyze: + text += "ANALYZE " + text += compiler.process(element.statement, **kw) + + return text + + +async def explain_query( + query: Visitable, + session: AsyncSession, +) -> None: + """Get explain query.""" + logger.debug( + "\n".join( + row[0] + for row in await session.execute(explain(query, analyze=True)) + ) + ) diff --git a/app/models.py b/app/models.py index f94c0e369..6a76ca57c 100644 --- a/app/models.py +++ b/app/models.py @@ -282,7 +282,7 @@ def object_class_names_set(self) -> set[str]: onupdate=func.now(), nullable=True, ) - depth: Mapped[int] + depth: Mapped[int] = mapped_column(index=True) object_sid: Mapped[str] = mapped_column("objectSid") objectsid: Mapped[str] = synonym("object_sid") diff --git a/app/multidirectory.py b/app/multidirectory.py index 56a6c30bf..2c60cef12 100644 --- a/app/multidirectory.py +++ b/app/multidirectory.py @@ -1,4 +1,4 @@ -"""Main MiltiDirecory module. +"""Main MultiDirecory module. Copyright (c) 2024 MultiFactor License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE