diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bfeb87c37..b722eecac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -112,7 +112,6 @@ jobs: run: | nox -s typesafety-mypy - lint: name: lint runs-on: ubuntu-latest diff --git a/Makefile b/Makefile index 41dbfbbff..33c0e3858 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ bootstrap: pip install -U wheel pip install -U -e . - pip install -U -r pipelines/requirements/lint.txt + pip install -U -r pipelines/requirements/all.txt python -m prisma_cleanup prisma db push --schema=tests/data/schema.prisma cp tests/data/dev.db dev.db diff --git a/databases/constants.py b/databases/constants.py index b99e04bde..ece6e4733 100644 --- a/databases/constants.py +++ b/databases/constants.py @@ -38,6 +38,7 @@ def _fromdir(path: str) -> list[str]: unsupported_features={ 'json_arrays', 'array_push', + 'transactions', }, ), 'sqlite': DatabaseConfig( @@ -102,6 +103,7 @@ def _fromdir(path: str) -> list[str]: 'arrays': [*_fromdir('arrays'), *_fromdir('types/raw_queries/arrays')], 'array_push': _fromdir('arrays/push'), 'json_arrays': ['arrays/test_json.py', 'arrays/push/test_json.py'], + 'transactions': ['test_transactions.py'], # not yet implemented 'date': [], 'create_many': ['test_create_many.py'], diff --git a/databases/requirements.txt b/databases/requirements.txt index 7c20d1826..546b9cef3 100644 --- a/databases/requirements.txt +++ b/databases/requirements.txt @@ -5,5 +5,6 @@ distro -r ../pipelines/requirements/deps/pyright.txt -r ../pipelines/requirements/deps/pytest.txt +-r ../pipelines/requirements/deps/pytest-mock.txt -r ../pipelines/requirements/deps/pytest-asyncio.txt -r ../pipelines/requirements/deps/syrupy.txt diff --git a/databases/sync_tests/test_transactions.py b/databases/sync_tests/test_transactions.py new file mode 100644 index 000000000..c2cb5e0d4 --- /dev/null +++ b/databases/sync_tests/test_transactions.py @@ -0,0 +1,177 @@ +import time +from typing import Optional + +import pytest + +import prisma +from prisma import Prisma +from prisma.models import User +from ..utils import CURRENT_DATABASE + + +def test_context_manager(client: Prisma) -> None: + """Basic usage within a context manager""" + with client.tx(timeout=10 * 100) as transaction: + user = transaction.user.create({'name': 'Robert'}) + assert user.name == 'Robert' + + # ensure not commited outside transaction + assert client.user.count() == 0 + + transaction.profile.create( + { + 'description': 'Hello, there!', + 'country': 'Scotland', + 'user': { + 'connect': { + 'id': user.id, + }, + }, + }, + ) + + found = client.user.find_unique( + where={'id': user.id}, include={'profile': True} + ) + assert found is not None + assert found.name == 'Robert' + assert found.profile is not None + assert found.profile.description == 'Hello, there!' + + +def test_context_manager_auto_rollback(client: Prisma) -> None: + """An error being thrown when within a context manager causes the transaction to be rolled back""" + user: Optional[User] = None + + with pytest.raises(RuntimeError) as exc: + with client.tx() as tx: + user = tx.user.create({'name': 'Tegan'}) + raise RuntimeError('Error ocurred mid transaction.') + + assert exc.match('Error ocurred mid transaction.') + + assert user is not None + found = client.user.find_unique(where={'id': user.id}) + assert found is None + + +def test_batch_within_transaction(client: Prisma) -> None: + """Query batching can be used within transactions""" + with client.tx(timeout=10000) as transaction: + with transaction.batch_() as batcher: + batcher.user.create({'name': 'Tegan'}) + batcher.user.create({'name': 'Robert'}) + + assert client.user.count() == 0 + assert transaction.user.count() == 2 + + assert client.user.count() == 2 + + +def test_timeout(client: Prisma) -> None: + """A `TransactionExpiredError` is raised when the transaction times out.""" + # this outer block is necessary becuse to the context manager it appears that no error + # ocurred so it will attempt to commit the transaction, triggering the expired error again + with pytest.raises(prisma.errors.TransactionExpiredError): + with client.tx(timeout=50) as transaction: + time.sleep(0.05) + + with pytest.raises(prisma.errors.TransactionExpiredError) as exc: + transaction.user.create({'name': 'Robert'}) + + raise exc.value + + +@pytest.mark.skipif( + CURRENT_DATABASE == 'sqlite', reason='This is currently broken...' +) +def test_concurrent_transactions(client: Prisma) -> None: + """Two separate transactions can be used independently of each other at the same time""" + timeout = 15000 + with client.tx(timeout=timeout) as tx1, client.tx(timeout=timeout) as tx2: + user1 = tx1.user.create({'name': 'Tegan'}) + user2 = tx2.user.create({'name': 'Robert'}) + + assert tx1.user.find_first(where={'name': 'Robert'}) is None + assert tx2.user.find_first(where={'name': 'Tegan'}) is None + + found = tx1.user.find_first(where={'name': 'Tegan'}) + assert found is not None + assert found.id == user1.id + + found = tx2.user.find_first(where={'name': 'Robert'}) + assert found is not None + assert found.id == user2.id + + # ensure not leaked + assert client.user.count() == 0 + assert (tx1.user.find_first(where={'name': user2.name})) is None + assert (tx2.user.find_first(where={'name': user1.name})) is None + + assert client.user.count() == 2 + + +def test_transaction_raises_original_error(client: Prisma) -> None: + """If an error is raised during the execution of the transaction, it is raised""" + with pytest.raises(RuntimeError, match=r'Test error!'): + with client.tx(): + raise RuntimeError('Test error!') + + +def test_transaction_within_transaction_warning(client: Prisma) -> None: + """A warning is raised if a transaction is started from another transaction client""" + tx1 = client.tx().start() + with pytest.warns(UserWarning) as warnings: + tx1.tx().start() + + assert len(warnings) == 1 + record = warnings[0] + assert not isinstance(record.message, str) + assert ( + record.message.args[0] + == 'The current client is already in a transaction. This can lead to surprising behaviour.' + ) + assert record.filename == __file__ + + +def test_transaction_within_transaction_context_warning( + client: Prisma, +) -> None: + """A warning is raised if a transaction is started from another transaction client""" + with client.tx() as tx1: + with pytest.warns(UserWarning) as warnings: + with tx1.tx(): + pass + + assert len(warnings) == 1 + record = warnings[0] + assert not isinstance(record.message, str) + assert ( + record.message.args[0] + == 'The current client is already in a transaction. This can lead to surprising behaviour.' + ) + assert record.filename == __file__ + + +def test_transaction_not_started(client: Prisma) -> None: + """A `TransactionNotStartedError` is raised when attempting to call `commit()` or `rollback()` + on a transaction that hasn't been started yet. + """ + tx = client.tx() + + with pytest.raises(prisma.errors.TransactionNotStartedError): + tx.commit() + + with pytest.raises(prisma.errors.TransactionNotStartedError): + tx.rollback() + + +def test_transaction_already_closed(client: Prisma) -> None: + """Attempting to use a transaction outside of the context block raises an error""" + with client.tx() as transaction: + pass + + with pytest.raises(prisma.errors.TransactionExpiredError) as exc: + transaction.user.delete_many() + + assert exc.match('Transaction already closed') diff --git a/databases/tests/test_transactions.py b/databases/tests/test_transactions.py new file mode 100644 index 000000000..c6cbc6b43 --- /dev/null +++ b/databases/tests/test_transactions.py @@ -0,0 +1,189 @@ +import asyncio +from typing import Optional + +import pytest + +import prisma +from prisma import Prisma +from prisma.models import User +from ..utils import CURRENT_DATABASE + + +@pytest.mark.asyncio +async def test_context_manager(client: Prisma) -> None: + """Basic usage within a context manager""" + async with client.tx(timeout=10 * 100) as transaction: + user = await transaction.user.create({'name': 'Robert'}) + assert user.name == 'Robert' + + # ensure not commited outside transaction + assert await client.user.count() == 0 + + await transaction.profile.create( + { + 'description': 'Hello, there!', + 'country': 'Scotland', + 'user': { + 'connect': { + 'id': user.id, + }, + }, + }, + ) + + found = await client.user.find_unique( + where={'id': user.id}, include={'profile': True} + ) + assert found is not None + assert found.name == 'Robert' + assert found.profile is not None + assert found.profile.description == 'Hello, there!' + + +@pytest.mark.asyncio +async def test_context_manager_auto_rollback(client: Prisma) -> None: + """An error being thrown when within a context manager causes the transaction to be rolled back""" + user: Optional[User] = None + + with pytest.raises(RuntimeError) as exc: + async with client.tx() as tx: + user = await tx.user.create({'name': 'Tegan'}) + raise RuntimeError('Error ocurred mid transaction.') + + assert exc.match('Error ocurred mid transaction.') + + assert user is not None + found = await client.user.find_unique(where={'id': user.id}) + assert found is None + + +@pytest.mark.asyncio +async def test_batch_within_transaction(client: Prisma) -> None: + """Query batching can be used within transactions""" + async with client.tx(timeout=10000) as transaction: + async with transaction.batch_() as batcher: + batcher.user.create({'name': 'Tegan'}) + batcher.user.create({'name': 'Robert'}) + + assert await client.user.count() == 0 + assert await transaction.user.count() == 2 + + assert await client.user.count() == 2 + + +@pytest.mark.asyncio +async def test_timeout(client: Prisma) -> None: + """A `TransactionExpiredError` is raised when the transaction times out.""" + # this outer block is necessary becuse to the context manager it appears that no error + # ocurred so it will attempt to commit the transaction, triggering the expired error again + with pytest.raises(prisma.errors.TransactionExpiredError): + async with client.tx(timeout=50) as transaction: + await asyncio.sleep(0.05) + + with pytest.raises(prisma.errors.TransactionExpiredError) as exc: + await transaction.user.create({'name': 'Robert'}) + + raise exc.value + + +@pytest.mark.asyncio +@pytest.mark.skipif( + CURRENT_DATABASE == 'sqlite', reason='This is currently broken...' +) +async def test_concurrent_transactions(client: Prisma) -> None: + """Two separate transactions can be used independently of each other at the same time""" + timeout = 15000 + async with client.tx(timeout=timeout) as tx1, client.tx( + timeout=timeout + ) as tx2: + user1 = await tx1.user.create({'name': 'Tegan'}) + user2 = await tx2.user.create({'name': 'Robert'}) + + assert await tx1.user.find_first(where={'name': 'Robert'}) is None + assert await tx2.user.find_first(where={'name': 'Tegan'}) is None + + found = await tx1.user.find_first(where={'name': 'Tegan'}) + assert found is not None + assert found.id == user1.id + + found = await tx2.user.find_first(where={'name': 'Robert'}) + assert found is not None + assert found.id == user2.id + + # ensure not leaked + assert await client.user.count() == 0 + assert (await tx1.user.find_first(where={'name': user2.name})) is None + assert (await tx2.user.find_first(where={'name': user1.name})) is None + + assert await client.user.count() == 2 + + +@pytest.mark.asyncio +async def test_transaction_raises_original_error(client: Prisma) -> None: + """If an error is raised during the execution of the transaction, it is raised""" + with pytest.raises(RuntimeError, match=r'Test error!'): + async with client.tx(): + raise RuntimeError('Test error!') + + +@pytest.mark.asyncio +async def test_transaction_within_transaction_warning(client: Prisma) -> None: + """A warning is raised if a transaction is started from another transaction client""" + tx1 = await client.tx().start() + with pytest.warns(UserWarning) as warnings: + await tx1.tx().start() + + assert len(warnings) == 1 + record = warnings[0] + assert not isinstance(record.message, str) + assert ( + record.message.args[0] + == 'The current client is already in a transaction. This can lead to surprising behaviour.' + ) + assert record.filename == __file__ + + +@pytest.mark.asyncio +async def test_transaction_within_transaction_context_warning( + client: Prisma, +) -> None: + """A warning is raised if a transaction is started from another transaction client""" + async with client.tx() as tx1: + with pytest.warns(UserWarning) as warnings: + async with tx1.tx(): + pass + + assert len(warnings) == 1 + record = warnings[0] + assert not isinstance(record.message, str) + assert ( + record.message.args[0] + == 'The current client is already in a transaction. This can lead to surprising behaviour.' + ) + assert record.filename == __file__ + + +@pytest.mark.asyncio +async def test_transaction_not_started(client: Prisma) -> None: + """A `TransactionNotStartedError` is raised when attempting to call `commit()` or `rollback()` + on a transaction that hasn't been started yet. + """ + tx = client.tx() + + with pytest.raises(prisma.errors.TransactionNotStartedError): + await tx.commit() + + with pytest.raises(prisma.errors.TransactionNotStartedError): + await tx.rollback() + + +@pytest.mark.asyncio +async def test_transaction_already_closed(client: Prisma) -> None: + """Attempting to use a transaction outside of the context block raises an error""" + async with client.tx() as transaction: + pass + + with pytest.raises(prisma.errors.TransactionExpiredError) as exc: + await transaction.user.delete_many() + + assert exc.match('Transaction already closed') diff --git a/databases/utils.py b/databases/utils.py index 55e203150..4f4e4d351 100644 --- a/databases/utils.py +++ b/databases/utils.py @@ -21,6 +21,7 @@ 'json_arrays', 'raw_queries', 'create_many', + 'transactions', 'case_sensitivity', ] diff --git a/docs/reference/transactions.md b/docs/reference/transactions.md new file mode 100644 index 000000000..62a6a3cd0 --- /dev/null +++ b/docs/reference/transactions.md @@ -0,0 +1,123 @@ +# Interactive Transactions + +!!! warning + + Transactions are not fully tested against CockroachDB. + +Prisma Client Python supports interactive transactions, this is a generic solution allowing you to run multiple operations as a single, atomic operation - if any operation fails, or any other error is raised during the transaction, Prisma will roll back the entire transaction. + +This differs from [batch queries](./batching.md) as you can perform operations that are dependent on the results of previous operations. + +## Usage + +Transactions can be created using the `Prisma.tx()` method which returns a context manager that when entered returns a separate instance of +`Prisma` that wraps all queries in a transaction. + +=== "async" + + ```python + from prisma import Prisma + + prisma = Prisma() + await prisma.connect() + + async with prisma.tx() as transaction: + user = await transaction.user.update( + where={'id': from_user_id}, + data={'balance': {'decrement': 50}} + ) + if user.balance < 0: + raise ValueError(f'{user.name} does not have enough balance') + + await transaction.user.update( + where={'id': to_user_id}, + data={'balance': {'increment': 50}} + ) + ``` + +=== "sync" + + ```python + from prisma import Prisma + + prisma = Prisma() + prisma.connect() + + with prisma.tx() as transaction: + user = transaction.user.update( + where={'id': from_user_id}, + data={'balance': {'decrement': 50}} + ) + if user.balance < 0: + raise ValueError(f'{user.name} does not have enough balance') + + transaction.user.update( + where={'id': to_user_id}, + data={'balance': {'increment': 50}} + ) + ``` + +In this example, if the `ValueError` is raised then the entire transaction is rolled-back. This means that the first `update` call is reversed. + +In the case that this example runs successfully, then both database writes are committed when the context manager exits, meaning that queries running elsewhere in your application will then access the updated data. + +## Usage with Model Queries + +!!! warning + + Transactions support alongside [model based queries](./model-actions.md) is not stable. + + Do **not** rely on `Model.prisma()` always using the default `Prisma` instance. + This may be changed in the future. + + +=== "async" + + ```python + from prisma import Prisma + from prisma.models import User + + prisma = Prisma(auto_register=True) + await prisma.connect() + + async with prisma.tx() as transaction: + user = await User.prisma(transaction).update( + where={'id': from_user_id}, + data={'balance': {'decrement': 50}} + ) + if user.balance < 0: + raise ValueError(f'{user.name} does not have enough balance') + + user = await User.prisma(transaction).update( + where={'id': to_user_id}, + data={'balance': {'increment': 50}} + ) + ``` + +=== "sync" + + ```python + prisma = Prisma() + prisma.connect() + + with prisma.tx() as transaction: + user = User.prisma(transaction).update( + where={'id': from_user_id}, + data={'balance': {'decrement': 50}} + ) + if user.balance < 0: + raise ValueError(f'{user.name} does not have enough balance') + + user = User.prisma(transaction).update( + where={'id': to_user_id}, + data={'balance': {'increment': 50}} + ) + ``` + +## Timeouts + +You can pass the following options to configure how timeouts are applied to your transaction: + +`max_wait` - The maximum amount of time Prisma will wait to acquire a transaction from the database. This defaults to `2 seconds`. + +`timeout` - The maximum amount of time the transaction can run before being cancelled and rolled back. This defaults to `5 seconds`. diff --git a/lib/testing/shared_conftest/async_client.py b/lib/testing/shared_conftest/async_client.py index 764c132ba..5b574e645 100644 --- a/lib/testing/shared_conftest/async_client.py +++ b/lib/testing/shared_conftest/async_client.py @@ -1,5 +1,5 @@ import inspect -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, AsyncIterator import prisma from prisma import Prisma @@ -11,6 +11,15 @@ from _pytest.fixtures import FixtureRequest +@async_fixture(name='_cleanup_session', scope='session', autouse=True) +async def cleanup_session() -> AsyncIterator[None]: + yield + + client = prisma.get_client() + if client.is_connected(): + await client.disconnect() + + @async_fixture(name='client', scope='session') async def client_fixture() -> Prisma: client = prisma.get_client() diff --git a/lib/testing/shared_conftest/sync_client.py b/lib/testing/shared_conftest/sync_client.py index b13af722a..de512f973 100644 --- a/lib/testing/shared_conftest/sync_client.py +++ b/lib/testing/shared_conftest/sync_client.py @@ -1,5 +1,5 @@ import inspect -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterator import pytest @@ -12,6 +12,15 @@ from _pytest.fixtures import FixtureRequest +@pytest.fixture(name='_cleanup_session', scope='session', autouse=True) +def cleanup_session() -> Iterator[None]: + yield + + client = prisma.get_client() + if client.is_connected(): + client.disconnect() + + @pytest.fixture(name='client', scope='session') def client_fixture() -> Prisma: client = prisma.get_client() diff --git a/mkdocs.yml b/mkdocs.yml index 21f92a814..65e9a67de 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,7 +37,8 @@ markdown_extensions: - pymdownx.highlight: use_pygments: false - pymdownx.extra - - pymdownx.tabbed + - pymdownx.tabbed: + alternate_style: true - pymdownx.details - pymdownx.snippets - pymdownx.superfences @@ -58,6 +59,7 @@ nav: - Reference: - Query Operations: reference/operations.md - Client: reference/client.md + - Transactions: reference/transactions.md - Selecting Fields: reference/selecting-fields.md - Model Based Access: reference/model-actions.md - Batching Queries: reference/batching.md diff --git a/pipelines/requirements/all.txt b/pipelines/requirements/all.txt new file mode 100644 index 000000000..58a6c6841 --- /dev/null +++ b/pipelines/requirements/all.txt @@ -0,0 +1,2 @@ +-r lint.txt +-r docs.txt diff --git a/pipelines/requirements/deps/pytest-mock.txt b/pipelines/requirements/deps/pytest-mock.txt new file mode 100644 index 000000000..eafced0c0 --- /dev/null +++ b/pipelines/requirements/deps/pytest-mock.txt @@ -0,0 +1 @@ +pytest-mock==3.10.0 diff --git a/pipelines/requirements/test.txt b/pipelines/requirements/test.txt index 154d545c5..1c0a47e9c 100644 --- a/pipelines/requirements/test.txt +++ b/pipelines/requirements/test.txt @@ -1,6 +1,7 @@ -r coverage.txt -r deps/pytest-asyncio.txt -r deps/pytest.txt +-r deps/pytest-mock.txt -r deps/syrupy.txt pytest-sugar mock==5.0.2 diff --git a/requirements/base.txt b/requirements/base.txt index 961004dbb..3b014ce6f 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -3,7 +3,7 @@ jinja2>=2.11.2 pydantic>=1.8.0 click>=7.1.2 python-dotenv>=0.12.0 -typing-extensions>=3.7 +typing-extensions>=4.0.1 tomlkit nodeenv cached-property; python_version < '3.8' diff --git a/src/prisma/engine/__init__.py b/src/prisma/engine/__init__.py index aed77288f..a58ec3657 100644 --- a/src/prisma/engine/__init__.py +++ b/src/prisma/engine/__init__.py @@ -1,4 +1,5 @@ from .errors import * +from ._types import TransactionId as TransactionId try: from .query import * diff --git a/src/prisma/engine/_types.py b/src/prisma/engine/_types.py new file mode 100644 index 000000000..68ec19f56 --- /dev/null +++ b/src/prisma/engine/_types.py @@ -0,0 +1,3 @@ +from typing_extensions import NewType + +TransactionId = NewType('TransactionId', str) diff --git a/src/prisma/engine/utils.py b/src/prisma/engine/utils.py index c12228939..6efd39a40 100644 --- a/src/prisma/engine/utils.py +++ b/src/prisma/engine/utils.py @@ -155,6 +155,7 @@ def get_open_port() -> int: def handle_response_errors(resp: AbstractResponse[Any], data: Any) -> NoReturn: for error in data: try: + base_error_message = error.get('error', '') user_facing = error.get('user_facing_error', {}) code = user_facing.get('error_code') if code is None: @@ -167,6 +168,14 @@ def handle_response_errors(resp: AbstractResponse[Any], data: Any) -> NoReturn: # As we can only check for this error by searching the message then this # comes with performance concerns. message = user_facing.get('message', '') + + if code == 'P2028': + if base_error_message.startswith('Transaction already closed'): + raise prisma_errors.TransactionExpiredError( + base_error_message + ) + raise prisma_errors.TransactionError(message) + if 'A value is required but not set' in message: raise prisma_errors.MissingRequiredValueError(error) diff --git a/src/prisma/errors.py b/src/prisma/errors.py index fd064634f..5e2bff96d 100644 --- a/src/prisma/errors.py +++ b/src/prisma/errors.py @@ -131,6 +131,22 @@ class InputError(DataError): pass +class TransactionError(PrismaError): + pass + + +class TransactionExpiredError(TransactionError): + pass + + +class TransactionNotStartedError(TransactionError): + def __init__(self) -> None: + super().__init__( + 'Transaction has not been started yet.\n' + 'Transactions must be used within a context manager or started manually.' + ) + + class BuilderError(PrismaError): pass diff --git a/src/prisma/generator/templates/_utils.py.jinja b/src/prisma/generator/templates/_utils.py.jinja index 2137ec7aa..79f37932a 100644 --- a/src/prisma/generator/templates/_utils.py.jinja +++ b/src/prisma/generator/templates/_utils.py.jinja @@ -3,9 +3,11 @@ {% if is_async %} {% set maybe_await = 'await ' %} + {% set maybe_async = 'async ' %} {% set maybe_async_def = 'async def ' %} {% else %} {% set maybe_await = '' %} + {% set maybe_async = '' %} {% set maybe_async_def = 'def ' %} {% endif %} diff --git a/src/prisma/generator/templates/client.py.jinja b/src/prisma/generator/templates/client.py.jinja index 84ba503e9..5caa471ba 100644 --- a/src/prisma/generator/templates/client.py.jinja +++ b/src/prisma/generator/templates/client.py.jinja @@ -1,7 +1,9 @@ {% set annotations = true %} {% include '_header.py.jinja' %} -{% from '_utils.py.jinja' import is_async, maybe_async_def, maybe_await, recursive_types, active_provider with context %} +{% from '_utils.py.jinja' import is_async, maybe_async_def, maybe_await, maybe_async, recursive_types, active_provider with context %} # -- template client.py.jinja -- +import warnings +import logging from pathlib import Path from types import TracebackType @@ -11,8 +13,8 @@ from . import types, models, errors, actions from .types import DatasourceOverride, HttpConfig from ._types import BaseModelT, PrismaMethod from .bases import _PrismaModel -from .engine import AbstractEngine, QueryEngine -from .builder import QueryBuilder +from .engine import AbstractEngine, QueryEngine, TransactionId +from .builder import QueryBuilder, dumps from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths from ._compat import removeprefix from ._raw_query import deserialize_raw_results @@ -29,6 +31,8 @@ __all__ = ( 'get_client', ) +log: logging.Logger = logging.getLogger(__name__) + SCHEMA_PATH = Path('{{ schema_path.as_posix() }}') PACKAGED_SCHEMA_PATH = Path(__file__).parent.joinpath('schema.prisma') ENGINE_TYPE: EngineType = EngineType.{{ generator.config.engine_type }} @@ -125,11 +129,13 @@ class Prisma: '{{ model.name.lower() }}', {% endfor %} '__engine', - '_active_provider', - '_log_queries', + '__copied', + '_tx_id', '_datasource', - '_connect_timeout', + '_log_queries', '_http_config', + '_connect_timeout', + '_active_provider', ) def __init__( @@ -145,12 +151,17 @@ class Prisma: {% for model in dmmf.datamodel.models %} self.{{ model.name.lower() }} = actions.{{ model.name }}Actions[models.{{ model.name }}](self, models.{{ model.name }}) {% endfor %} + + # NOTE: if you add any more properties here then you may also need to forward + # them in the `_copy()` method. self.__engine: Optional[AbstractEngine] = None self._active_provider = '{{ active_provider }}' self._log_queries = log_queries self._datasource = datasource self._connect_timeout = connect_timeout self._http_config: HttpConfig = http or {} + self._tx_id: Optional[TransactionId] = None + self.__copied: bool = False if use_dotenv: load_env() @@ -159,10 +170,15 @@ class Prisma: register(self) def __del__(self) -> None: - if self.__engine is not None: - self.__engine.stop() + # Note: as the transaction manager holds a reference to the original + # client as well as the transaction client the original client cannot + # be `free`d before the transaction is finished. So stopping the engine + # here should be safe. + if self.__engine is not None and not self.__copied: + log.debug('unclosed client - stopping engine') + engine = self.__engine self.__engine = None - + engine.stop() {% if is_async %} async def __aenter__(self) -> 'Prisma': @@ -230,13 +246,14 @@ class Prisma: {{ maybe_async_def }}disconnect(self, timeout: Optional[float] = None) -> None: """Disconnect the Prisma query engine.""" if self.__engine is not None: + engine = self.__engine + self.__engine = None {% if is_async %} - {{ maybe_await }}self.__engine.aclose(timeout=timeout) + await engine.aclose(timeout=timeout) {% else %} - self.__engine.close(timeout=timeout) + engine.close(timeout=timeout) {% endif %} - self.__engine.stop(timeout=timeout) - self.__engine = None + engine.stop(timeout=timeout) {% if active_provider != 'mongodb' %} {{ maybe_async_def }}execute_raw(self, query: LiteralString, *args: Any) -> int: @@ -336,6 +353,35 @@ class Prisma: """Returns a context manager for grouping write queries into a single transaction.""" return Batch(client=self) + def tx(self, *, max_wait: int = 2000, timeout: int = 5000) -> 'TransactionManager': + """Returns a context manager for executing queries within a database transaction. + + Entering the context manager returns a new Prisma instance wrapping all + actions within a transaction, queries will be isolated to the Prisma instance and + will not be commited to the database until the context manager exits. + + By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this + defualt with the `max_wait` argument which accepts a value in milliseconds. + + By default, Prisma will cancel and rollback ay transactions that last longer than 5 seconds. You can modify this timeout + with the `timeout` argument which accepts a value in milliseconds. + + Example usage: + + ```py + {{ maybe_async }}with client.tx() as transaction: + user1 = {{ maybe_await }}client.user.create({'name': 'Robert'}) + user2 = {{ maybe_await }}client.user.create({'name': 'Tegan'}) + ``` + + In the above example, if the first database call succeeds but the second does not then neither of the records will be created. + """ + return TransactionManager(client=self, max_wait=max_wait, timeout=timeout) + + def is_transaction(self) -> bool: + """Returns True if the client is wrapped within a transaction""" + return self._tx_id is not None + # TODO: don't return Any {{ maybe_async_def }}_execute( self, @@ -350,7 +396,26 @@ class Prisma: arguments=arguments, root_selection=root_selection, ) - return {{ maybe_await }}self._engine.query(builder.build()) + return {{ maybe_await }}self._engine.query(builder.build(), tx_id=self._tx_id) + + def _copy(self) -> 'Prisma': + """Return a new Prisma instance using the same engine process (if connected). + + This is only intended for private usage, there are no guarantees around this API. + """ + new = Prisma( + use_dotenv=False, + http=self._http_config, + datasource=self._datasource, + log_queries=self._log_queries, + connect_timeout=self._connect_timeout, + ) + new.__copied = True + + if self.__engine is not None: + new._engine = self.__engine + + return new def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine: if ENGINE_TYPE == EngineType.binary: @@ -372,6 +437,10 @@ class Prisma: raise errors.ClientNotConnectedError() return engine + @_engine.setter + def _engine(self, engine: AbstractEngine) -> None: + self.__engine = engine + def _make_sqlite_datasource(self) -> DatasourceOverride: return { 'name': '{{ datasources[0].name }}', @@ -396,6 +465,108 @@ class Prisma: } +class TransactionManager: + """Context manager for wrapping a Prisma instance within a transaction. + + This should never be created manually, instead it should be used + through the Prisma.tx() method. + """ + + def __init__(self, *, client: Prisma, max_wait: int, timeout: int) -> None: + self.__client = client + self._max_wait = max_wait + self._timeout = timeout + self._tx_id: Optional[TransactionId] = None + + {{ maybe_async_def }}start(self, *, _from_context: bool = False) -> Prisma: + """Start the transaction and return the wrapped Prisma instance""" + if self.__client.is_transaction(): + # if we were called from the context manager then the stacklevel + # needs to be one higher to warn on the actual offending code + warnings.warn( + 'The current client is already in a transaction. This can lead to surprising behaviour.', + UserWarning, + stacklevel=3 if _from_context else 2 + ) + + tx_id = {{ maybe_await }}self.__client._engine.start_transaction( + content=dumps( + { + 'timeout': self._timeout, + 'max_wait': self._max_wait, + } + ), + ) + self._tx_id = tx_id + client = self.__client._copy() + client._tx_id = tx_id + return client + + {{ maybe_async_def }}commit(self) -> None: + """Commit the transaction to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + {{ maybe_await }}self.__client._engine.commit_transaction(self._tx_id) + + {{ maybe_async_def }}rollback(self) -> None: + """Do not commit the changes to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + {{ maybe_await }}self.__client._engine.rollback_transaction(self._tx_id) + + {% if is_async %} + async def __aenter__(self) -> Prisma: + return await self.start(_from_context=True) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is None: + log.debug('Transaction exited with no exception - commiting') + await self.commit() + return + + log.debug('Transaction exited with exc type: %s - rolling back', exc_type) + + try: + await self.rollback() + except Exception as exc: + log.warning( + 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', + exc + ) + {% else %} + def __enter__(self) -> Prisma: + return self.start(_from_context=True) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is None: + log.debug('Transaction exited with no exception - commiting') + self.commit() + return + + log.debug('Transaction exited with exc type: %s - rolling back', exc_type) + + try: + self.rollback() + except Exception as exc: + log.warning( + 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', + exc + ) + {% endif %} + + # TODO: this should return the results as well # TODO: don't require copy-pasting arguments between actions and batch actions class Batch: @@ -418,8 +589,6 @@ class Batch: {{ maybe_async_def }}commit(self) -> None: """Execute the queries""" # TODO: normalise this, we should still call client._execute - from .builder import dumps - queries = self.__queries self.__queries = [] @@ -433,7 +602,10 @@ class Batch: ], 'transaction': True, } - {{ maybe_await }}self.__client._engine.query(dumps(payload)) + {{ maybe_await }}self.__client._engine.query( + dumps(payload), + tx_id=self.__client._tx_id, + ) {% if active_provider != 'mongodb' %} def execute_raw(self, query: LiteralString, *args: Any) -> None: diff --git a/src/prisma/generator/templates/engine/abstract.py.jinja b/src/prisma/generator/templates/engine/abstract.py.jinja index b5ea50187..22dd5c878 100644 --- a/src/prisma/generator/templates/engine/abstract.py.jinja +++ b/src/prisma/generator/templates/engine/abstract.py.jinja @@ -1,14 +1,18 @@ +{% set annotations = true %} {% include '_header.py.jinja' %} {% from '_utils.py.jinja' import maybe_async_def with context %} # -- template engine/abstract.py.jinja -- from abc import ABC, abstractmethod +from ._types import TransactionId from ..types import DatasourceOverride from .._compat import get_running_loop + __all__ = ( 'AbstractEngine', ) + class AbstractEngine(ABC): dml: str @@ -54,7 +58,7 @@ class AbstractEngine(ABC): ... @abstractmethod - {{ maybe_async_def }}query(self, content: str) -> Any: + {{ maybe_async_def }}query(self, content: str, *, tx_id: TransactionId | None) -> Any: """Execute a GraphQL query. This method expects a JSON object matching this structure: @@ -66,3 +70,18 @@ class AbstractEngine(ABC): } """ ... + + @abstractmethod + {{ maybe_async_def }}start_transaction(self, *, content: str) -> TransactionId: + """Start an interactive transaction, returns the transaction ID that can be used to perform subsequent operations""" + ... + + @abstractmethod + {{ maybe_async_def }}commit_transaction(self, tx_id: TransactionId) -> None: + """Commit an interactive transaction, the given transaction will no longer be usable""" + ... + + @abstractmethod + {{ maybe_async_def }}rollback_transaction(self, tx_id: TransactionId) -> None: + """Rollback an interactive transaction, the given transaction will no longer be usable""" + ... diff --git a/src/prisma/generator/templates/engine/http.py.jinja b/src/prisma/generator/templates/engine/http.py.jinja index 75dd179de..7ed9aef88 100644 --- a/src/prisma/generator/templates/engine/http.py.jinja +++ b/src/prisma/generator/templates/engine/http.py.jinja @@ -29,13 +29,11 @@ class HTTPEngine(AbstractEngine): headers: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> None: + super().__init__() self.url = url self.session = HTTP(**kwargs) self.headers = headers if headers is not None else {} - def __del__(self) -> None: - self.stop() - {% if is_async %} def close(self, *, timeout: Optional[float] = None) -> None: pass @@ -54,7 +52,14 @@ class HTTPEngine(AbstractEngine): if self.session and not self.session.closed: {{ maybe_await }}self.session.close() - {{ maybe_async_def }}request(self, method: Method, path: str, *, content: Any = None) -> Any: + {{ maybe_async_def }}request( + self, + method: Method, + path: str, + *, + content: Any = None, + headers: Optional[Dict[str, str]] = None, + ) -> Any: if self.url is None: raise errors.NotConnectedError('Not connected to the query engine') @@ -66,13 +71,19 @@ class HTTPEngine(AbstractEngine): } } + if headers is not None: + kwargs['headers'].update(headers) + if content is not None: kwargs['content'] = content url = self.url + path - log.debug('Sending %s request to %s with content: %s', method, url, content) + log.debug('Sending %s request to %s', method, url) + log.debug('Request headers: %s', kwargs['headers']) + log.debug('Request content: %s', content) resp = {{ maybe_await }}self.session.request(method, url, **kwargs) + log.debug('%s %s returned status %s', method, url, resp.status) if 300 > resp.status >= 200: response = {{ maybe_await }}resp.json() diff --git a/src/prisma/generator/templates/engine/query.py.jinja b/src/prisma/generator/templates/engine/query.py.jinja index 5f9d4f801..707c80f1b 100644 --- a/src/prisma/generator/templates/engine/query.py.jinja +++ b/src/prisma/generator/templates/engine/query.py.jinja @@ -20,6 +20,7 @@ from ..binaries import platform from ..utils import time_since, _env_bool from ..types import DatasourceOverride from ..builder import dumps +from ._types import TransactionId __all__ = ('QueryEngine',) @@ -43,9 +44,6 @@ class QueryEngine(HTTPEngine): # ensure the query engine process is terminated when we are atexit.register(self.stop) - def __del__(self) -> None: - self.stop() - def close(self, *, timeout: Optional[float] = None) -> None: log.debug('Disconnecting query engine...') @@ -178,9 +176,40 @@ class QueryEngine(HTTPEngine): 'Could not connect to the query engine' ) from last_exc - {{ maybe_async_def }}query(self, content: str) -> Any: - return {{ maybe_await }}self.request('POST', '/', content=content) + {{ maybe_async_def }}query( + self, + content: str, + *, + tx_id: TransactionId | None, + ) -> Any: + headers: Dict[str, str] = {} + if tx_id is not None: + headers['X-transaction-id'] = tx_id + + return {{ maybe_await }}self.request( + 'POST', + '/', + content=content, + headers=headers, + ) + + {{ maybe_async_def }}start_transaction(self, *, content: str) -> TransactionId: + result = {{ maybe_await }}self.request( + 'POST', + '/transaction/start', + content=content, + ) + return TransactionId(result['id']) + + {{ maybe_async_def }}commit_transaction(self, tx_id: TransactionId) -> None: + {{ maybe_await }}self.request( + 'POST', f'/transaction/{tx_id}/commit' + ) + {{ maybe_async_def }}rollback_transaction(self, tx_id: TransactionId) -> None: + {{ maybe_await }}self.request( + 'POST', f'/transaction/{tx_id}/rollback' + ) # black does not respect the fmt: off comment without this # fmt: on diff --git a/tests/test_client.py b/tests/test_client.py index d87797648..bd4bdb85f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -204,3 +204,42 @@ def test_sqlite_url(client: Prisma) -> None: url = client._make_sqlite_url('sqlite:sqlite.db') assert url == f'file:{SCHEMA_PATH.parent.joinpath("sqlite.db")}' + + +@pytest.mark.asyncio +async def test_copy() -> None: + """The Prisma._copy() method forwards all relevant properties""" + client1 = Prisma( + log_queries=True, + datasource={ + 'url': 'file:foo.db', + }, + connect_timeout=15, + http={ + 'trust_env': False, + }, + ) + client2 = client1._copy() + assert not client2.is_connected() is None + assert client2._log_queries is True + assert client2._datasource == {'url': 'file:foo.db'} + assert client2._connect_timeout == 15 + assert client2._http_config == {'trust_env': False} + + await client1.connect() + assert client1.is_connected() + client3 = client1._copy() + assert client3.is_connected() + + +@pytest.mark.asyncio +async def test_copied_client_does_not_close_engine(client: Prisma) -> None: + """Deleting a Prisma._copy()'d client does not cause the engine to be stopped""" + copied = client._copy() + assert copied.is_connected() + assert client.is_connected() + + del copied + + assert client.is_connected() + await client.user.count() # ensure queries can still be executed diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[client.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[client.py].raw index 4f814c626..cc9f8e0e2 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[client.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[client.py].raw @@ -40,6 +40,8 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template client.py.jinja -- +import warnings +import logging from pathlib import Path from types import TracebackType @@ -49,8 +51,8 @@ from . import types, models, errors, actions from .types import DatasourceOverride, HttpConfig from ._types import BaseModelT, PrismaMethod from .bases import _PrismaModel -from .engine import AbstractEngine, QueryEngine -from .builder import QueryBuilder +from .engine import AbstractEngine, QueryEngine, TransactionId +from .builder import QueryBuilder, dumps from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths from ._compat import removeprefix from ._raw_query import deserialize_raw_results @@ -67,6 +69,8 @@ __all__ = ( 'get_client', ) +log: logging.Logger = logging.getLogger(__name__) + SCHEMA_PATH = Path('') PACKAGED_SCHEMA_PATH = Path(__file__).parent.joinpath('schema.prisma') ENGINE_TYPE: EngineType = EngineType.binary @@ -181,11 +185,13 @@ class Prisma: 'd', 'e', '__engine', - '_active_provider', - '_log_queries', + '__copied', + '_tx_id', '_datasource', - '_connect_timeout', + '_log_queries', '_http_config', + '_connect_timeout', + '_active_provider', ) def __init__( @@ -210,12 +216,17 @@ class Prisma: self.c = actions.CActions[models.C](self, models.C) self.d = actions.DActions[models.D](self, models.D) self.e = actions.EActions[models.E](self, models.E) + + # NOTE: if you add any more properties here then you may also need to forward + # them in the `_copy()` method. self.__engine: Optional[AbstractEngine] = None self._active_provider = 'postgresql' self._log_queries = log_queries self._datasource = datasource self._connect_timeout = connect_timeout self._http_config: HttpConfig = http or {} + self._tx_id: Optional[TransactionId] = None + self.__copied: bool = False if use_dotenv: load_env() @@ -224,10 +235,15 @@ class Prisma: register(self) def __del__(self) -> None: - if self.__engine is not None: - self.__engine.stop() + # Note: as the transaction manager holds a reference to the original + # client as well as the transaction client the original client cannot + # be `free`d before the transaction is finished. So stopping the engine + # here should be safe. + if self.__engine is not None and not self.__copied: + log.debug('unclosed client - stopping engine') + engine = self.__engine self.__engine = None - + engine.stop() async def __aenter__(self) -> 'Prisma': await self.connect() @@ -274,9 +290,10 @@ class Prisma: async def disconnect(self, timeout: Optional[float] = None) -> None: """Disconnect the Prisma query engine.""" if self.__engine is not None: - await self.__engine.aclose(timeout=timeout) - self.__engine.stop(timeout=timeout) + engine = self.__engine self.__engine = None + await engine.aclose(timeout=timeout) + engine.stop(timeout=timeout) async def execute_raw(self, query: LiteralString, *args: Any) -> int: resp = await self._execute( @@ -374,6 +391,35 @@ class Prisma: """Returns a context manager for grouping write queries into a single transaction.""" return Batch(client=self) + def tx(self, *, max_wait: int = 2000, timeout: int = 5000) -> 'TransactionManager': + """Returns a context manager for executing queries within a database transaction. + + Entering the context manager returns a new Prisma instance wrapping all + actions within a transaction, queries will be isolated to the Prisma instance and + will not be commited to the database until the context manager exits. + + By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this + defualt with the `max_wait` argument which accepts a value in milliseconds. + + By default, Prisma will cancel and rollback ay transactions that last longer than 5 seconds. You can modify this timeout + with the `timeout` argument which accepts a value in milliseconds. + + Example usage: + + ```py + async with client.tx() as transaction: + user1 = await client.user.create({'name': 'Robert'}) + user2 = await client.user.create({'name': 'Tegan'}) + ``` + + In the above example, if the first database call succeeds but the second does not then neither of the records will be created. + """ + return TransactionManager(client=self, max_wait=max_wait, timeout=timeout) + + def is_transaction(self) -> bool: + """Returns True if the client is wrapped within a transaction""" + return self._tx_id is not None + # TODO: don't return Any async def _execute( self, @@ -388,7 +434,26 @@ class Prisma: arguments=arguments, root_selection=root_selection, ) - return await self._engine.query(builder.build()) + return await self._engine.query(builder.build(), tx_id=self._tx_id) + + def _copy(self) -> 'Prisma': + """Return a new Prisma instance using the same engine process (if connected). + + This is only intended for private usage, there are no guarantees around this API. + """ + new = Prisma( + use_dotenv=False, + http=self._http_config, + datasource=self._datasource, + log_queries=self._log_queries, + connect_timeout=self._connect_timeout, + ) + new.__copied = True + + if self.__engine is not None: + new._engine = self.__engine + + return new def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine: if ENGINE_TYPE == EngineType.binary: @@ -410,6 +475,10 @@ class Prisma: raise errors.ClientNotConnectedError() return engine + @_engine.setter + def _engine(self, engine: AbstractEngine) -> None: + self.__engine = engine + def _make_sqlite_datasource(self) -> DatasourceOverride: return { 'name': 'db', @@ -434,6 +503,82 @@ class Prisma: } +class TransactionManager: + """Context manager for wrapping a Prisma instance within a transaction. + + This should never be created manually, instead it should be used + through the Prisma.tx() method. + """ + + def __init__(self, *, client: Prisma, max_wait: int, timeout: int) -> None: + self.__client = client + self._max_wait = max_wait + self._timeout = timeout + self._tx_id: Optional[TransactionId] = None + + async def start(self, *, _from_context: bool = False) -> Prisma: + """Start the transaction and return the wrapped Prisma instance""" + if self.__client.is_transaction(): + # if we were called from the context manager then the stacklevel + # needs to be one higher to warn on the actual offending code + warnings.warn( + 'The current client is already in a transaction. This can lead to surprising behaviour.', + UserWarning, + stacklevel=3 if _from_context else 2 + ) + + tx_id = await self.__client._engine.start_transaction( + content=dumps( + { + 'timeout': self._timeout, + 'max_wait': self._max_wait, + } + ), + ) + self._tx_id = tx_id + client = self.__client._copy() + client._tx_id = tx_id + return client + + async def commit(self) -> None: + """Commit the transaction to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + await self.__client._engine.commit_transaction(self._tx_id) + + async def rollback(self) -> None: + """Do not commit the changes to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + await self.__client._engine.rollback_transaction(self._tx_id) + + async def __aenter__(self) -> Prisma: + return await self.start(_from_context=True) + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is None: + log.debug('Transaction exited with no exception - commiting') + await self.commit() + return + + log.debug('Transaction exited with exc type: %s - rolling back', exc_type) + + try: + await self.rollback() + except Exception as exc: + log.warning( + 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', + exc + ) + + # TODO: this should return the results as well # TODO: don't require copy-pasting arguments between actions and batch actions class Batch: @@ -474,8 +619,6 @@ class Batch: async def commit(self) -> None: """Execute the queries""" # TODO: normalise this, we should still call client._execute - from .builder import dumps - queries = self.__queries self.__queries = [] @@ -489,7 +632,10 @@ class Batch: ], 'transaction': True, } - await self.__client._engine.query(dumps(payload)) + await self.__client._engine.query( + dumps(payload), + tx_id=self.__client._tx_id, + ) def execute_raw(self, query: LiteralString, *args: Any) -> None: self._add( diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[engineabstract.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[engineabstract.py].raw index dc58d5657..c98bab673 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[engineabstract.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[engineabstract.py].raw @@ -3,6 +3,7 @@ # code generated by Prisma. DO NOT EDIT. # pyright: reportUnusedImport=false # fmt: off +from __future__ import annotations # global imports for type checking from builtins import bool as _bool @@ -40,13 +41,16 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template engine/abstract.py.jinja -- from abc import ABC, abstractmethod +from ._types import TransactionId from ..types import DatasourceOverride from .._compat import get_running_loop + __all__ = ( 'AbstractEngine', ) + class AbstractEngine(ABC): dml: str @@ -92,7 +96,7 @@ class AbstractEngine(ABC): ... @abstractmethod - async def query(self, content: str) -> Any: + async def query(self, content: str, *, tx_id: TransactionId | None) -> Any: """Execute a GraphQL query. This method expects a JSON object matching this structure: @@ -104,4 +108,19 @@ class AbstractEngine(ABC): } """ ... + + @abstractmethod + async def start_transaction(self, *, content: str) -> TransactionId: + """Start an interactive transaction, returns the transaction ID that can be used to perform subsequent operations""" + ... + + @abstractmethod + async def commit_transaction(self, tx_id: TransactionId) -> None: + """Commit an interactive transaction, the given transaction will no longer be usable""" + ... + + @abstractmethod + async def rollback_transaction(self, tx_id: TransactionId) -> None: + """Rollback an interactive transaction, the given transaction will no longer be usable""" + ... ''' \ No newline at end of file diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginehttp.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginehttp.py].raw index 1a617b364..ed55a1952 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginehttp.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginehttp.py].raw @@ -67,13 +67,11 @@ class HTTPEngine(AbstractEngine): headers: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> None: + super().__init__() self.url = url self.session = HTTP(**kwargs) self.headers = headers if headers is not None else {} - def __del__(self) -> None: - self.stop() - def close(self, *, timeout: Optional[float] = None) -> None: pass @@ -84,7 +82,14 @@ class HTTPEngine(AbstractEngine): if self.session and not self.session.closed: await self.session.close() - async def request(self, method: Method, path: str, *, content: Any = None) -> Any: + async def request( + self, + method: Method, + path: str, + *, + content: Any = None, + headers: Optional[Dict[str, str]] = None, + ) -> Any: if self.url is None: raise errors.NotConnectedError('Not connected to the query engine') @@ -96,13 +101,19 @@ class HTTPEngine(AbstractEngine): } } + if headers is not None: + kwargs['headers'].update(headers) + if content is not None: kwargs['content'] = content url = self.url + path - log.debug('Sending %s request to %s with content: %s', method, url, content) + log.debug('Sending %s request to %s', method, url) + log.debug('Request headers: %s', kwargs['headers']) + log.debug('Request content: %s', content) resp = await self.session.request(method, url, **kwargs) + log.debug('%s %s returned status %s', method, url, resp.status) if 300 > resp.status >= 200: response = await resp.json() diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginequery.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginequery.py].raw index 16a3744cd..dac7db595 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginequery.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_async[enginequery.py].raw @@ -58,6 +58,7 @@ from ..binaries import platform from ..utils import time_since, _env_bool from ..types import DatasourceOverride from ..builder import dumps +from ._types import TransactionId __all__ = ('QueryEngine',) @@ -81,9 +82,6 @@ class QueryEngine(HTTPEngine): # ensure the query engine process is terminated when we are atexit.register(self.stop) - def __del__(self) -> None: - self.stop() - def close(self, *, timeout: Optional[float] = None) -> None: log.debug('Disconnecting query engine...') @@ -215,9 +213,40 @@ class QueryEngine(HTTPEngine): 'Could not connect to the query engine' ) from last_exc - async def query(self, content: str) -> Any: - return await self.request('POST', '/', content=content) + async def query( + self, + content: str, + *, + tx_id: TransactionId | None, + ) -> Any: + headers: Dict[str, str] = {} + if tx_id is not None: + headers['X-transaction-id'] = tx_id + + return await self.request( + 'POST', + '/', + content=content, + headers=headers, + ) + + async def start_transaction(self, *, content: str) -> TransactionId: + result = await self.request( + 'POST', + '/transaction/start', + content=content, + ) + return TransactionId(result['id']) + + async def commit_transaction(self, tx_id: TransactionId) -> None: + await self.request( + 'POST', f'/transaction/{tx_id}/commit' + ) + async def rollback_transaction(self, tx_id: TransactionId) -> None: + await self.request( + 'POST', f'/transaction/{tx_id}/rollback' + ) # black does not respect the fmt: off comment without this # fmt: on diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[client.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[client.py].raw index 8ab9eef81..603792371 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[client.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[client.py].raw @@ -40,6 +40,8 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template client.py.jinja -- +import warnings +import logging from pathlib import Path from types import TracebackType @@ -49,8 +51,8 @@ from . import types, models, errors, actions from .types import DatasourceOverride, HttpConfig from ._types import BaseModelT, PrismaMethod from .bases import _PrismaModel -from .engine import AbstractEngine, QueryEngine -from .builder import QueryBuilder +from .engine import AbstractEngine, QueryEngine, TransactionId +from .builder import QueryBuilder, dumps from .generator.models import EngineType, OptionalValueFromEnvVar, BinaryPaths from ._compat import removeprefix from ._raw_query import deserialize_raw_results @@ -67,6 +69,8 @@ __all__ = ( 'get_client', ) +log: logging.Logger = logging.getLogger(__name__) + SCHEMA_PATH = Path('') PACKAGED_SCHEMA_PATH = Path(__file__).parent.joinpath('schema.prisma') ENGINE_TYPE: EngineType = EngineType.binary @@ -181,11 +185,13 @@ class Prisma: 'd', 'e', '__engine', - '_active_provider', - '_log_queries', + '__copied', + '_tx_id', '_datasource', - '_connect_timeout', + '_log_queries', '_http_config', + '_connect_timeout', + '_active_provider', ) def __init__( @@ -210,12 +216,17 @@ class Prisma: self.c = actions.CActions[models.C](self, models.C) self.d = actions.DActions[models.D](self, models.D) self.e = actions.EActions[models.E](self, models.E) + + # NOTE: if you add any more properties here then you may also need to forward + # them in the `_copy()` method. self.__engine: Optional[AbstractEngine] = None self._active_provider = 'postgresql' self._log_queries = log_queries self._datasource = datasource self._connect_timeout = connect_timeout self._http_config: HttpConfig = http or {} + self._tx_id: Optional[TransactionId] = None + self.__copied: bool = False if use_dotenv: load_env() @@ -224,10 +235,15 @@ class Prisma: register(self) def __del__(self) -> None: - if self.__engine is not None: - self.__engine.stop() + # Note: as the transaction manager holds a reference to the original + # client as well as the transaction client the original client cannot + # be `free`d before the transaction is finished. So stopping the engine + # here should be safe. + if self.__engine is not None and not self.__copied: + log.debug('unclosed client - stopping engine') + engine = self.__engine self.__engine = None - + engine.stop() def __enter__(self) -> 'Prisma': self.connect() @@ -274,9 +290,10 @@ class Prisma: def disconnect(self, timeout: Optional[float] = None) -> None: """Disconnect the Prisma query engine.""" if self.__engine is not None: - self.__engine.close(timeout=timeout) - self.__engine.stop(timeout=timeout) + engine = self.__engine self.__engine = None + engine.close(timeout=timeout) + engine.stop(timeout=timeout) def execute_raw(self, query: LiteralString, *args: Any) -> int: resp = self._execute( @@ -374,6 +391,35 @@ class Prisma: """Returns a context manager for grouping write queries into a single transaction.""" return Batch(client=self) + def tx(self, *, max_wait: int = 2000, timeout: int = 5000) -> 'TransactionManager': + """Returns a context manager for executing queries within a database transaction. + + Entering the context manager returns a new Prisma instance wrapping all + actions within a transaction, queries will be isolated to the Prisma instance and + will not be commited to the database until the context manager exits. + + By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this + defualt with the `max_wait` argument which accepts a value in milliseconds. + + By default, Prisma will cancel and rollback ay transactions that last longer than 5 seconds. You can modify this timeout + with the `timeout` argument which accepts a value in milliseconds. + + Example usage: + + ```py + with client.tx() as transaction: + user1 = client.user.create({'name': 'Robert'}) + user2 = client.user.create({'name': 'Tegan'}) + ``` + + In the above example, if the first database call succeeds but the second does not then neither of the records will be created. + """ + return TransactionManager(client=self, max_wait=max_wait, timeout=timeout) + + def is_transaction(self) -> bool: + """Returns True if the client is wrapped within a transaction""" + return self._tx_id is not None + # TODO: don't return Any def _execute( self, @@ -388,7 +434,26 @@ class Prisma: arguments=arguments, root_selection=root_selection, ) - return self._engine.query(builder.build()) + return self._engine.query(builder.build(), tx_id=self._tx_id) + + def _copy(self) -> 'Prisma': + """Return a new Prisma instance using the same engine process (if connected). + + This is only intended for private usage, there are no guarantees around this API. + """ + new = Prisma( + use_dotenv=False, + http=self._http_config, + datasource=self._datasource, + log_queries=self._log_queries, + connect_timeout=self._connect_timeout, + ) + new.__copied = True + + if self.__engine is not None: + new._engine = self.__engine + + return new def _create_engine(self, dml_path: Path = PACKAGED_SCHEMA_PATH) -> AbstractEngine: if ENGINE_TYPE == EngineType.binary: @@ -410,6 +475,10 @@ class Prisma: raise errors.ClientNotConnectedError() return engine + @_engine.setter + def _engine(self, engine: AbstractEngine) -> None: + self.__engine = engine + def _make_sqlite_datasource(self) -> DatasourceOverride: return { 'name': 'db', @@ -434,6 +503,82 @@ class Prisma: } +class TransactionManager: + """Context manager for wrapping a Prisma instance within a transaction. + + This should never be created manually, instead it should be used + through the Prisma.tx() method. + """ + + def __init__(self, *, client: Prisma, max_wait: int, timeout: int) -> None: + self.__client = client + self._max_wait = max_wait + self._timeout = timeout + self._tx_id: Optional[TransactionId] = None + + def start(self, *, _from_context: bool = False) -> Prisma: + """Start the transaction and return the wrapped Prisma instance""" + if self.__client.is_transaction(): + # if we were called from the context manager then the stacklevel + # needs to be one higher to warn on the actual offending code + warnings.warn( + 'The current client is already in a transaction. This can lead to surprising behaviour.', + UserWarning, + stacklevel=3 if _from_context else 2 + ) + + tx_id = self.__client._engine.start_transaction( + content=dumps( + { + 'timeout': self._timeout, + 'max_wait': self._max_wait, + } + ), + ) + self._tx_id = tx_id + client = self.__client._copy() + client._tx_id = tx_id + return client + + def commit(self) -> None: + """Commit the transaction to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + self.__client._engine.commit_transaction(self._tx_id) + + def rollback(self) -> None: + """Do not commit the changes to the database, this transaction will no longer be usable""" + if self._tx_id is None: + raise errors.TransactionNotStartedError() + + self.__client._engine.rollback_transaction(self._tx_id) + + def __enter__(self) -> Prisma: + return self.start(_from_context=True) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + if exc is None: + log.debug('Transaction exited with no exception - commiting') + self.commit() + return + + log.debug('Transaction exited with exc type: %s - rolling back', exc_type) + + try: + self.rollback() + except Exception as exc: + log.warning( + 'Encountered exc `%s` while rolling back a transaction. Ignoring and raising original exception', + exc + ) + + # TODO: this should return the results as well # TODO: don't require copy-pasting arguments between actions and batch actions class Batch: @@ -474,8 +619,6 @@ class Batch: def commit(self) -> None: """Execute the queries""" # TODO: normalise this, we should still call client._execute - from .builder import dumps - queries = self.__queries self.__queries = [] @@ -489,7 +632,10 @@ class Batch: ], 'transaction': True, } - self.__client._engine.query(dumps(payload)) + self.__client._engine.query( + dumps(payload), + tx_id=self.__client._tx_id, + ) def execute_raw(self, query: LiteralString, *args: Any) -> None: self._add( diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[engineabstract.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[engineabstract.py].raw index 26ef7f690..7c6378055 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[engineabstract.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[engineabstract.py].raw @@ -3,6 +3,7 @@ # code generated by Prisma. DO NOT EDIT. # pyright: reportUnusedImport=false # fmt: off +from __future__ import annotations # global imports for type checking from builtins import bool as _bool @@ -40,13 +41,16 @@ from typing_extensions import TypedDict, Literal LiteralString = str # -- template engine/abstract.py.jinja -- from abc import ABC, abstractmethod +from ._types import TransactionId from ..types import DatasourceOverride from .._compat import get_running_loop + __all__ = ( 'AbstractEngine', ) + class AbstractEngine(ABC): dml: str @@ -92,7 +96,7 @@ class AbstractEngine(ABC): ... @abstractmethod - def query(self, content: str) -> Any: + def query(self, content: str, *, tx_id: TransactionId | None) -> Any: """Execute a GraphQL query. This method expects a JSON object matching this structure: @@ -104,4 +108,19 @@ class AbstractEngine(ABC): } """ ... + + @abstractmethod + def start_transaction(self, *, content: str) -> TransactionId: + """Start an interactive transaction, returns the transaction ID that can be used to perform subsequent operations""" + ... + + @abstractmethod + def commit_transaction(self, tx_id: TransactionId) -> None: + """Commit an interactive transaction, the given transaction will no longer be usable""" + ... + + @abstractmethod + def rollback_transaction(self, tx_id: TransactionId) -> None: + """Rollback an interactive transaction, the given transaction will no longer be usable""" + ... ''' \ No newline at end of file diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginehttp.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginehttp.py].raw index 4e2e81628..14ee31ed0 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginehttp.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginehttp.py].raw @@ -67,13 +67,11 @@ class HTTPEngine(AbstractEngine): headers: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> None: + super().__init__() self.url = url self.session = HTTP(**kwargs) self.headers = headers if headers is not None else {} - def __del__(self) -> None: - self.stop() - def close(self, *, timeout: Optional[float] = None) -> None: self._close_session() @@ -84,7 +82,14 @@ class HTTPEngine(AbstractEngine): if self.session and not self.session.closed: self.session.close() - def request(self, method: Method, path: str, *, content: Any = None) -> Any: + def request( + self, + method: Method, + path: str, + *, + content: Any = None, + headers: Optional[Dict[str, str]] = None, + ) -> Any: if self.url is None: raise errors.NotConnectedError('Not connected to the query engine') @@ -96,13 +101,19 @@ class HTTPEngine(AbstractEngine): } } + if headers is not None: + kwargs['headers'].update(headers) + if content is not None: kwargs['content'] = content url = self.url + path - log.debug('Sending %s request to %s with content: %s', method, url, content) + log.debug('Sending %s request to %s', method, url) + log.debug('Request headers: %s', kwargs['headers']) + log.debug('Request content: %s', content) resp = self.session.request(method, url, **kwargs) + log.debug('%s %s returned status %s', method, url, resp.status) if 300 > resp.status >= 200: response = resp.json() diff --git a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginequery.py].raw b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginequery.py].raw index 3c6cbbaf6..2401ed8c7 100644 --- a/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginequery.py].raw +++ b/tests/test_generation/exhaustive/__snapshots__/test_exhaustive/test_sync[enginequery.py].raw @@ -58,6 +58,7 @@ from ..binaries import platform from ..utils import time_since, _env_bool from ..types import DatasourceOverride from ..builder import dumps +from ._types import TransactionId __all__ = ('QueryEngine',) @@ -81,9 +82,6 @@ class QueryEngine(HTTPEngine): # ensure the query engine process is terminated when we are atexit.register(self.stop) - def __del__(self) -> None: - self.stop() - def close(self, *, timeout: Optional[float] = None) -> None: log.debug('Disconnecting query engine...') @@ -216,9 +214,40 @@ class QueryEngine(HTTPEngine): 'Could not connect to the query engine' ) from last_exc - def query(self, content: str) -> Any: - return self.request('POST', '/', content=content) + def query( + self, + content: str, + *, + tx_id: TransactionId | None, + ) -> Any: + headers: Dict[str, str] = {} + if tx_id is not None: + headers['X-transaction-id'] = tx_id + + return self.request( + 'POST', + '/', + content=content, + headers=headers, + ) + + def start_transaction(self, *, content: str) -> TransactionId: + result = self.request( + 'POST', + '/transaction/start', + content=content, + ) + return TransactionId(result['id']) + + def commit_transaction(self, tx_id: TransactionId) -> None: + self.request( + 'POST', f'/transaction/{tx_id}/commit' + ) + def rollback_transaction(self, tx_id: TransactionId) -> None: + self.request( + 'POST', f'/transaction/{tx_id}/rollback' + ) # black does not respect the fmt: off comment without this # fmt: on