Skip to content

Commit

Permalink
add ignore_existing_datums argument for bulk uploads (#580)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed May 16, 2024
1 parent 1c10888 commit bc569fb
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 129 deletions.
2 changes: 1 addition & 1 deletion api/tests/functional-tests/backend/core/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_create_annotation_already_exists_error(

core.create_groundtruths(db, empty_groundtruths)
core.create_predictions(db, empty_predictions)
with pytest.raises(exceptions.DatumAlreadyExistsError):
with pytest.raises(exceptions.DatumsAlreadyExistError):
core.create_groundtruths(db, empty_groundtruths[0:1])
with pytest.raises(exceptions.AnnotationAlreadyExistsError):
core.create_predictions(db, empty_predictions[0:1])
Expand Down
60 changes: 60 additions & 0 deletions api/tests/functional-tests/backend/core/test_datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,66 @@ def test_create_datum(
assert db.scalar(select(func.count()).select_from(models.Datum)) == 2


def test_create_datums(
db: Session,
created_dataset: str,
):
assert db.scalar(select(func.count()).select_from(models.Datum)) == 0
dataset = core.fetch_dataset(db=db, name=created_dataset)

assert (
len(
core.create_datums(
db=db,
datums=[
schemas.Datum(uid="uid1"),
schemas.Datum(uid="uid2"),
schemas.Datum(uid="uid3"),
],
datasets=[dataset] * 3,
ignore_existing_datums=True,
)
)
== 3
)

assert db.scalar(select(func.count()).select_from(models.Datum)) == 3

assert (
len(
core.create_datums(
db=db,
datums=[
schemas.Datum(uid="uid1"),
schemas.Datum(uid="uid4"),
schemas.Datum(uid="uid3"),
],
datasets=[dataset] * 3,
ignore_existing_datums=True,
)
)
== 1 # only one new datum was created (uid4)
)

assert db.scalar(select(func.count()).select_from(models.Datum)) == 4

with pytest.raises(exceptions.DatumsAlreadyExistError) as exc_info:
core.create_datums(
db=db,
datums=[
schemas.Datum(uid="uid2"),
schemas.Datum(uid="uid3"),
schemas.Datum(uid="uid7"),
],
datasets=[dataset] * 3,
ignore_existing_datums=False,
)
assert "Datums with uids" in str(exc_info.value)
assert "uid2" in str(exc_info.value)
assert "uid3" in str(exc_info.value)
assert "uid7" not in str(exc_info.value)


def test_get_paginated_datums(
db: Session,
created_dataset: str,
Expand Down
63 changes: 56 additions & 7 deletions api/valor_api/backend/core/datum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sqlalchemy import and_, desc, func
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

Expand All @@ -8,7 +9,10 @@


def create_datums(
db: Session, datums: list[schemas.Datum], datasets: list[models.Dataset]
db: Session,
datums: list[schemas.Datum],
datasets: list[models.Dataset],
ignore_existing_datums: bool,
) -> list[models.Datum]:
"""Creates datums in bulk
Expand All @@ -20,6 +24,10 @@ def create_datums(
The datums to add to the database.
datasets
The datasets to link to the datums. This list should be the same length as the datums list.
ignore_existing_datums
If True, will ignore datums that already exist in the database.
If False, will raise an error if any datums already exist.
Default is False.
Returns
-------
Expand All @@ -36,14 +44,55 @@ def create_datums(
]

try:
db.add_all(rows)
if not ignore_existing_datums:
db.add_all(rows)
db.commit()
return rows

values = [
{
"uid": row.uid,
"dataset_id": row.dataset_id,
"meta": row.meta,
}
for row in rows
]
insert_stmt = (
insert(models.Datum)
.values(values)
.on_conflict_do_nothing(index_elements=["dataset_id", "uid"])
.returning(models.Datum.id, models.Datum.uid)
)

ids_uids = db.execute(insert_stmt)
db.commit()
except IntegrityError:
uid_to_id = {uid: id_ for id_, uid in ids_uids}
new_rows = []
for row in rows:
if row.uid in uid_to_id:
row.id = uid_to_id[row.uid]
new_rows.append(row)
return new_rows

except IntegrityError as e:
db.rollback()
# TODO: fix this exception
raise exceptions.DatumAlreadyExistsError("")

return rows
if (
"duplicate key value violates unique constraint" not in str(e)
or ignore_existing_datums
):
raise e

# get existing datums
existing_datums: list[models.Datum] = []
for datum, dataset in zip(datums, datasets):
try:
existing_datums.append(fetch_datum(db, dataset.id, datum.uid))
except exceptions.DatumDoesNotExistError:
pass

raise exceptions.DatumsAlreadyExistError(
[datum.uid for datum in existing_datums]
)


def create_datum(
Expand Down
21 changes: 20 additions & 1 deletion api/valor_api/backend/core/groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from valor_api.backend import core, models


def create_groundtruths(db: Session, groundtruths: list[schemas.GroundTruth]):
def create_groundtruths(
db: Session,
groundtruths: list[schemas.GroundTruth],
ignore_existing_datums: bool = False,
):
"""Create ground truths in bulk.
Parameters
Expand All @@ -15,6 +19,10 @@ def create_groundtruths(db: Session, groundtruths: list[schemas.GroundTruth]):
The database Session to query against.
groundtruths
The ground truths to create.
ignore_existing_datums
If True, will ignore datums that already exist in the database.
If False, will raise an error if any datums already exist.
Default is False.
Returns
-------
Expand All @@ -39,7 +47,18 @@ def create_groundtruths(db: Session, groundtruths: list[schemas.GroundTruth]):
dataset_name_to_dataset[groundtruth.dataset_name]
for groundtruth in groundtruths
],
ignore_existing_datums=ignore_existing_datums,
)

if ignore_existing_datums:
# datums only contains the newly created ones, so we need to filter out
# the ones that already existed
groundtruths = [
gt
for gt in groundtruths
if gt.datum.uid in [datum.uid for datum in datums]
]

all_labels = [
label
for groundtruth in groundtruths
Expand Down
7 changes: 6 additions & 1 deletion api/valor_api/crud/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def create_groundtruths(
*,
db: Session,
groundtruths: list[schemas.GroundTruth],
ignore_existing_datums: bool = False,
):
"""
Creates a ground truth.
Expand All @@ -65,7 +66,11 @@ def create_groundtruths(
groundtruth: schemas.GroundTruth
The ground truth to create.
"""
backend.create_groundtruths(db, groundtruths=groundtruths)
backend.create_groundtruths(
db,
groundtruths=groundtruths,
ignore_existing_datums=ignore_existing_datums,
)


def create_predictions(
Expand Down
14 changes: 14 additions & 0 deletions api/valor_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,20 @@ def __init__(self, uid: str):
super().__init__(f"Datum with uid: `{uid}` already exists.")


class DatumsAlreadyExistError(Exception):
"""
Raises an exception if the user tries to create a datum that already exists.
Parameters
-------
uids
The UIDs of the datums.
"""

def __init__(self, uids: list[str]):
super().__init__(f"Datums with uids: `{uids}` already exist.")


""" Annotation """


Expand Down
12 changes: 10 additions & 2 deletions api/valor_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_db():
tags=["GroundTruths"],
)
def create_groundtruths(
groundtruths: list[schemas.GroundTruth], db: Session = Depends(get_db)
groundtruths: list[schemas.GroundTruth],
ignore_existing_datums: bool = False,
db: Session = Depends(get_db),
):
"""
Create a ground truth in the database.
Expand All @@ -76,6 +78,8 @@ def create_groundtruths(
The ground truths to add to the database.
db : Session
The database session to use. This parameter is a sqlalchemy dependency and shouldn't be submitted by the user.
ignore_existing_datums : bool, optional
If True, will ignore datums that already exist in the database.
Raises
------
Expand All @@ -85,7 +89,11 @@ def create_groundtruths(
If the dataset has been finalized, or if the datum already exists.
"""
try:
crud.create_groundtruths(db=db, groundtruths=groundtruths)
crud.create_groundtruths(
db=db,
groundtruths=groundtruths,
ignore_existing_datums=ignore_existing_datums,
)
except Exception as e:
raise exceptions.create_http_error(e)

Expand Down
24 changes: 1 addition & 23 deletions client/unit-tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,14 @@

import pytest

from valor.client import _chunk_list, connect, get_connection, reset_connection
from valor.client import connect, get_connection, reset_connection
from valor.exceptions import (
ClientAlreadyConnectedError,
ClientConnectionFailed,
ClientNotConnectedError,
)


def test__chunk_list():
# edge case
data = [{"key": "value"}]
results = _chunk_list(json_list=data, chunk_size_bytes=10)

assert results == [data]

# 100 elements with an average element size of 23.8
data = [{f"key_{i}": f"value_{i}"} for i in range(100)]

# standard case
results = _chunk_list(json_list=data, chunk_size_bytes=1000)
assert (
len(results) == 4
) # recursively chunked once, which added an extra split
assert [len(x) for x in results] == [42, 41, 1, 16]

# edge case with small chunk size
results = _chunk_list(json_list=data, chunk_size_bytes=1)
assert results == [[x] for x in data]


@patch("valor.client.ClientConnection")
def test_connect(ClientConnection):
connect(host="host")
Expand Down
Loading

0 comments on commit bc569fb

Please sign in to comment.