Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/case support #922

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,14 @@
from pydantic.fields import PrivateAttr

from .. import config
from .utils import Faker, Sampler, clean_multiline
from .utils import (
Faker,
Sampler,
clean_multiline,
to_camel_case,
to_pascal_case,
to_snake_case,
)
from ..utils import DEBUG_GENERATOR, assert_never
from ..errors import UnsupportedListTypeError
from .._compat import (
Expand Down Expand Up @@ -251,6 +258,14 @@ def __str__(self) -> str:
return self.value


class ClientCasing(str, enum.Enum):
snake_case = 'snake_case'
camel_case = 'camel_case'
lower_case = 'lower_case'
upper_case = 'upper_case'
pascal_case = 'pascal_case'


class Module(BaseModel):
if TYPE_CHECKING:
spec: machinery.ModuleSpec
Expand Down Expand Up @@ -428,9 +443,7 @@ class BinaryPaths(BaseModel):
else:

class Config(BaseModel.Config): # pyright: ignore[reportDeprecated]
extra: Any = (
pydantic.Extra.allow # pyright: ignore[reportDeprecated]
)
extra: Any = pydantic.Extra.allow # pyright: ignore[reportDeprecated]


class Datasource(BaseModel):
Expand Down Expand Up @@ -501,6 +514,7 @@ class Config(BaseSettings):
env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH',
)
engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE')
client_casing: ClientCasing = FieldInfo(default=ClientCasing.lower_case)

# this should be a list of experimental features
# https://github.com/prisma/prisma/issues/12442
Expand Down Expand Up @@ -600,9 +614,7 @@ def partial_type_generator_converter(cls, values: Dict[str, Any]) -> Dict[str, A
@classmethod
def _partial_type_generator_converter(cls, value: Optional[str]) -> Optional[Module]:
try:
return Module(
spec=value # pyright: ignore[reportArgumentType]
)
return Module(spec=value) # pyright: ignore[reportArgumentType]
except ValueError:
if value is None:
# no config value passed and the default location was not found
Expand Down Expand Up @@ -734,7 +746,8 @@ def name_validator(cls, name: str) -> str:
f'use a different model name with \'@@map("{name}")\'.'
)

if iskeyword(name.lower()):
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.lower_case and iskeyword(name.lower()):
raise ValueError(
f'Model name "{name}" results in a client property that shadows a Python keyword; '
f'use a different model name with \'@@map("{name}")\'.'
Expand Down Expand Up @@ -800,8 +813,19 @@ def instance_name(self) -> str:
"""
if self.extension and self.extension.instance_name:
return self.extension.instance_name

return self.name.lower()
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.camel_case:
return to_camel_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.pascal_case:
return to_pascal_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.snake_case:
return to_snake_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.upper_case:
return self.name.upper()
elif isinstance(config, Config) and config.client_casing == ClientCasing.lower_case:
return self.name.lower()
else:
return assert_never()

@property
def plural_name(self) -> str:
Expand Down
Loading