Skip to content

Commit

Permalink
feat(client): add create_many action method
Browse files Browse the repository at this point in the history
closes #8
  • Loading branch information
RobertCraigie committed Jul 17, 2021
1 parent 6d7ff25 commit cc2872c
Show file tree
Hide file tree
Showing 14 changed files with 1,115 additions and 31 deletions.
10 changes: 10 additions & 0 deletions prisma/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ def __init__(self) -> None:
super().__init__('Cannot make a request from a closed client.')


class UnsupportedDatabaseError(PrismaError):
context: str
database: str

def __init__(self, database: str, context: str) -> None:
super().__init__(f'{context} is not supported by {database}')
self.database = database
self.context = context


class DataError(PrismaError):
data: Any
code: Any
Expand Down
16 changes: 15 additions & 1 deletion prisma/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ class Data(BaseModel):
generator: 'Generator'
dmmf: 'DMMF' = FieldInfo(alias='dmmf')
schema_path: str = FieldInfo(alias='schemaPath')
datasources: List['Datasource'] = FieldInfo(alias='datasources')

# TODO
data_sources: Any = FieldInfo(alias='dataSources')
other_generators: List[Any] = FieldInfo(alias='otherGenerators')

@classmethod
Expand All @@ -143,6 +143,14 @@ def parse_obj(cls, obj: Any) -> 'Data':
return data


class Datasource(BaseModel):
# TODO: provider enums
name: str
provider: List[str]
active_provider: str = FieldInfo(alias='activeProvider')
url: 'OptionalValueFromEnvVar'


class Generator(BaseModel):
name: str
output: 'ValueFromEnvVar'
Expand All @@ -157,6 +165,11 @@ class ValueFromEnvVar(BaseModel):
from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')


class OptionalValueFromEnvVar(BaseModel):
value: Optional[str]
from_env_var: Optional[str] = FieldInfo(alias='fromEnvVar')


class Config(BaseSettings):
"""Custom generator config options."""

Expand Down Expand Up @@ -466,3 +479,4 @@ class DefaultValue(BaseModel):
Model.update_forward_refs()
Datamodel.update_forward_refs()
Generator.update_forward_refs()
Datasource.update_forward_refs()
2 changes: 2 additions & 0 deletions prisma/generator/templates/_utils.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ time.sleep({{ duration }})
'update': 'mutation',
'upsert': 'mutation',
'query_raw': 'mutation',
'create_many': 'mutation',
'execute_raw': 'mutation',
'delete_many': 'mutation',
'update_many': 'mutation',
Expand All @@ -40,6 +41,7 @@ time.sleep({{ duration }})
'update': 'updateOne',
'upsert': 'upsertOne',
'query_raw': 'queryRaw',
'create_many': 'createMany',
'execute_raw': 'executeRaw',
'delete_many': 'deleteMany',
'update_many': 'updateMany',
Expand Down
18 changes: 8 additions & 10 deletions prisma/generator/templates/builder.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -399,16 +399,14 @@ class Arguments(Node):
Key(arg, node=Data.create(self.builder, data=value))
)
elif isinstance(value, (list, tuple, set)):
# NOTE: this is a special case for execute_raw and query_raw
# that expect parameters to be passed as a json string.
#
# this special case only works as there are no other root level arguments
# that take an Iterable, all other Iterable arguments are encapsulated within
# a Data node which handles Iterables differently, passing them to ListNode.
#
# prisma expects a json string value like "[\"John\",\"123\"]"
# we encode twice to ensure that only the inner quotes are escaped
children.append(f'{arg}: {dumps(dumps(value))}')
# NOTE: we have a special case for execute_raw and query_raw
# here as prisma expects parameters to be passed as a json string
# value like "[\"John\",\"123\"]", and we encode twice to ensure
# that only the inner quotes are escaped
if self.builder.method in {'queryRaw', 'executeRaw'}:
children.append(f'{arg}: {dumps(dumps(value))}')
else:
children.append(Key(arg, node=ListNode.create(self.builder, data=value)))
else:
children.append(f'{arg}: {dumps(value)}')

Expand Down
43 changes: 43 additions & 0 deletions prisma/generator/templates/client.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Client:
self.{{ model.name.lower() }} = {{ model.name }}Actions(self)
{% endfor %}
self.__engine: Optional[QueryEngine] = None
self._active_provider = '{{ datasources[0].active_provider }}'

if use_dotenv:
load_env()
Expand Down Expand Up @@ -151,6 +152,7 @@ class Batch:
def __init__(self, client: Client) -> None:
self.__client = client
self.__queries: List[str] = []
self._active_provider = client._active_provider
{% for model in dmmf.datamodel.models %}
self.{{ model.name.lower() }} = {{ model.name }}BatchActions(self)
{% endfor %}
Expand Down Expand Up @@ -227,6 +229,27 @@ class {{ model.name }}Actions:
)
return models.{{ model.name }}.parse_obj(resp['data']['result'])

{{ maybe_async_def }}create_many(
self,
data: List[types.{{ model.name }}CreateWithoutRelationsInput],
*,
skip_duplicates: Optional[bool] = None,
) -> int:
if self._client._active_provider == 'sqlite':
raise errors.UnsupportedDatabaseError('sqlite', 'create_many()')

resp = {{ maybe_await }}self._client._execute(
operation='{{ operations.create_many }}',
method='{{ methods.create_many }}',
model='{{ model.name }}',
arguments={
'data': data,
'skipDuplicates': skip_duplicates,
},
root_selection=['count'],
)
return resp['data']['result']['count']

{{ maybe_async_def }}delete(
self,
where: types.{{ model.name }}WhereUniqueInput,
Expand Down Expand Up @@ -430,6 +453,26 @@ class {{ model.name }}BatchActions:
},
)

def create_many(
self,
data: List[types.{{ model.name }}CreateWithoutRelationsInput],
*,
skip_duplicates: Optional[bool] = None,
) -> None:
if self._batcher._active_provider == 'sqlite':
raise errors.UnsupportedDatabaseError('sqlite', 'create_many()')

self._batcher._add(
operation='{{ operations.create_many }}',
method='{{ methods.create_many }}',
model='{{ model.name }}',
arguments={
'data': data,
'skipDuplicates': skip_duplicates,
},
root_selection=['count'],
)

def delete(
self,
where: types.{{ model.name }}WhereUniqueInput,
Expand Down
11 changes: 11 additions & 0 deletions tests/integrations/postgresql/tests/test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from prisma import Client


@pytest.mark.asyncio
async def test_create_many(client: Client) -> None:
async with client.batch_() as batcher:
batcher.user.create({'name': 'Robert'})
batcher.user.create({'name': 'Tegan'})

assert await client.user.count() == 2
34 changes: 34 additions & 0 deletions tests/integrations/postgresql/tests/test_create_many.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import prisma
from prisma import Client
from prisma.enums import Role


@pytest.mark.asyncio
async def test_create_many(client: Client) -> None:
total = await client.user.create_many(
[{'name': 'Robert', 'role': Role.ADMIN}, {'name': 'Tegan'}]
)
assert total == 2

user = await client.user.find_first(where={'name': 'Robert'})
assert user is not None
assert user.name == 'Robert'
assert user.role == Role.ADMIN

assert await client.user.count() == 2


@pytest.mark.asyncio
async def test_skip_duplicates(client: Client) -> None:
user = await client.user.create({'name': 'Robert'})

with pytest.raises(prisma.errors.UniqueViolationError) as exc:
await client.user.create_many([{'id': user.id, 'name': 'Robert 2'}])

assert exc.match(r'Unique constraint failed on the fields: \(`id`\)')

count = await client.user.create_many(
[{'id': user.id, 'name': 'Robert 2'}, {'name': 'Tegan'}], skip_duplicates=True
)
assert count == 1
9 changes: 9 additions & 0 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,15 @@ async def test_delete_many(client: Client) -> None:
assert await client.user.count() == 0


@pytest.mark.asyncio
async def test_create_many_unsupported(client: Client) -> None:
with pytest.raises(prisma.errors.UnsupportedDatabaseError) as exc:
async with client.batch_() as batcher:
batcher.user.create_many([{'name': 'Robert'}])

assert exc.match(r'create_many\(\) is not supported by sqlite')


def test_ensure_batch_and_action_signatures_are_equal(client: Client) -> None:
# ensure tests will fail if an action method is updated without
# updating the corresponding batch method
Expand Down
8 changes: 8 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ async def test_catches_not_connected() -> None:
await client.post.delete_many()

assert 'await client.connect()' in str(exc)


@pytest.mark.asyncio
async def test_create_many_invalid_provider(client: Client) -> None:
with pytest.raises(errors.UnsupportedDatabaseError) as exc:
await client.user.create_many([{'name': 'Robert'}])

assert exc.match(r'create_many\(\) is not supported by sqlite')
Original file line number Diff line number Diff line change
Expand Up @@ -550,16 +550,14 @@ class Arguments(Node):
Key(arg, node=Data.create(self.builder, data=value))
)
elif isinstance(value, (list, tuple, set)):
# NOTE: this is a special case for execute_raw and query_raw
# that expect parameters to be passed as a json string.
#
# this special case only works as there are no other root level arguments
# that take an Iterable, all other Iterable arguments are encapsulated within
# a Data node which handles Iterables differently, passing them to ListNode.
#
# prisma expects a json string value like "[\"John\",\"123\"]"
# we encode twice to ensure that only the inner quotes are escaped
children.append(f'{arg}: {dumps(dumps(value))}')
# NOTE: we have a special case for execute_raw and query_raw
# here as prisma expects parameters to be passed as a json string
# value like "[\"John\",\"123\"]", and we encode twice to ensure
# that only the inner quotes are escaped
if self.builder.method in {'queryRaw', 'executeRaw'}:
children.append(f'{arg}: {dumps(dumps(value))}')
else:
children.append(Key(arg, node=ListNode.create(self.builder, data=value)))
else:
children.append(f'{arg}: {dumps(value)}')

Expand Down
Loading

0 comments on commit cc2872c

Please sign in to comment.