From 6c50c89ed2b068012b66c763c3ce9424920f5138 Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Fri, 27 Jun 2025 01:37:28 +0300 Subject: [PATCH 1/7] Test formatting fix --- .gitignore | 1 + src/sqlacodegen/generators.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index b5b2f478..052c997a 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist build venv* +compose.yaml diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index f98fe4e9..063be2ef 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -38,7 +38,7 @@ TypeDecorator, UniqueConstraint, ) -from sqlalchemy.dialects.postgresql import DOMAIN, JSONB +from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import TextClause @@ -222,7 +222,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: if isinstance(column.type, ARRAY): self.add_import(column.type.item_type.__class__) - elif isinstance(column.type, JSONB): + elif isinstance(column.type, (JSONB, JSON)): if ( not isinstance(column.type.astext_type, Text) or column.type.astext_type.length is not None @@ -499,7 +499,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s else: return render_callable("mapped_column", *args, kwargs=kwargs) - def render_column_type(self, coltype: object) -> str: + def render_column_type(self, coltype: TypeEngine[Any]) -> str: args = [] kwargs: dict[str, Any] = {} sig = inspect.signature(coltype.__class__.__init__) @@ -515,6 +515,15 @@ def render_column_type(self, coltype: object) -> str: continue value = getattr(coltype, param.name, missing) + + if isinstance(value, (JSONB, JSON)): + # Remove astext_type if it's the default + if ( + isinstance(value.astext_type, Text) + and value.astext_type.length is None + ): + value.astext_type = None # type: ignore[assignment] + default = defaults.get(param.name, missing) if isinstance(value, TextClause): self.add_literal_import("sqlalchemy", "text") @@ -547,7 +556,7 @@ def render_column_type(self, coltype: object) -> str: if (value := getattr(coltype, colname)) is not None: kwargs[colname] = repr(value) - if isinstance(coltype, JSONB): + if isinstance(coltype, (JSONB, JSON)): # Remove astext_type if it's the default if ( isinstance(coltype.astext_type, Text) @@ -1224,7 +1233,10 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]: return "".join(pre), column_type, "]" * post_size def render_python_type(column_type: TypeEngine[Any]) -> str: - python_type = column_type.python_type + if isinstance(column_type, DOMAIN): + python_type = column_type.data_type.python_type + else: + python_type = column_type.python_type python_type_name = python_type.__name__ python_type_module = python_type.__module__ if python_type_module == "builtins": From 826392dbaca945cb13951e67149b6b8f87726e5b Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Fri, 27 Jun 2025 13:59:24 +0300 Subject: [PATCH 2/7] Added test coverage --- CHANGES.rst | 1 + tests/test_generator_declarative.py | 43 ++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 60a67d59..58a3923b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,7 @@ Version history - Fixed incorrect package name used in ``importlib.metadata.version`` for ``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz) - Prevent double pluralization (PR by @dkratzert) +- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon) **3.0.0** diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 9beb4a01..b73a0885 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -2,7 +2,9 @@ import pytest from _pytest.fixtures import FixtureRequest -from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import BIGINT, PrimaryKeyConstraint +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import JSON from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, @@ -1592,3 +1594,42 @@ class WithItems(Base): str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2)) """, ) + + +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_domain_text(generator: CodeGenerator) -> None: + Table( + "test_domain_json", + generator.metadata, + Column("id", BIGINT, primary_key=True), + Column( + "foo", + postgresql.DOMAIN( + "domain_json", + JSON, + not_null=False, + ), + nullable=True, + ), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import BigInteger +from sqlalchemy.dialects.postgresql import DOMAIN, JSON +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class TestDomainJson(Base): + __tablename__ = 'test_domain_json' + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(), not_null=False)) +""", + ) From e3275c8375938bd0a6fdc1cf44799306758e79df Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Fri, 27 Jun 2025 14:21:33 +0300 Subject: [PATCH 3/7] test name fix --- tests/test_generator_declarative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index b73a0885..901d6e23 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -1597,7 +1597,7 @@ class WithItems(Base): @pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) -def test_domain_text(generator: CodeGenerator) -> None: +def test_domain_json(generator: CodeGenerator) -> None: Table( "test_domain_json", generator.metadata, From b90a4a2874e6886e548909ed9ef895ecaec492ab Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Fri, 27 Jun 2025 14:25:10 +0300 Subject: [PATCH 4/7] Support non-default jsons --- src/sqlacodegen/generators.py | 2 ++ tests/test_generator_declarative.py | 39 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 063be2ef..9629efc5 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -523,6 +523,8 @@ def render_column_type(self, coltype: TypeEngine[Any]) -> str: and value.astext_type.length is None ): value.astext_type = None # type: ignore[assignment] + else: + self.add_import(Text) default = defaults.get(param.name, missing) if isinstance(value, TextClause): diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 901d6e23..38aa2f7b 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -1633,3 +1633,42 @@ class TestDomainJson(Base): foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(), not_null=False)) """, ) + + +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_domain_non_default_json(generator: CodeGenerator) -> None: + Table( + "test_domain_json", + generator.metadata, + Column("id", BIGINT, primary_key=True), + Column( + "foo", + postgresql.DOMAIN( + "domain_json", + JSON(astext_type=Text(128)), + not_null=False, + ), + nullable=True, + ), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import BigInteger, Text +from sqlalchemy.dialects.postgresql import DOMAIN, JSON +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class TestDomainJson(Base): + __tablename__ = 'test_domain_json' + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(astext_type=Text(length=128)), not_null=False)) +""", + ) From c520dcf6493092f8f5deed5013d0a9db95522004 Mon Sep 17 00:00:00 2001 From: Idan Sheinberg Date: Fri, 27 Jun 2025 18:12:56 +0300 Subject: [PATCH 5/7] Update src/sqlacodegen/generators.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Alex Grönholm --- src/sqlacodegen/generators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index 9629efc5..5feb1e96 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -1239,6 +1239,7 @@ def render_python_type(column_type: TypeEngine[Any]) -> str: python_type = column_type.data_type.python_type else: python_type = column_type.python_type + python_type_name = python_type.__name__ python_type_module = python_type.__module__ if python_type_module == "builtins": From e21de66ba5640abe0d60deacc484383506d73f03 Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Sat, 28 Jun 2025 02:31:30 +0300 Subject: [PATCH 6/7] PR Fixes --- tests/test_generator_declarative.py | 20 +++++++++++++------- tests/test_generator_tables.py | 20 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 38aa2f7b..2d30782c 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -4,7 +4,7 @@ from _pytest.fixtures import FixtureRequest from sqlalchemy import BIGINT, PrimaryKeyConstraint from sqlalchemy.dialects import postgresql -from sqlalchemy.dialects.postgresql import JSON +from sqlalchemy.dialects.postgresql import JSON, JSONB from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, @@ -1635,8 +1635,14 @@ class TestDomainJson(Base): ) -@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) -def test_domain_non_default_json(generator: CodeGenerator) -> None: +@pytest.mark.parametrize( + "domain_type", + [JSONB, JSON], +) +def test_domain_non_default_json( + generator: CodeGenerator, + domain_type: type[JSON] | type[JSONB], +) -> None: Table( "test_domain_json", generator.metadata, @@ -1645,7 +1651,7 @@ def test_domain_non_default_json(generator: CodeGenerator) -> None: "foo", postgresql.DOMAIN( "domain_json", - JSON(astext_type=Text(128)), + domain_type(astext_type=Text(128)), not_null=False, ), nullable=True, @@ -1654,11 +1660,11 @@ def test_domain_non_default_json(generator: CodeGenerator) -> None: validate_code( generator.generate(), - """\ + f"""\ from typing import Optional from sqlalchemy import BigInteger, Text -from sqlalchemy.dialects.postgresql import DOMAIN, JSON +from sqlalchemy.dialects.postgresql import DOMAIN, {domain_type.__name__} from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class Base(DeclarativeBase): @@ -1669,6 +1675,6 @@ class TestDomainJson(Base): __tablename__ = 'test_domain_json' id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(astext_type=Text(length=128)), not_null=False)) + foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False)) """, ) diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index fd1e544d..fe6e35e8 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -181,6 +181,26 @@ def test_jsonb_default(generator: CodeGenerator) -> None: ) +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_json_default(generator: CodeGenerator) -> None: + Table("simple_items", generator.metadata, Column("json", postgresql.JSON)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, JSON, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('json', JSON) + ) + """, + ) + + def test_enum_detection(generator: CodeGenerator) -> None: Table( "simple_items", From f64e27a73789d262596cbdc1c59c918ab84c6622 Mon Sep 17 00:00:00 2001 From: "idan.sheinberg" Date: Sat, 28 Jun 2025 02:33:22 +0300 Subject: [PATCH 7/7] PR Fixes --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 052c997a..cf791adf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,4 @@ dist build venv* -compose.yaml +docker-compose.yaml