Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(generator): support custom default casing for client properties #877

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pydantic.fields import PrivateAttr

from .. import config
from .utils import Faker, Sampler, clean_multiline
from .utils import Faker, Sampler, clean_multiline, to_camel_case, to_pascal_case, to_snake_case
from ..utils import DEBUG_GENERATOR, assert_never
from ..errors import UnsupportedListTypeError
from .._compat import (
Expand Down Expand Up @@ -248,6 +248,14 @@ def __str__(self) -> str:
return self.value


class ClientCasing(str, enum.Enum):
snake_case = 'snake_case'
camel_case = 'camel_case'
lower_case = 'lower_case'
upper_case = 'upper_case'
pascal_case = 'pascal_case'


class Module(BaseModel):
if TYPE_CHECKING:
spec: machinery.ModuleSpec
Expand Down Expand Up @@ -496,6 +504,7 @@ class Config(BaseSettings):
env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH',
)
engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE')
client_casing: ClientCasing = FieldInfo(default=ClientCasing.lower_case)

# this should be a list of experimental features
# https://github.com/prisma/prisma/issues/12442
Expand Down Expand Up @@ -684,7 +693,8 @@ def name_validator(cls, name: str) -> str:
f'use a different model name with \'@@map("{name}")\'.'
)

if iskeyword(name.lower()):
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.lower_case and iskeyword(name.lower()):
raise ValueError(
f'Model name "{name}" results in a client property that shadows a Python keyword; '
f'use a different model name with \'@@map("{name}")\'.'
Expand Down Expand Up @@ -748,7 +758,19 @@ def instance_name(self) -> str:

`User` -> `Prisma().user`
"""
return self.name.lower()
config = get_config()
if isinstance(config, Config) and config.client_casing == ClientCasing.camel_case:
return to_camel_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.pascal_case:
return to_pascal_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.snake_case:
return to_snake_case(self.name)
elif isinstance(config, Config) and config.client_casing == ClientCasing.upper_case:
return self.name.upper()
elif isinstance(config, Config) and config.client_casing == ClientCasing.lower_case:
return self.name.lower()
else:
return assert_never()

@property
def plural_name(self) -> str:
Expand Down
35 changes: 35 additions & 0 deletions src/prisma/generator/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Union, TypeVar, Iterator
from pathlib import Path
Expand Down Expand Up @@ -122,3 +123,37 @@ def clean_multiline(string: str) -> str:
assert string, 'Expected non-empty string'
lines = string.splitlines()
return '\n'.join([dedent(lines[0]), *lines[1:]])


# https://github.com/nficano/humps/blob/master/humps/main.py

ACRONYM_RE = re.compile(r'([A-Z\d]+)(?=[A-Z\d]|$)')
PASCAL_RE = re.compile(r'([^\-_]+)')
SPLIT_RE = re.compile(r'([\-_]*[A-Z][^A-Z]*[\-_]*)')
UNDERSCORE_RE = re.compile(r'(?<=[^\-_])[\-_]+[^\-_]')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be great to add test cases for all these functions!



def to_snake_case(input_str: str) -> str:
if to_camel_case(input_str) == input_str or to_pascal_case(input_str) == input_str: # if camel case or pascal case
input_str = ACRONYM_RE.sub(lambda m: m.group(0).title(), input_str)
input_str = '_'.join(s for s in SPLIT_RE.split(input_str) if s)
return input_str.lower()
else:
input_str = re.sub(r'[^a-zA-Z0-9]', '_', input_str)
input_str = input_str.lower().strip('_')

return input_str


def to_camel_case(input_str: str) -> str:
if len(input_str) != 0 and not input_str[:2].isupper():
input_str = input_str[0].lower() + input_str[1:]
return UNDERSCORE_RE.sub(lambda m: m.group(0)[-1].upper(), input_str)


def to_pascal_case(input_str: str) -> str:
def _replace_fn(match: re.Match[str]) -> str:
return match.group(1)[0].upper() + match.group(1)[1:]

input_str = to_camel_case(PASCAL_RE.sub(_replace_fn, input_str))
return input_str[0].upper() + input_str[1:] if len(input_str) != 0 else input_str
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: where did you get these functions from? did you come up with it yourself?

I'm only asking because, if you copied these from somewhere, it would be great to add a link to that place in case these need to be updated in the future.

43 changes: 42 additions & 1 deletion tests/test_generation/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from prisma.generator.utils import copy_tree
from prisma.generator.utils import (
copy_tree,
to_camel_case,
to_pascal_case,
to_snake_case,
)

from ..utils import Testdir

Expand Down Expand Up @@ -26,3 +31,39 @@ def test_copy_tree_ignores_files(testdir: Testdir) -> None:
assert files[0].name == 'bar.py'
assert files[1].name == 'foo.py'
assert files[2].name == 'hello.py'


def test_to_snake_case() -> None:
snake_case = 'snake_case_test'
pascal_case = 'PascalCaseTest'
camel_case = 'camelCaseTest'
mixed_case = 'Mixed_Case_Test'

assert to_snake_case(snake_case) == 'snake_case_test'
assert to_snake_case(pascal_case) == 'pascal_case_test'
assert to_snake_case(camel_case) == 'camel_case_test'
assert to_snake_case(mixed_case) == 'mixed_case_test'


def test_to_pascal_case() -> None:
snake_case = 'snake_case_test'
pascal_case = 'PascalCaseTest'
camel_case = 'camelCaseTest'
mixed_case = 'Mixed_Case_Test'

assert to_pascal_case(snake_case) == 'SnakeCaseTest'
assert to_pascal_case(pascal_case) == 'PascalCaseTest'
assert to_pascal_case(camel_case) == 'CamelCaseTest'
assert to_pascal_case(mixed_case) == 'MixedCaseTest'


def test_to_camel_case() -> None:
snake_case = 'snake_case_test'
pascal_case = 'PascalCaseTest'
camel_case = 'camelCaseTest'
mixed_case = 'Mixed_Case_Test'

assert to_camel_case(snake_case) == 'snakeCaseTest'
assert to_camel_case(pascal_case) == 'pascalCaseTest'
assert to_camel_case(camel_case) == 'camelCaseTest'
assert to_camel_case(mixed_case) == 'mixedCaseTest'