Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def render_enum_class(
to_pascal_fn: Callable[[str], str],
to_snake_fn: Callable[[str], str],
) -> str:
class_name = to_pascal_fn(f"{package_name}_{enum.name}")
class_name = to_pascal_fn(f"{package_name}_{to_snake_fn(enum.name)}")
members = []
seen_names: dict[str, int] = {}

Expand All @@ -301,7 +301,7 @@ def render_enum_class(
name = "".join(c if c.isalnum() else "_" for c in name)
name = name.strip("_") or "EMPTY"
if name[0].isdigit():
name = "_" + name
name = "NUM" + name
if name in seen_names:
seen_names[name] += 1
name = f"{name}_{seen_names[name]}"
Expand Down Expand Up @@ -589,7 +589,7 @@ def column_py_spec( # noqa: C901, PLR0912
catalog: Catalog,
package_name: str,
to_pascal_fn: Callable[[str], str],
_to_snake_fn: Callable[[str], str] = inflection.underscore,
to_snake_fn: Callable[[str], str] = inflection.underscore,
number: int = 0,
) -> ColumnPySpec:
db_type = column.type.name.removeprefix("pg_catalog.")
Expand Down Expand Up @@ -628,7 +628,11 @@ def column_py_spec( # noqa: C901, PLR0912
case "any" | "anyelement":
py_type = "object"
case enum if catalog.schema_by_ref(column.type).has_enum(enum):
py_type = to_pascal_fn(f"{package_name}_{enum}") if package_name else "str"
py_type = (
to_pascal_fn(f"{package_name}_{to_snake_fn(enum)}")
if package_name
else "str"
)
case _:
logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
py_type = "object"
Expand Down
13 changes: 5 additions & 8 deletions tests/test_type_system.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import keyword
from enum import StrEnum

import pytest
Expand Down Expand Up @@ -101,7 +100,7 @@ async def test_enum_naming_normalization(test_project: ProjectBuilder) -> None:

mod = test_project.generate()

assert hasattr(mod, "TestdbCamelcaseenum")
assert hasattr(mod, "TestdbCamelCaseEnum")
assert hasattr(mod, "TestdbScreamingEnum")


Expand All @@ -117,13 +116,11 @@ async def test_enum_value_name_normalization(test_project: ProjectBuilder) -> No
mod = test_project.generate()

enum_cls = mod.TestdbWeirdEnum
expected = {"1st", "foo-bar", "foo_bar"}
assert enum_cls.NUM1ST == "1st"
assert enum_cls.FOO_BAR == "foo-bar"
assert enum_cls.FOO_BAR_2 == "foo_bar"

assert {member.value for member in enum_cls} == expected
assert len(enum_cls.__members__) == len(expected)
for name in enum_cls.__members__:
assert name.isidentifier()
assert not keyword.iskeyword(name.lower())
assert len(enum_cls.__members__) == 3


async def test_enum_empty_label_value(test_project: ProjectBuilder) -> None:
Expand Down