diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index 198722912..07affa90d 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -33,7 +33,7 @@ from pydantic.fields import PrivateAttr from .. import config -from .utils import Faker, Sampler, clean_multiline +from .utils import Faker, Sampler, clean_multiline, to_camel_case, to_pascal_case, to_snake_case from ..utils import DEBUG_GENERATOR, assert_never from ..errors import UnsupportedListTypeError from .._compat import ( @@ -248,6 +248,14 @@ def __str__(self) -> str: return self.value +class ClientCasing(str, enum.Enum): + snake_case = 'snake_case' + camel_case = 'camel_case' + lower_case = 'lower_case' + upper_case = 'upper_case' + pascal_case = 'pascal_case' + + class Module(BaseModel): if TYPE_CHECKING: spec: machinery.ModuleSpec @@ -496,6 +504,7 @@ class Config(BaseSettings): env='PRISMA_PY_CONFIG_RECURSIVE_TYPE_DEPTH', ) engine_type: EngineType = FieldInfo(default=EngineType.binary, env='PRISMA_PY_CONFIG_ENGINE_TYPE') + client_casing: ClientCasing = FieldInfo(default=ClientCasing.lower_case) # this should be a list of experimental features # https://github.com/prisma/prisma/issues/12442 @@ -684,7 +693,8 @@ def name_validator(cls, name: str) -> str: f'use a different model name with \'@@map("{name}")\'.' ) - if iskeyword(name.lower()): + config = get_config() + if isinstance(config, Config) and config.client_casing == ClientCasing.lower_case and iskeyword(name.lower()): raise ValueError( f'Model name "{name}" results in a client property that shadows a Python keyword; ' f'use a different model name with \'@@map("{name}")\'.' @@ -748,7 +758,19 @@ def instance_name(self) -> str: `User` -> `Prisma().user` """ - return self.name.lower() + config = get_config() + if isinstance(config, Config) and config.client_casing == ClientCasing.camel_case: + return to_camel_case(self.name) + elif isinstance(config, Config) and config.client_casing == ClientCasing.pascal_case: + return to_pascal_case(self.name) + elif isinstance(config, Config) and config.client_casing == ClientCasing.snake_case: + return to_snake_case(self.name) + elif isinstance(config, Config) and config.client_casing == ClientCasing.upper_case: + return self.name.upper() + elif isinstance(config, Config) and config.client_casing == ClientCasing.lower_case: + return self.name.lower() + else: + return assert_never() @property def plural_name(self) -> str: diff --git a/src/prisma/generator/utils.py b/src/prisma/generator/utils.py index c7eca3443..b55d68a07 100644 --- a/src/prisma/generator/utils.py +++ b/src/prisma/generator/utils.py @@ -1,4 +1,5 @@ import os +import re import shutil from typing import TYPE_CHECKING, Any, Dict, List, Union, TypeVar, Iterator from pathlib import Path @@ -122,3 +123,37 @@ def clean_multiline(string: str) -> str: assert string, 'Expected non-empty string' lines = string.splitlines() return '\n'.join([dedent(lines[0]), *lines[1:]]) + + +# https://github.com/nficano/humps/blob/master/humps/main.py + +ACRONYM_RE = re.compile(r'([A-Z\d]+)(?=[A-Z\d]|$)') +PASCAL_RE = re.compile(r'([^\-_]+)') +SPLIT_RE = re.compile(r'([\-_]*[A-Z][^A-Z]*[\-_]*)') +UNDERSCORE_RE = re.compile(r'(?<=[^\-_])[\-_]+[^\-_]') + + +def to_snake_case(input_str: str) -> str: + if to_camel_case(input_str) == input_str or to_pascal_case(input_str) == input_str: # if camel case or pascal case + input_str = ACRONYM_RE.sub(lambda m: m.group(0).title(), input_str) + input_str = '_'.join(s for s in SPLIT_RE.split(input_str) if s) + return input_str.lower() + else: + input_str = re.sub(r'[^a-zA-Z0-9]', '_', input_str) + input_str = input_str.lower().strip('_') + + return input_str + + +def to_camel_case(input_str: str) -> str: + if len(input_str) != 0 and not input_str[:2].isupper(): + input_str = input_str[0].lower() + input_str[1:] + return UNDERSCORE_RE.sub(lambda m: m.group(0)[-1].upper(), input_str) + + +def to_pascal_case(input_str: str) -> str: + def _replace_fn(match: re.Match[str]) -> str: + return match.group(1)[0].upper() + match.group(1)[1:] + + input_str = to_camel_case(PASCAL_RE.sub(_replace_fn, input_str)) + return input_str[0].upper() + input_str[1:] if len(input_str) != 0 else input_str diff --git a/tests/test_generation/test_utils.py b/tests/test_generation/test_utils.py index a77569cb4..d012ba2b5 100644 --- a/tests/test_generation/test_utils.py +++ b/tests/test_generation/test_utils.py @@ -1,4 +1,9 @@ -from prisma.generator.utils import copy_tree +from prisma.generator.utils import ( + copy_tree, + to_camel_case, + to_pascal_case, + to_snake_case, +) from ..utils import Testdir @@ -26,3 +31,39 @@ def test_copy_tree_ignores_files(testdir: Testdir) -> None: assert files[0].name == 'bar.py' assert files[1].name == 'foo.py' assert files[2].name == 'hello.py' + + +def test_to_snake_case() -> None: + snake_case = 'snake_case_test' + pascal_case = 'PascalCaseTest' + camel_case = 'camelCaseTest' + mixed_case = 'Mixed_Case_Test' + + assert to_snake_case(snake_case) == 'snake_case_test' + assert to_snake_case(pascal_case) == 'pascal_case_test' + assert to_snake_case(camel_case) == 'camel_case_test' + assert to_snake_case(mixed_case) == 'mixed_case_test' + + +def test_to_pascal_case() -> None: + snake_case = 'snake_case_test' + pascal_case = 'PascalCaseTest' + camel_case = 'camelCaseTest' + mixed_case = 'Mixed_Case_Test' + + assert to_pascal_case(snake_case) == 'SnakeCaseTest' + assert to_pascal_case(pascal_case) == 'PascalCaseTest' + assert to_pascal_case(camel_case) == 'CamelCaseTest' + assert to_pascal_case(mixed_case) == 'MixedCaseTest' + + +def test_to_camel_case() -> None: + snake_case = 'snake_case_test' + pascal_case = 'PascalCaseTest' + camel_case = 'camelCaseTest' + mixed_case = 'Mixed_Case_Test' + + assert to_camel_case(snake_case) == 'snakeCaseTest' + assert to_camel_case(pascal_case) == 'pascalCaseTest' + assert to_camel_case(camel_case) == 'camelCaseTest' + assert to_camel_case(mixed_case) == 'mixedCaseTest'