Skip to content

Commit

Permalink
feat(client): add transaction isolation level
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and jonathanblade committed Feb 12, 2024
1 parent 22cc236 commit 69cc396
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 29 deletions.
65 changes: 64 additions & 1 deletion databases/sync_tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, ISOLATION_LEVELS_MAPPING, RawQueries


def test_model_query(client: Prisma) -> None:
Expand Down Expand Up @@ -201,3 +201,66 @@ def test_transaction_already_closed(client: Prisma) -> None:
transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.parametrize(
('input_level',),
[
pytest.param(
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(CURRENT_DATABASE != 'sqlserver', reason='Not available'),
),
pytest.param(
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite',
reason="SQLite doesn't have the way to query the current transaction isolation level",
),
),
],
)
@pytest.mark.skipif(CURRENT_DATABASE == 'mongodb', reason='Not available')
@pytest.mark.skipif(
CURRENT_DATABASE in ['mysql', 'mariadb'],
reason="""
MySQL 8.0 doesn't have the way to query the current transaction isolation level.
See https://bugs.mysql.com/bug.php?id=53341
Refs:
* https://github.com/prisma/prisma/issues/22890
""",
)
def test_isolation_level(
client: Prisma,
database: str,
raw_queries: RawQueries,
input_level: str,
) -> None:
"""Ensure that transaction isolation level is set correctly"""
with client.tx(isolation_level=getattr(prisma.TransactionIsolationLevel, input_level)) as tx:
results = tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
assert level == ISOLATION_LEVELS_MAPPING[input_level][database]
66 changes: 65 additions & 1 deletion databases/tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, ISOLATION_LEVELS_MAPPING, RawQueries


@pytest.mark.asyncio
Expand Down Expand Up @@ -212,3 +212,67 @@ async def test_transaction_already_closed(client: Prisma) -> None:
await transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.asyncio
@pytest.mark.parametrize(
('input_level',),
[
pytest.param(
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(CURRENT_DATABASE != 'sqlserver', reason='Not available'),
),
pytest.param(
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite',
reason="SQLite doesn't have the way to query the current transaction isolation level",
),
),
],
)
@pytest.mark.skipif(CURRENT_DATABASE == 'mongodb', reason='Not available')
@pytest.mark.skipif(
CURRENT_DATABASE in ['mysql', 'mariadb'],
reason="""
MySQL 8.0 doesn't have the way to query the current transaction isolation level.
See https://bugs.mysql.com/bug.php?id=53341
Refs:
* https://github.com/prisma/prisma/issues/22890
""",
)
async def test_isolation_level(
client: Prisma,
database: str,
raw_queries: RawQueries,
input_level: str,
) -> None:
"""Ensure that transaction isolation level is set correctly"""
async with client.tx(isolation_level=getattr(prisma.TransactionIsolationLevel, input_level)) as tx:
results = await tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
assert level == ISOLATION_LEVELS_MAPPING[input_level][database]
63 changes: 61 additions & 2 deletions databases/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import os
from typing import Set
from typing import Set, Optional
from pathlib import Path
from typing_extensions import Literal, get_args, override
from typing_extensions import Literal, TypedDict, get_args, override

from pydantic import BaseModel
from syrupy.extensions.amber import AmberSnapshotExtension
Expand Down Expand Up @@ -85,6 +85,8 @@ class RawQueries(BaseModel):
test_query_raw_no_result: LiteralString
test_execute_raw_no_result: LiteralString

select_tx_isolation: LiteralString


_mysql_queries = RawQueries(
count_posts="""
Expand Down Expand Up @@ -136,8 +138,12 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SELECT @@transaction_isolation
""",
)


_postgresql_queries = RawQueries(
count_posts="""
SELECT COUNT(*) as count
Expand Down Expand Up @@ -188,6 +194,9 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SHOW transaction_isolation
""",
)

RAW_QUERIES_MAPPING: DatabaseMapping[RawQueries] = {
Expand Down Expand Up @@ -245,5 +254,55 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
Not avaliable
""",
),
}


class IsolationLevelsMapping(TypedDict):
READ_UNCOMMITTED: DatabaseMapping[Optional[LiteralString]]
READ_COMMITTED: DatabaseMapping[Optional[LiteralString]]
REPEATABLE_READ: DatabaseMapping[Optional[LiteralString]]
SNAPSHOT: DatabaseMapping[Optional[LiteralString]]
SERIALIZABLE: DatabaseMapping[Optional[LiteralString]]


ISOLATION_LEVELS_MAPPING: IsolationLevelsMapping = {
'READ_UNCOMMITTED': {
'postgresql': 'read uncommitted',
'cockroachdb': None,
'mysql': 'READ-UNCOMMITTED',
'mariadb': 'READ-UNCOMMITTED',
'sqlite': None,
},
'READ_COMMITTED': {
'postgresql': 'read committed',
'cockroachdb': None,
'mysql': 'READ-COMMITTED',
'mariadb': 'READ-COMMITTED',
'sqlite': None,
},
'REPEATABLE_READ': {
'postgresql': 'repeatable read',
'cockroachdb': None,
'mysql': 'REPEATABLE-READ',
'mariadb': 'REPEATABLE-READ',
'sqlite': None,
},
'SNAPSHOT': {
'postgresql': None,
'cockroachdb': None,
'mysql': None,
'mariadb': None,
'sqlite': None,
},
'SERIALIZABLE': {
'postgresql': 'serializable',
'cockroachdb': 'SERIALIZABLE',
'mysql': 'SERIALIZABLE',
'mariadb': 'SERIALIZABLE',
'sqlite': None,
},
}
17 changes: 17 additions & 0 deletions docs/reference/transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ In the case that this example runs successfully, then both database writes are c
)
```

## Isolation levels

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

!!! note
Prisma Client Python generates `TransactionIsolationLevel` enumeration that includes only the options supported by the current database.

```py
from prisma import Prisma, TransactionIsolationLevel

client = Prisma()
client.tx(
isolation_level=TransactionIsolationLevel.READ_UNCOMMITTED,
)
```

## Timeouts

You can pass the following options to configure how timeouts are applied to your transaction:
Expand Down
51 changes: 32 additions & 19 deletions src/prisma/_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,31 @@
import logging
import warnings
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from datetime import timedelta

from ._types import TransactionId
from .errors import TransactionNotStartedError
from ._compat import StrEnum
from ._builder import dumps

if TYPE_CHECKING:
from ._base_client import SyncBasePrisma, AsyncBasePrisma

log: logging.Logger = logging.getLogger(__name__)

__all__ = (
'AsyncTransactionManager',
'SyncTransactionManager',
)


_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma')
_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma')
_IsolationLevelT = TypeVar('_IsolationLevelT', bound=StrEnum)


class AsyncTransactionManager(Generic[_AsyncPrismaT]):
class AsyncTransactionManager(Generic[_AsyncPrismaT, _IsolationLevelT]):
"""Context manager for wrapping a Prisma instance within a transaction.
This should never be created manually, instead it should be used
Expand All @@ -33,8 +40,10 @@ def __init__(
client: _AsyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: _IsolationLevelT | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -71,14 +80,15 @@ async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = await self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = await self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down Expand Up @@ -122,7 +132,7 @@ async def __aexit__(
)


class SyncTransactionManager(Generic[_SyncPrismaT]):
class SyncTransactionManager(Generic[_SyncPrismaT, _IsolationLevelT]):
"""Context manager for wrapping a Prisma instance within a transaction.
This should never be created manually, instead it should be used
Expand All @@ -135,8 +145,10 @@ def __init__(
client: _SyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: _IsolationLevelT | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -173,14 +185,15 @@ def start(self, *, _from_context: bool = False) -> _SyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down
Loading

0 comments on commit 69cc396

Please sign in to comment.