Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
dist
build
venv*
docker-compose.yaml
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Version history
- Fixed incorrect package name used in ``importlib.metadata.version`` for
``sqlalchemy-citext``, resolving ``PackageNotFoundError`` (PR by @oaimtiaz)
- Prevent double pluralization (PR by @dkratzert)
- Fixes DOMAIN extending JSON/JSONB data types (PR by @sheinbergon)

**3.0.0**

Expand Down
25 changes: 20 additions & 5 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
TypeDecorator,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import DOMAIN, JSONB
from sqlalchemy.dialects.postgresql import DOMAIN, JSON, JSONB
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.exc import CompileError
from sqlalchemy.sql.elements import TextClause
Expand Down Expand Up @@ -222,7 +222,7 @@ def collect_imports_for_column(self, column: Column[Any]) -> None:

if isinstance(column.type, ARRAY):
self.add_import(column.type.item_type.__class__)
elif isinstance(column.type, JSONB):
elif isinstance(column.type, (JSONB, JSON)):
if (
not isinstance(column.type.astext_type, Text)
or column.type.astext_type.length is not None
Expand Down Expand Up @@ -499,7 +499,7 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
else:
return render_callable("mapped_column", *args, kwargs=kwargs)

def render_column_type(self, coltype: object) -> str:
def render_column_type(self, coltype: TypeEngine[Any]) -> str:
args = []
kwargs: dict[str, Any] = {}
sig = inspect.signature(coltype.__class__.__init__)
Expand All @@ -515,6 +515,17 @@ def render_column_type(self, coltype: object) -> str:
continue

value = getattr(coltype, param.name, missing)

if isinstance(value, (JSONB, JSON)):
# Remove astext_type if it's the default
if (
isinstance(value.astext_type, Text)
and value.astext_type.length is None
):
value.astext_type = None # type: ignore[assignment]
else:
self.add_import(Text)

default = defaults.get(param.name, missing)
if isinstance(value, TextClause):
self.add_literal_import("sqlalchemy", "text")
Expand Down Expand Up @@ -547,7 +558,7 @@ def render_column_type(self, coltype: object) -> str:
if (value := getattr(coltype, colname)) is not None:
kwargs[colname] = repr(value)

if isinstance(coltype, JSONB):
if isinstance(coltype, (JSONB, JSON)):
# Remove astext_type if it's the default
if (
isinstance(coltype.astext_type, Text)
Expand Down Expand Up @@ -1224,7 +1235,11 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
return "".join(pre), column_type, "]" * post_size

def render_python_type(column_type: TypeEngine[Any]) -> str:
python_type = column_type.python_type
if isinstance(column_type, DOMAIN):
python_type = column_type.data_type.python_type
else:
python_type = column_type.python_type

python_type_name = python_type.__name__
python_type_module = python_type.__module__
if python_type_module == "builtins":
Expand Down
88 changes: 87 additions & 1 deletion tests/test_generator_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import pytest
from _pytest.fixtures import FixtureRequest
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import BIGINT, PrimaryKeyConstraint
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import JSON, JSONB
from sqlalchemy.engine import Engine
from sqlalchemy.schema import (
CheckConstraint,
Expand Down Expand Up @@ -1592,3 +1594,87 @@ class WithItems(Base):
str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2))
""",
)


@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
def test_domain_json(generator: CodeGenerator) -> None:
Table(
"test_domain_json",
generator.metadata,
Column("id", BIGINT, primary_key=True),
Column(
"foo",
postgresql.DOMAIN(
"domain_json",
JSON,
not_null=False,
),
nullable=True,
),
)

validate_code(
generator.generate(),
"""\
from typing import Optional

from sqlalchemy import BigInteger
from sqlalchemy.dialects.postgresql import DOMAIN, JSON
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class TestDomainJson(Base):
__tablename__ = 'test_domain_json'

id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', JSON(), not_null=False))
""",
)


@pytest.mark.parametrize(
"domain_type",
[JSONB, JSON],
)
def test_domain_non_default_json(
generator: CodeGenerator,
domain_type: type[JSON] | type[JSONB],
) -> None:
Table(
"test_domain_json",
generator.metadata,
Column("id", BIGINT, primary_key=True),
Column(
"foo",
postgresql.DOMAIN(
"domain_json",
domain_type(astext_type=Text(128)),
not_null=False,
),
nullable=True,
),
)

validate_code(
generator.generate(),
f"""\
from typing import Optional

from sqlalchemy import BigInteger, Text
from sqlalchemy.dialects.postgresql import DOMAIN, {domain_type.__name__}
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

class Base(DeclarativeBase):
pass


class TestDomainJson(Base):
__tablename__ = 'test_domain_json'

id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
foo: Mapped[Optional[dict]] = mapped_column(DOMAIN('domain_json', {domain_type.__name__}(astext_type=Text(length=128)), not_null=False))
""",
)
20 changes: 20 additions & 0 deletions tests/test_generator_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,26 @@ def test_jsonb_default(generator: CodeGenerator) -> None:
)


@pytest.mark.parametrize("engine", ["postgresql"], indirect=["engine"])
def test_json_default(generator: CodeGenerator) -> None:
Table("simple_items", generator.metadata, Column("json", postgresql.JSON))

validate_code(
generator.generate(),
"""\
from sqlalchemy import Column, JSON, MetaData, Table

metadata = MetaData()


t_simple_items = Table(
'simple_items', metadata,
Column('json', JSON)
)
""",
)


def test_enum_detection(generator: CodeGenerator) -> None:
Table(
"simple_items",
Expand Down