diff --git a/.gitignore b/.gitignore index b5b2f478..cf791adf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ dist build venv* +docker-compose.yaml diff --git a/CHANGES.rst b/CHANGES.rst index 60a67d59..58a3923b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,7 @@ Version history - Fixed incorrect package name used in ``importlib.metadata.version`` for ``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz) - Prevent double pluralization (PR by @dkratzert) +- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon) **3.0.0** diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index f98fe4e9..5feb1e96 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -38,7 +38,7 @@ TypeDecorator, UniqueConstraint, ) -from sqlalchemy.dialects.postgresql import DOMAIN, JSONB +from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB from sqlalchemy.engine import Connection, Engine from sqlalchemy.exc import CompileError from sqlalchemy.sql.elements import TextClause @@ -222,7 +222,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None: if isinstance(column.type, ARRAY): self.add_import(column.type.item_type.__class__) - elif isinstance(column.type, JSONB): + elif isinstance(column.type, (JSONB, JSON)): if ( not isinstance(column.type.astext_type, Text) or column.type.astext_type.length is not None @@ -499,7 +499,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s else: return render_callable("mapped_column", *args, kwargs=kwargs) - def render_column_type(self, coltype: object) -> str: + def render_column_type(self, coltype: TypeEngine[Any]) -> str: args = [] kwargs: dict[str, Any] = {} sig = inspect.signature(coltype.__class__.__init__) @@ -515,6 +515,17 @@ def render_column_type(self, coltype: object) -> str: continue value = getattr(coltype, param.name, missing) + + if isinstance(value, (JSONB, JSON)): + # Remove astext_type if it's the default + if ( + isinstance(value.astext_type, Text) + and value.astext_type.length is None + ): + value.astext_type = None # type: ignore[assignment] + else: + self.add_import(Text) + default = defaults.get(param.name, missing) if isinstance(value, TextClause): self.add_literal_import("sqlalchemy", "text") @@ -547,7 +558,7 @@ def render_column_type(self, coltype: object) -> str: if (value := getattr(coltype, colname)) is not None: kwargs[colname] = repr(value) - if isinstance(coltype, JSONB): + if isinstance(coltype, (JSONB, JSON)): # Remove astext_type if it's the default if ( isinstance(coltype.astext_type, Text) @@ -1224,7 +1235,11 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]: return "".join(pre), column_type, "]" * post_size def render_python_type(column_type: TypeEngine[Any]) -> str: - python_type = column_type.python_type + if isinstance(column_type, DOMAIN): + python_type = column_type.data_type.python_type + else: + python_type = column_type.python_type + python_type_name = python_type.__name__ python_type_module = python_type.__module__ if python_type_module == "builtins": diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 9beb4a01..2d30782c 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -2,7 +2,9 @@ import pytest from _pytest.fixtures import FixtureRequest -from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import BIGINT, PrimaryKeyConstraint +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import JSON, JSONB from sqlalchemy.engine import Engine from sqlalchemy.schema import ( CheckConstraint, @@ -1592,3 +1594,87 @@ class WithItems(Base): str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2)) """, ) + + +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_domain_json(generator: CodeGenerator) -> None: + Table( + "test_domain_json", + generator.metadata, + Column("id", BIGINT, primary_key=True), + Column( + "foo", + postgresql.DOMAIN( + "domain_json", + JSON, + not_null=False, + ), + nullable=True, + ), + ) + + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import BigInteger +from sqlalchemy.dialects.postgresql import DOMAIN, JSON +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class TestDomainJson(Base): + __tablename__ = 'test_domain_json' + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(), not_null=False)) +""", + ) + + +@pytest.mark.parametrize( + "domain_type", + [JSONB, JSON], +) +def test_domain_non_default_json( + generator: CodeGenerator, + domain_type: type[JSON] | type[JSONB], +) -> None: + Table( + "test_domain_json", + generator.metadata, + Column("id", BIGINT, primary_key=True), + Column( + "foo", + postgresql.DOMAIN( + "domain_json", + domain_type(astext_type=Text(128)), + not_null=False, + ), + nullable=True, + ), + ) + + validate_code( + generator.generate(), + f"""\ +from typing import Optional + +from sqlalchemy import BigInteger, Text +from sqlalchemy.dialects.postgresql import DOMAIN, {domain_type.__name__} +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class TestDomainJson(Base): + __tablename__ = 'test_domain_json' + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False)) +""", + ) diff --git a/tests/test_generator_tables.py b/tests/test_generator_tables.py index fd1e544d..fe6e35e8 100644 --- a/tests/test_generator_tables.py +++ b/tests/test_generator_tables.py @@ -181,6 +181,26 @@ def test_jsonb_default(generator: CodeGenerator) -> None: ) +@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"]) +def test_json_default(generator: CodeGenerator) -> None: + Table("simple_items", generator.metadata, Column("json", postgresql.JSON)) + + validate_code( + generator.generate(), + """\ + from sqlalchemy import Column, JSON, MetaData, Table + + metadata = MetaData() + + + t_simple_items = Table( + 'simple_items', metadata, + Column('json', JSON) + ) + """, + ) + + def test_enum_detection(generator: CodeGenerator) -> None: Table( "simple_items",