diff --git a/CHANGES.rst b/CHANGES.rst index 6c0fd657..c2df8fc1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,8 @@ Version history - Temporarily restrict SQLAlchemy version to 2.0.41 (PR by @sheinbergon) - Fixes ``add_import`` behavior when adding imports from sqlalchemy and overall better alignment of import behavior(s) across generators +- Fixes ``nullable`` column behavior for non-null columns for both + ``sqlmodels`` and ``declarative`` generators (PR by @sheinbergon) **3.0.0** diff --git a/src/sqlacodegen/generators.py b/src/sqlacodegen/generators.py index e9a08cde..7b4901a7 100644 --- a/src/sqlacodegen/generators.py +++ b/src/sqlacodegen/generators.py @@ -410,7 +410,9 @@ def render_column( args = [] kwargs: dict[str, Any] = {} kwarg = [] - is_sole_pk = column.primary_key and len(column.table.primary_key) == 1 + is_part_of_composite_pk = ( + column.primary_key and len(column.table.primary_key) > 1 + ) dedicated_fks = [ c for c in column.foreign_keys @@ -460,8 +462,10 @@ def render_column( kwargs["key"] = column.key if is_primary: kwargs["primary_key"] = True - if not column.nullable and not is_sole_pk and is_table: + if not column.nullable and not column.primary_key: kwargs["nullable"] = False + if column.nullable and is_part_of_composite_pk: + kwargs["nullable"] = True if is_unique: column.unique = True diff --git a/tests/test_cli.py b/tests/test_cli.py index f3a09fec..d4410ff7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -82,7 +82,7 @@ class Foo(Base): __tablename__ = 'foo' id: Mapped[int] = mapped_column(Integer, primary_key=True) - name: Mapped[str] = mapped_column(Text) + name: Mapped[str] = mapped_column(Text, nullable=False) """ ) @@ -115,7 +115,7 @@ class Foo(Base): __tablename__ = 'foo' id: Mapped[int] = mapped_column(Integer, primary_key=True) - name: Mapped[str] = mapped_column(Text) + name: Mapped[str] = mapped_column(Text, nullable=False) """ ) @@ -142,7 +142,7 @@ def test_cli_sqlmodels(db_path: Path, tmp_path: Path) -> None: class Foo(SQLModel, table=True): id: int = Field(sa_column=Column('id', Integer, primary_key=True)) - name: str = Field(sa_column=Column('name', Text)) + name: str = Field(sa_column=Column('name', Text, nullable=False)) """ ) diff --git a/tests/test_generator_dataclass.py b/tests/test_generator_dataclass.py index b2f9ebb1..307f865c 100644 --- a/tests/test_generator_dataclass.py +++ b/tests/test_generator_dataclass.py @@ -77,7 +77,7 @@ class Simple(Base): __tablename__ = 'simple' id: Mapped[int] = mapped_column(Integer, primary_key=True) - age: Mapped[int] = mapped_column(Integer) + age: Mapped[int] = mapped_column(Integer, nullable=False) name: Mapped[Optional[str]] = mapped_column(String(20), \ server_default=text('foo')) """, diff --git a/tests/test_generator_declarative.py b/tests/test_generator_declarative.py index 2d30782c..931d5965 100644 --- a/tests/test_generator_declarative.py +++ b/tests/test_generator_declarative.py @@ -341,7 +341,7 @@ class SimpleItems(Base): id: Mapped[int] = mapped_column(Integer, primary_key=True) top_container_id: Mapped[int] = \ -mapped_column(ForeignKey('simple_containers.id')) +mapped_column(ForeignKey('simple_containers.id'), nullable=False) parent_container_id: Mapped[Optional[int]] = \ mapped_column(ForeignKey('simple_containers.id')) @@ -812,6 +812,34 @@ class SimpleItems(Base): ) +def test_composite_nullable_pk(generator: CodeGenerator) -> None: + Table( + "simple_items", + generator.metadata, + Column("id1", INTEGER, primary_key=True), + Column("id2", INTEGER, primary_key=True, nullable=True), + ) + validate_code( + generator.generate(), + """\ +from typing import Optional + +from sqlalchemy import Integer +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column + +class Base(DeclarativeBase): + pass + + +class SimpleItems(Base): + __tablename__ = 'simple_items' + + id1: Mapped[int] = mapped_column(Integer, primary_key=True) + id2: Mapped[Optional[int]] = mapped_column(Integer, primary_key=True, nullable=True) + """, + ) + + def test_joined_inheritance(generator: CodeGenerator) -> None: Table( "simple_sub_items", @@ -1045,7 +1073,7 @@ class Group(Base): ) groups_id: Mapped[int] = mapped_column(Integer, primary_key=True) - group_name: Mapped[str] = mapped_column(Text(50)) + group_name: Mapped[str] = mapped_column(Text(50), nullable=False) users: Mapped[list['User']] = relationship('User', back_populates='group') @@ -1590,7 +1618,7 @@ class WithItems(Base): __tablename__ = 'with_items' id: Mapped[int] = mapped_column(Integer, primary_key=True) - int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER())) + int_items_not_optional: Mapped[list[int]] = mapped_column(ARRAY(INTEGER()), nullable=False) str_matrix: Mapped[Optional[list[list[str]]]] = mapped_column(ARRAY(VARCHAR(), dimensions=2)) """, ) diff --git a/tests/test_generator_sqlmodel.py b/tests/test_generator_sqlmodel.py index 1fce5830..32a736e2 100644 --- a/tests/test_generator_sqlmodel.py +++ b/tests/test_generator_sqlmodel.py @@ -33,7 +33,7 @@ def test_indexes(generator: CodeGenerator) -> None: "item", generator.metadata, Column("id", INTEGER, primary_key=True), - Column("number", INTEGER), + Column("number", INTEGER, nullable=False), Column("text", VARCHAR), ) simple_items.indexes.add(Index("idx_number", simple_items.c.number)) @@ -58,8 +58,8 @@ class Item(SQLModel, table=True): ) id: int = Field(sa_column=Column('id', Integer, primary_key=True)) - number: Optional[int] = Field(default=None, sa_column=Column(\ -'number', Integer)) + number: int = Field(sa_column=Column(\ +'number', Integer, nullable=False)) text: Optional[str] = Field(default=None, sa_column=Column(\ 'text', String)) """,