Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(client): start of JSON Protocol implementation #748

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
6 changes: 5 additions & 1 deletion src/prisma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from .utils import setup_logging
from . import errors as errors
from .validator import *
from ._types import PrismaMethod as PrismaMethod
from ._types import (
PrismaMethod as PrismaMethod,
NotGiven as NotGiven,
NOT_GIVEN as NOT_GIVEN,
)
from ._metrics import (
Metric as Metric,
Metrics as Metrics,
Expand Down
12 changes: 12 additions & 0 deletions src/prisma/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from typing import Mapping
from typing_extensions import TypeGuard


def is_list(value: object) -> TypeGuard[list[object]]:
return isinstance(value, list)


def is_mapping(value: object) -> TypeGuard[Mapping[str, object]]:
return isinstance(value, Mapping)
34 changes: 34 additions & 0 deletions src/prisma/_parsers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import Any, Callable, TypeVar

from ._helpers import is_list


_T = TypeVar('_T')


def allow_none(parser: Callable[[Any], _T]) -> Callable[[Any], _T | None]:
"""Wrap the given parser function to allow passing in None values."""

def wrapped(value: Any) -> _T | None:
if value is None:
return None

return parser(value)

return wrapped


def as_list(parser: Callable[[Any], _T]) -> Callable[[Any], list[_T]]:
"""Wrap the given parser function to accept a list and invoke it for each entry"""

def wrapped(value: Any) -> list[_T]:
if not is_list(value):
raise TypeError(
f'Expected value to be a list but got {type(value)}'
)

return [parser(entry) for entry in value]

return wrapped
21 changes: 20 additions & 1 deletion src/prisma/_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Coroutine, TypeVar, Type, Tuple, Any
from typing import Callable, Coroutine, TypeVar, Type, Tuple, Any, Union
from pydantic import BaseModel
from typing_extensions import (
TypeGuard as TypeGuard,
Expand All @@ -11,6 +11,8 @@

Method = Literal['GET', 'POST']

_T = TypeVar('_T')

CallableT = TypeVar('CallableT', bound='FuncType')
BaseModelT = TypeVar('BaseModelT', bound=BaseModel)

Expand All @@ -28,6 +30,23 @@ class _GenericAlias(Protocol):
__origin__: Type[object]


class NotGiven:
"""Represents cases where a value has not been explicitly given.
Useful when `None` is not a possible default value.
"""

def __bool__(self) -> bool:
return False

def __repr__(self) -> str:
return 'NOT_GIVEN'


NOT_GIVEN = NotGiven()
NotGivenOr = Union[_T, NotGiven]

PrismaMethod = Literal[
# raw queries
'query_raw',
Expand Down
7 changes: 7 additions & 0 deletions src/prisma/engine/json/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .serializer import (
QueryInput as QueryInput,
build_single_query as build_single_query,
serialize_single_query as serialize_single_query,
serialize_batched_query as serialize_batched_query,
)
from .deserializer import deserialize as deserialize
36 changes: 36 additions & 0 deletions src/prisma/engine/json/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from typing import Mapping
from typing_extensions import TypeGuard

from .types import OutputTaggedValue
from ..._helpers import is_list, is_mapping


def deserialize(value: object) -> object:
if not value:
return value

if is_list(value):
return [deserialize(entry) for entry in value]

if is_mapping(value):
if is_tagged_value(value):
return deserialize_tagged_value(value)

return {key: deserialize(item) for key, item in value.items()}

return value


def is_tagged_value(
value: Mapping[str, object]
) -> TypeGuard[OutputTaggedValue]:
return isinstance(value.get('$type'), str)


def deserialize_tagged_value(tagged: OutputTaggedValue) -> object:
if tagged['$type'] == 'FieldRef':
raise RuntimeError('Cannot deserialize FieldRef values yet')

return tagged['value']
Loading
Loading