From 5913fc92cc08f7ef059ae61368caa770860d1096 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Sun, 30 Nov 2025 21:00:43 +0000 Subject: [PATCH] Simplify params for SQL params Per https://docs.djangoproject.com/en/6.0/releases/6.0/\#custom-orm-expressions-should-return-params-as-a-tuple, ensure we always use a tuple for params, and the simpler unpacking syntax for combining them. --- src/django_mysql/models/expressions.py | 30 +++++++++++------------ src/django_mysql/models/fields/dynamic.py | 2 +- src/django_mysql/models/fields/lists.py | 10 +++++--- src/django_mysql/models/lookups.py | 28 ++++++++++++--------- src/django_mysql/models/transforms.py | 3 +-- tests/testapp/test_dynamicfield.py | 2 +- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/src/django_mysql/models/expressions.py b/src/django_mysql/models/expressions.py index f94a8adc..beeea019 100644 --- a/src/django_mysql/models/expressions.py +++ b/src/django_mysql/models/expressions.py @@ -72,10 +72,10 @@ def as_sql( field, field_params = compiler.compile(self.lhs) value, value_params = compiler.compile(self.rhs) - sql = self.sql_expression % (field, value) - params = tuple(value_params) + tuple(field_params) - - return sql, params + return ( + self.sql_expression % (field, value), + (*value_params, *field_params), + ) class AppendLeftListF(TwoSidedExpression): @@ -148,8 +148,7 @@ def as_sql( ) -> tuple[str, tuple[Any, ...]]: field, field_params = compiler.compile(self.lhs) - sql = self.sql_expression % (field) - return sql, tuple(field_params) + return (self.sql_expression % (field), field_params) class PopLeftListF(BaseExpression): @@ -180,8 +179,7 @@ def as_sql( ) -> tuple[str, tuple[Any, ...]]: field, field_params = compiler.compile(self.lhs) - sql = self.sql_expression % (field) - return sql, tuple(field_params) + return (self.sql_expression % (field), field_params) class SetF: @@ -227,10 +225,10 @@ def as_sql( field, field_params = compiler.compile(self.lhs) value, value_params = compiler.compile(self.rhs) - sql = self.sql_expression % (value, field) - params = tuple(value_params) + tuple(field_params) - - return sql, params + return ( + self.sql_expression % (value, field), + (*value_params, *field_params), + ) class RemoveSetF(TwoSidedExpression): @@ -280,7 +278,7 @@ def as_sql( field, field_params = compiler.compile(self.lhs) value, value_params = compiler.compile(self.rhs) - sql = self.sql_expression % (value, field) - params = tuple(value_params) + tuple(field_params) - - return sql, params + return ( + self.sql_expression % (value, field), + (*value_params, *field_params), + ) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 1e27fb94..e8c1bf00 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -342,7 +342,7 @@ def as_sql( lhs, params = compiler.compile(self.lhs) return ( f"COLUMN_GET({lhs}, %s AS {self.data_type})", - tuple(params) + (self.key_name,), + (*params, self.key_name), ) diff --git a/src/django_mysql/models/fields/lists.py b/src/django_mysql/models/fields/lists.py index 11f2ab8a..c59bfa93 100644 --- a/src/django_mysql/models/fields/lists.py +++ b/src/django_mysql/models/fields/lists.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Any, cast from django.core import checks @@ -216,12 +216,14 @@ def __init__(self, index: int, *args: Any, **kwargs: Any) -> None: def as_sql( self, qn: Callable[[str], str], connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) - params = tuple(lhs_params) + tuple(rhs_params) # Put rhs on the left since that's the order FIND_IN_SET uses - return f"(FIND_IN_SET({rhs}, {lhs}) = {self.index})", params + return ( + f"(FIND_IN_SET({rhs}, {lhs}) = {self.index})", + (*lhs_params, *rhs_params), + ) class IndexLookupFactory: diff --git a/src/django_mysql/models/lookups.py b/src/django_mysql/models/lookups.py index ba418670..102e96e5 100644 --- a/src/django_mysql/models/lookups.py +++ b/src/django_mysql/models/lookups.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Any from django.db.backends.base.base import BaseDatabaseWrapper @@ -23,11 +23,13 @@ def as_sql( self, qn: Callable[[str], str], connection: BaseDatabaseWrapper, - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) - params = tuple(lhs_params) + tuple(rhs_params) - return f"{lhs} SOUNDS LIKE {rhs}", params + return ( + f"{lhs} SOUNDS LIKE {rhs}", + (*lhs_params, *rhs_params), + ) class Soundex(Transform): @@ -36,7 +38,7 @@ class Soundex(Transform): def as_sql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, params = compiler.compile(self.lhs) return f"SOUNDEX({lhs})", params @@ -62,12 +64,14 @@ def get_prep_lookup(self) -> Any: def as_sql( self, qn: Callable[[str], str], connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) # Put rhs (and params) on the left since that's the order FIND_IN_SET uses - params = tuple(rhs_params) + tuple(lhs_params) - return f"FIND_IN_SET({rhs}, {lhs})", params + return ( + f"FIND_IN_SET({rhs}, {lhs})", + (*rhs_params, *lhs_params), + ) class SetIContains(SetContains): @@ -82,8 +86,10 @@ class DynColHasKey(Lookup): def as_sql( self, qn: Callable[[str], str], connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, lhs_params = self.process_lhs(qn, connection) rhs, rhs_params = self.process_rhs(qn, connection) - params = tuple(lhs_params) + tuple(rhs_params) - return f"COLUMN_EXISTS({lhs}, {rhs})", params + return ( + f"COLUMN_EXISTS({lhs}, {rhs})", + (*lhs_params, *rhs_params), + ) diff --git a/src/django_mysql/models/transforms.py b/src/django_mysql/models/transforms.py index 8e38b283..a9e4f8ad 100644 --- a/src/django_mysql/models/transforms.py +++ b/src/django_mysql/models/transforms.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Iterable from typing import Any from django.db.backends.base.base import BaseDatabaseWrapper @@ -25,6 +24,6 @@ class SetLength(Transform): def as_sql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper - ) -> tuple[str, Iterable[Any]]: + ) -> tuple[str, tuple[Any, ...]]: lhs, params = compiler.compile(self.lhs) return self.expr % (lhs, lhs, lhs), params diff --git a/tests/testapp/test_dynamicfield.py b/tests/testapp/test_dynamicfield.py index 31bd7c3a..605f48d2 100644 --- a/tests/testapp/test_dynamicfield.py +++ b/tests/testapp/test_dynamicfield.py @@ -111,7 +111,7 @@ class DumbTransform(Transform): def as_sql(self, compiler, connection): lhs, params = compiler.compile(self.lhs) - return "%s", ["dumb"] + return "%s", ("dumb",) DynamicField.register_lookup(DumbTransform)