Skip to content

Commit

Permalink
feat(client): add transaction isolation level
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and jonathanblade committed Jan 31, 2024
1 parent 22cc236 commit ac76cbb
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 19 deletions.
59 changes: 58 additions & 1 deletion databases/sync_tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, RawQueries


def test_model_query(client: Prisma) -> None:
Expand Down Expand Up @@ -201,3 +201,60 @@ def test_transaction_already_closed(client: Prisma) -> None:
transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.parametrize(
('input_level', 'expected_level'),
[
pytest.param(
prisma.TransactionIsolationLevel.READ_UNCOMMITTED,
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.READ_COMMITTED,
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.REPEATABLE_READ,
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.SNAPSHOT,
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(True, reason='Available for SQL Server only'),
),
pytest.param(
prisma.TransactionIsolationLevel.SERIALIZABLE,
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite', reason='PRAGMA has only effect in shared-cache mode'
),
),
],
)
# TODO: remove after issue will be resolved
@pytest.mark.skipif(CURRENT_DATABASE in ['mysql', 'mariadb'], reason='https://github.com/prisma/prisma/issues/22890')
def test_isolation_level(
client: Prisma, raw_queries: RawQueries, input_level: prisma.TransactionIsolationLevel, expected_level: str
) -> None:
"""Ensure that transaction isolation level is set correctly"""
with client.tx(isolation_level=input_level) as tx:
results = tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
# The result can depends on the database, so we do upper() and replace()
level = str(level).upper().replace(' ', '_').replace('-', '_')
assert level == expected_level
60 changes: 59 additions & 1 deletion databases/tests/test_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from prisma import Prisma
from prisma.models import User, Profile

from ..utils import CURRENT_DATABASE
from ..utils import CURRENT_DATABASE, RawQueries


@pytest.mark.asyncio
Expand Down Expand Up @@ -212,3 +212,61 @@ async def test_transaction_already_closed(client: Prisma) -> None:
await transaction.user.delete_many()

assert exc.match('Transaction already closed')


@pytest.mark.asyncio
@pytest.mark.parametrize(
('input_level', 'expected_level'),
[
pytest.param(
prisma.TransactionIsolationLevel.READ_UNCOMMITTED,
'READ_UNCOMMITTED',
id='read uncommitted',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.READ_COMMITTED,
'READ_COMMITTED',
id='read committed',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.REPEATABLE_READ,
'REPEATABLE_READ',
id='repeatable read',
marks=pytest.mark.skipif(CURRENT_DATABASE in ['cockroachdb', 'sqlite'], reason='Not available'),
),
pytest.param(
prisma.TransactionIsolationLevel.SNAPSHOT,
'SNAPSHOT',
id='snapshot',
marks=pytest.mark.skipif(True, reason='Available for SQL Server only'),
),
pytest.param(
prisma.TransactionIsolationLevel.SERIALIZABLE,
'SERIALIZABLE',
id='serializable',
marks=pytest.mark.skipif(
CURRENT_DATABASE == 'sqlite', reason='PRAGMA has only effect in shared-cache mode'
),
),
],
)
# TODO: remove after issue will be resolved
@pytest.mark.skipif(CURRENT_DATABASE in ['mysql', 'mariadb'], reason='https://github.com/prisma/prisma/issues/22890')
async def test_isolation_level(
client: Prisma, raw_queries: RawQueries, input_level: prisma.TransactionIsolationLevel, expected_level: str
) -> None:
"""Ensure that transaction isolation level is set correctly"""
async with client.tx(isolation_level=input_level) as tx:
results = await tx.query_raw(raw_queries.select_tx_isolation)

assert len(results) == 1

row = results[0]
assert any(row)

level = next(iter(row.values()))
# The result can depends on the database, so we do upper() and replace()
level = str(level).upper().replace(' ', '_').replace('-', '_')
assert level == expected_level
12 changes: 12 additions & 0 deletions databases/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class RawQueries(BaseModel):
test_query_raw_no_result: LiteralString
test_execute_raw_no_result: LiteralString

select_tx_isolation: LiteralString


_mysql_queries = RawQueries(
count_posts="""
Expand Down Expand Up @@ -136,8 +138,12 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SELECT @@transaction_isolation
""",
)


_postgresql_queries = RawQueries(
count_posts="""
SELECT COUNT(*) as count
Expand Down Expand Up @@ -188,6 +194,9 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
SHOW transaction_isolation
""",
)

RAW_QUERIES_MAPPING: DatabaseMapping[RawQueries] = {
Expand Down Expand Up @@ -245,5 +254,8 @@ class RawQueries(BaseModel):
SET title = 'updated title'
WHERE id = 'sdldsd'
""",
select_tx_isolation="""
PRAGMA read_uncommitted = 1
""",
),
}
11 changes: 11 additions & 0 deletions docs/reference/transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,17 @@ In the case that this example runs successfully, then both database writes are c
)
```

## Isolation levels

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

```py
prisma.tx(
isolation_level=prisma.TransactionIsolationLevel.READ_UNCOMMITTED,
)
```

## Timeouts

You can pass the following options to configure how timeouts are applied to your transaction:
Expand Down
56 changes: 39 additions & 17 deletions src/prisma/_transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,39 @@
import logging
import warnings
from types import TracebackType
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from datetime import timedelta

from ._types import TransactionId
from .errors import TransactionNotStartedError
from ._compat import StrEnum
from ._builder import dumps

if TYPE_CHECKING:
from ._base_client import SyncBasePrisma, AsyncBasePrisma

log: logging.Logger = logging.getLogger(__name__)

__all__ = (
'TransactionIsolationLevel',
'AsyncTransactionManager',
'SyncTransactionManager',
)


_SyncPrismaT = TypeVar('_SyncPrismaT', bound='SyncBasePrisma')
_AsyncPrismaT = TypeVar('_AsyncPrismaT', bound='AsyncBasePrisma')


# See here: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels
class TransactionIsolationLevel(StrEnum):
READ_UNCOMMITTED = 'ReadUncommitted'
READ_COMMITTED = 'ReadCommitted'
REPEATABLE_READ = 'RepeatableRead'
SNAPSHOT = 'Snapshot'
SERIALIZABLE = 'Serializable'


class AsyncTransactionManager(Generic[_AsyncPrismaT]):
"""Context manager for wrapping a Prisma instance within a transaction.
Expand All @@ -33,8 +49,10 @@ def __init__(
client: _AsyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: TransactionIsolationLevel | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -71,14 +89,15 @@ async def start(self, *, _from_context: bool = False) -> _AsyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = await self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = await self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down Expand Up @@ -135,8 +154,10 @@ def __init__(
client: _SyncPrismaT,
max_wait: int | timedelta,
timeout: int | timedelta,
isolation_level: TransactionIsolationLevel | None,
) -> None:
self.__client = client
self._isolation_level = isolation_level

if isinstance(max_wait, int):
message = (
Expand Down Expand Up @@ -173,14 +194,15 @@ def start(self, *, _from_context: bool = False) -> _SyncPrismaT:
stacklevel=3 if _from_context else 2,
)

tx_id = self.__client._engine.start_transaction(
content=dumps(
{
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
),
)
content_dict: dict[str, Any] = {
'timeout': int(self._timeout.total_seconds() * 1000),
'max_wait': int(self._max_wait.total_seconds() * 1000),
}
if self._isolation_level is not None:
content_dict['isolation_level'] = self._isolation_level.value

tx_id = self.__client._engine.start_transaction(content=dumps(content_dict))

self._tx_id = tx_id
client = self.__client._copy()
client._tx_id = tx_id
Expand Down
7 changes: 7 additions & 0 deletions src/prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ from .metadata import PRISMA_MODELS, RELATIONAL_FIELD_MAPPINGS
from ._transactions import AsyncTransactionManager, SyncTransactionManager

# re-exports
from ._transactions import TransactionIsolationLevel
from ._base_client import SyncBasePrisma, AsyncBasePrisma, load_env as load_env
from ._registry import (
register as register,
Expand All @@ -37,6 +38,7 @@ __all__ = (
'ENGINE_TYPE',
'SCHEMA_PATH',
'BINARY_PATHS',
'TransactionIsolationLevel',
'Batch',
'Prisma',
'Client',
Expand Down Expand Up @@ -202,6 +204,7 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
def tx(
self,
*,
isolation_level: Optional[TransactionIsolationLevel] = None,
max_wait: Union[int, timedelta] = DEFAULT_TX_MAX_WAIT,
timeout: Union[int, timedelta] = DEFAULT_TX_TIMEOUT,
) -> TransactionManager:
Expand All @@ -211,6 +214,9 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
actions within a transaction, queries will be isolated to the Prisma instance and
will not be commited to the database until the context manager exits.

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this
default with the `max_wait` argument which accepts a value in milliseconds or `datetime.timedelta`.

Expand All @@ -231,6 +237,7 @@ class Prisma({% if is_async %}AsyncBasePrisma{% else %}SyncBasePrisma{% endif %}
client=self,
max_wait=max_wait,
timeout=timeout,
isolation_level=isolation_level,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ from .metadata import PRISMA_MODELS, RELATIONAL_FIELD_MAPPINGS
from ._transactions import AsyncTransactionManager, SyncTransactionManager

# re-exports
from ._transactions import TransactionIsolationLevel
from ._base_client import SyncBasePrisma, AsyncBasePrisma, load_env as load_env
from ._registry import (
register as register,
Expand All @@ -75,6 +76,7 @@ __all__ = (
'ENGINE_TYPE',
'SCHEMA_PATH',
'BINARY_PATHS',
'TransactionIsolationLevel',
'Batch',
'Prisma',
'Client',
Expand Down Expand Up @@ -265,6 +267,7 @@ class Prisma(AsyncBasePrisma):
def tx(
self,
*,
isolation_level: Optional[TransactionIsolationLevel] = None,
max_wait: Union[int, timedelta] = DEFAULT_TX_MAX_WAIT,
timeout: Union[int, timedelta] = DEFAULT_TX_TIMEOUT,
) -> TransactionManager:
Expand All @@ -274,6 +277,9 @@ class Prisma(AsyncBasePrisma):
actions within a transaction, queries will be isolated to the Prisma instance and
will not be commited to the database until the context manager exits.

By default, Prisma sets the isolation level to the value currently configured in the database. You can modify this
default with the `isolation_level` argument (see [supported isolation levels](https://www.prisma.io/docs/orm/prisma-client/queries/transactions#supported-isolation-levels)).

By default, Prisma will wait a maximum of 2 seconds to acquire a transaction from the database. You can modify this
default with the `max_wait` argument which accepts a value in milliseconds or `datetime.timedelta`.

Expand All @@ -294,6 +300,7 @@ class Prisma(AsyncBasePrisma):
client=self,
max_wait=max_wait,
timeout=timeout,
isolation_level=isolation_level,
)


Expand Down

0 comments on commit ac76cbb

Please sign in to comment.