Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add support for transactions #427

Closed
wants to merge 13 commits into from
Closed
6 changes: 6 additions & 0 deletions src/prisma/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def handle_response_errors(resp: AbstractResponse[Any], data: Any) -> NoReturn:
# As we can only check for this error by searching the message then this
# comes with performance concerns.
message = user_facing.get('message', '')

if code == 'P2028':
if message.endswith("Last state: 'Expired'."):
raise prisma_errors.TransactionExpiredError()
raise prisma_errors.TransactionError(message)

if 'A value is required but not set' in message:
raise prisma_errors.MissingRequiredValueError(error)

Expand Down
17 changes: 17 additions & 0 deletions src/prisma/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ class InputError(DataError):
pass


class TransactionError(PrismaError):
pass


class TransactionExpiredError(TransactionError):
def __init__(self) -> None:
super().__init__('Attempted operation on an expired transaction.')


class TransactionNotStartedError(TransactionError):
def __init__(self) -> None:
super().__init__(
'Transaction has not been started yet.\n'
'Transactions must be used within a context manager.'
)


class BuilderError(PrismaError):
pass

Expand Down
1 change: 1 addition & 0 deletions src/prisma/generator/templates/_header.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import decimal
import datetime
from typing import (
TYPE_CHECKING,
NoReturn,
Optional,
Iterable,
Iterator,
Expand Down
165 changes: 156 additions & 9 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
{% include '_header.py.jinja' %}
{% from '_utils.py.jinja' import is_async, maybe_async_def, maybe_await, methods, operations, recursive_types, active_provider with context %}
# -- template client.py.jinja --
import warnings
from pathlib import Path
from types import TracebackType

from . import types, models, errors, actions
from .types import DatasourceOverride, HttpConfig
from ._types import BaseModelT
from .engine import AbstractEngine, QueryEngine
from .builder import QueryBuilder
from .builder import QueryBuilder, dumps
from .generator.models import EngineType, OptionalValueFromEnvVar


Expand All @@ -26,6 +27,21 @@ __all__ = (
SCHEMA_PATH = Path('{{ schema_path.as_posix() }}')
PACKAGED_SCHEMA_PATH = Path(__file__).parent.joinpath('schema.prisma')

{% set itx_missing_doc = clean_multiline('''
Interactive transactions require a preview feature flag to be set.

If you would like to enable interactive transactions then you should modify
your client definition in your Prisma Schema file:

```prisma
generator py {
provider = "prisma-client-py"
previewFeatures = ["interactiveTransactions"]
}
```
''')
%}

ENGINE_TYPE: EngineType = EngineType.{{ generator.config.engine_type }}

RegisteredClient = Union['Prisma', Callable[[], 'Prisma']]
Expand Down Expand Up @@ -123,11 +139,12 @@ class Prisma:
'{{ model.name.lower() }}',
{% endfor %}
'__engine',
'_active_provider',
'_log_queries',
'_tx_id',
'_datasource',
'_connect_timeout',
'_log_queries',
'_http_config',
'_connect_timeout',
'_active_provider',
)

def __init__(
Expand All @@ -153,6 +170,7 @@ class Prisma:
self._datasource = datasource
self._connect_timeout = connect_timeout
self._http_config: HttpConfig = http or {}
self._tx_id: Optional[str] = None

if use_dotenv:
load_env()
Expand All @@ -161,7 +179,10 @@ class Prisma:
register(self)

def __del__(self) -> None:
if self.__engine is not None:
# we are not in a transaction, should be safe to stop
# the query engine
# TODO: what if there are outstanding transactions?
if self.__engine is not None and not self.is_transaction():
self.__engine.stop()
self.__engine = None

Expand Down Expand Up @@ -302,6 +323,32 @@ class Prisma:
"""Returns a context manager for grouping write queries into a single transaction."""
return Batch(client=self)

{% if 'interactiveTransactions' in generator.preview_features %}
def tx(self, *, max_wait: int = 2000, timeout: int = 5000) -> 'TransactionManager':
"""Returns a context manager for executing queries within a transaction.

Entering the context manager returns a new Prisma instance wrapping all
actions within a transaction, queries will be isolated to the Prisma instance and
will not be commited to the database until the context manager exits.
"""
return TransactionManager(client=self, max_wait=max_wait, timeout=timeout)
{% else %}
def tx(self, *, max_wait: int = ..., timeout: int = ...) -> NoReturn:
"""
{{ itx_missing_doc }}
"""
# TODO: fix formatting for this message.
raise RuntimeError(
'''
{{ itx_missing_doc }}
'''
)
{% endif %}

def is_transaction(self) -> bool:
"""Returns true if the client is wrapped within a transaction"""
return self._tx_id is not None

# TODO: don't return Any
{{ maybe_async_def }}_execute(
self,
Expand All @@ -318,7 +365,16 @@ class Prisma:
arguments=arguments,
root_selection=root_selection,
)
return {{ maybe_await }}self._engine.query(builder.build())
return {{ maybe_await }}self._engine.query(builder.build(), tx_id=self._tx_id)

def _copy(self) -> 'Prisma':
"""Return a new Prisma instance using the same engine process (if connected)."""
new = Prisma(use_dotenv=False, log_queries=self._log_queries)

if self.__engine is not None:
new._engine = self.__engine

return new

def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine:
if ENGINE_TYPE == EngineType.binary:
Expand All @@ -340,6 +396,10 @@ class Prisma:
raise errors.ClientNotConnectedError()
return engine

@_engine.setter
def _engine(self, engine: QueryEngine) -> None:
self.__engine = engine

def _make_sqlite_datasource(self) -> DatasourceOverride:
return {
'name': '{{ datasources[0].name }}',
Expand All @@ -364,6 +424,92 @@ class Prisma:
}


class TransactionManager:
"""Context manager for wrapping a Prisma instance within a transaction.

This should never be created manually, instead it should be used
through the Prisma().tx() method.
"""

def __init__(self, client: Prisma, max_wait: int, timeout: int) -> None:
self.__client = client
self._max_wait = max_wait
self._timeout = timeout
self._tx_id: Optional[str] = None

# TODO: these endpoints should be abstract methods in the engine class

{{ maybe_async_def }}start(self, *, _from_context: bool = False) -> Prisma:
"""Start the transaction and return the wrapped Prisma instance"""
if self.__client.is_transaction():
# if we were called from the context manager then the stacklevel
# needs to be one higher to warn on the actual offending code
warnings.warn(
'The current client is already in a transaction.',
UserWarning,
stacklevel=3 if _from_context else 2
)

result = {{ maybe_await }}self.__client._engine.request(
'POST',
'/transaction/start',
content=dumps(
{
'timeout': self._timeout,
'max_wait': self._max_wait,
}
),
)
self._tx_id = tx_id = result['id']
client = self.__client._copy()
client._tx_id = tx_id
return client

{{ maybe_async_def }}commit(self) -> None:
"""Commit the transaction to the database, this transaction will no longer be usable"""
if self._tx_id is None:
raise errors.TransactionNotStartedError()

{{ maybe_await }}self.__client._engine.request(
'POST', f'/transaction/{self._tx_id}/commit'
)

{{ maybe_async_def }}rollback(self) -> None:
"""Do not commit the changes to the database, this transaction will no longer be usable"""
if self._tx_id is None:
raise errors.TransactionNotStartedError()

{{ maybe_await }}self.__client._engine.request(
'POST', f'/transaction/{self._tx_id}/rollback'
)

{% if is_async %}
async def __aenter__(self) -> Prisma:
return await self.start(_from_context=True)

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc is None:
await self.commit()
{% else %}
def __enter__(self) -> Prisma:
return self.start(_from_context=True)

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
if exc is None:
self.commit()
{% endif %}


# TODO: this should return the results as well
# TODO: don't require copy-pasting arguments between actions and batch actions
class Batch:
Expand All @@ -386,8 +532,6 @@ class Batch:
{{ maybe_async_def }}commit(self) -> None:
"""Execute the queries"""
# TODO: normalise this, we should still call client._execute
from .builder import dumps

queries = self.__queries
self.__queries = []

Expand All @@ -401,7 +545,10 @@ class Batch:
],
'transaction': True,
}
{{ maybe_await }}self.__client._engine.query(dumps(payload))
{{ maybe_await }}self.__client._engine.query(
dumps(payload),
tx_id=self.__client._tx_id,
)

{% if is_async %}
async def __aenter__(self) -> 'Batch':
Expand Down
2 changes: 1 addition & 1 deletion src/prisma/generator/templates/engine/abstract.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AbstractEngine(ABC):
...

@abstractmethod
{{ maybe_async_def }}query(self, content: str) -> Any:
{{ maybe_async_def }}query(self, content: str, tx_id: Optional[str] = None) -> Any:
"""Execute a GraphQL query.

This method expects a JSON object matching this structure:
Expand Down
12 changes: 11 additions & 1 deletion src/prisma/generator/templates/engine/http.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ class HTTPEngine(AbstractEngine):
if self.session and not self.session.closed:
{{ maybe_await }}self.session.close()

{{ maybe_async_def }}request(self, method: Method, path: str, *, content: Any = None) -> Any:
{{ maybe_async_def }}request(
self,
method: Method,
path: str,
*,
content: Any = None,
headers: Optional[Dict[str, str]] = None,
) -> Any:
if self.url is None:
raise errors.NotConnectedError('Not connected to the query engine')

Expand All @@ -66,6 +73,9 @@ class HTTPEngine(AbstractEngine):
}
}

if headers is not None:
kwargs['headers'].update(headers)

if content is not None:
kwargs['content'] = content

Expand Down
17 changes: 15 additions & 2 deletions src/prisma/generator/templates/engine/query.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,21 @@ class QueryEngine(HTTPEngine):
'Could not connect to the query engine'
) from last_exc

{{ maybe_async_def }}query(self, content: str) -> Any:
return {{ maybe_await }}self.request('POST', '/', content=content)
{{ maybe_async_def }}query(
self,
content: str,
tx_id: Optional[str] = None
) -> Any:
headers: Dict[str, str] = {}
if tx_id is not None:
headers['X-transaction-id'] = tx_id

return {{ maybe_await }}self.request(
'POST',
'/',
content=content,
headers=headers,
)


# black does not respect the fmt: off comment without this
Expand Down
3 changes: 2 additions & 1 deletion tests/data/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ datasource db {
url = "file:dev.db"
}

generator client {
generator db {
provider = "prisma-client-py"
interface = "asyncio"
recursive_type_depth = 5
partial_type_generator = "tests/scripts/partial_type_generator.py"
previewFeatures = ["interactiveTransactions"]
enable_experimental_decimal = true
}

Expand Down
13 changes: 11 additions & 2 deletions tests/integrations/sync/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@ datasource db {
}

generator client {
provider = "coverage run -m prisma"
interface = "sync"
provider = "coverage run -m prisma"
interface = "sync"
previewFeatures = ["interactiveTransactions"]
}

model User {
id String @id @default(cuid())
created_at DateTime @default(now())
updated_at DateTime @updatedAt
name String
profile Profile?
}

model Profile {
id Int @id @default(autoincrement())
user User @relation(fields: [user_id], references: [id])
user_id String
bio String
}
Loading