Skip to content

Commit

Permalink
feat: add count_submitted_responses as property to Records databa…
Browse files Browse the repository at this point in the history
…se model (#5118)

# Description

This PR include the following changes:
* Added `count_submitted_responses` as a property of `Record` database
model.
  * This property requires record responses to be pre-loaded. 
  * This value is get from the database using a subquery.
* Added `count_submitted_responses` to search engine mapping.
* Record `status` is exposed by API schemas and is calculated based in
`count_submitted_responses` column property from `Record` database
model.
* This `status` is defined as a property inside `Record` database model
and it's using the `dataset` distribution strategy to calculate the
value.

## Missing changes in this PR
- [ ] Make test suite to pass after changes.
- [ ] Add support to `status` value in search endpoints so we can filter
by `status=pending&response_status=pending`.
- [ ] Check that we are refreshing the record
`count_submitted_responses` values before indexing the record and add a
partial update into the search engine when some associated entity (like
responses) are create/updated/deleted for a record. (We probably should
add a partial update of the index for this record attribute).
- [ ] Change dataset progress metrics.
- [ ] Change user metrics.

Refs #5069

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Refactor (change restructuring the codebase without changing
functionality)
- [ ] Improvement (change adding some improvement to an existing
functionality)
- [ ] Documentation update

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [ ] Test A
- [ ] Test B

**Checklist**

- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Paco Aranda <francis@argilla.io>
  • Loading branch information
jfcalvo and frascuchon committed Jul 1, 2024
1 parent c8aa1a9 commit 58c8257
Show file tree
Hide file tree
Showing 26 changed files with 769 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ export class RecordRepository {
constructor(private readonly axios: NuxtAxiosInstance) {}

getRecords(criteria: RecordCriteria): Promise<BackendRecords> {
if (criteria.isFilteringByAdvanceSearch)
return this.getRecordsByAdvanceSearch(criteria);

return this.getRecordsByDatasetId(criteria);
return this.getRecordsByAdvanceSearch(criteria);
// return this.getRecordsByDatasetId(criteria);
}

async getRecord(recordId: string): Promise<BackendRecord> {
Expand Down Expand Up @@ -264,6 +262,30 @@ export class RecordRepository {
};
}

body.filters = {
and: [
{
type: "terms",
scope: {
entity: "response",
property: "status",
},
values: [status],
},
],
};

if (status === "pending") {
body.filters.and.push({
type: "terms",
scope: {
entity: "record",
property: "status",
},
values: ["pending"],
});
}

if (
isFilteringByMetadata ||
isFilteringByResponse ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add status column to records table
Revision ID: 237f7c674d74
Revises: 45a12f74448b
Create Date: 2024-06-18 17:59:36.992165
"""

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "237f7c674d74"
down_revision = "45a12f74448b"
branch_labels = None
depends_on = None


record_status_enum = sa.Enum("pending", "completed", name="record_status_enum")


def upgrade() -> None:
record_status_enum.create(op.get_bind())

op.add_column("records", sa.Column("status", record_status_enum, server_default="pending", nullable=False))
op.create_index(op.f("ix_records_status"), "records", ["status"], unique=False)

# NOTE: Updating existent records to have "completed" status when they have
# at least one response with "submitted" status.
op.execute("""
UPDATE records
SET status = 'completed'
WHERE id IN (
SELECT DISTINCT record_id
FROM responses
WHERE status = 'submitted'
);
""")


def downgrade() -> None:
op.drop_index(op.f("ix_records_status"), table_name="records")
op.drop_column("records", "status")

record_status_enum.drop(op.get_bind())
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ async def update_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.update(response))
Expand All @@ -83,7 +85,9 @@ async def delete_response(
response = await Response.get_or_raise(
db,
response_id,
options=[selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions)],
options=[
selectinload(Response.record).selectinload(Record.dataset).selectinload(Dataset.questions),
],
)

await authorize(current_user, ResponsePolicy.delete(response))
Expand Down
5 changes: 3 additions & 2 deletions argilla-server/src/argilla_server/api/schemas/v1/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from argilla_server.api.schemas.v1.metadata_properties import MetadataPropertyName
from argilla_server.api.schemas.v1.responses import Response, ResponseFilterScope, UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import Suggestion, SuggestionCreate, SuggestionFilterScope
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder
from argilla_server.enums import RecordInclude, RecordSortField, SimilarityOrder, SortOrder, RecordStatus
from argilla_server.pydantic_v1 import BaseModel, Field, StrictStr, root_validator, validator
from argilla_server.pydantic_v1.utils import GetterDict
from argilla_server.search_engine import TextQuery
Expand Down Expand Up @@ -66,6 +66,7 @@ def get(self, key: str, default: Any) -> Any:

class Record(BaseModel):
id: UUID
status: RecordStatus
fields: Dict[str, Any]
metadata: Optional[Dict[str, Any]]
external_id: Optional[str]
Expand Down Expand Up @@ -196,7 +197,7 @@ def _has_relationships(self):

class RecordFilterScope(BaseModel):
entity: Literal["record"]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at]]
property: Union[Literal[RecordSortField.inserted_at], Literal[RecordSortField.updated_at], Literal["status"]]


class Records(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/bulk/records_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from argilla_server.api.schemas.v1.responses import UserResponseCreate
from argilla_server.api.schemas.v1.suggestions import SuggestionCreate
from argilla_server.contexts import distribution
from argilla_server.contexts.accounts import fetch_users_by_ids_as_dict
from argilla_server.contexts.records import (
fetch_records_by_external_ids_as_dict,
Expand Down Expand Up @@ -67,6 +68,7 @@ async def create_records_bulk(self, dataset: Dataset, bulk_create: RecordsBulkCr

await self._upsert_records_relationships(records, bulk_create.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -207,6 +209,7 @@ async def upsert_records_bulk(self, dataset: Dataset, bulk_upsert: RecordsBulkUp

await self._upsert_records_relationships(records, bulk_upsert.items)
await _preload_records_relationships_before_index(self._db, records)
await distribution.update_records_status(self._db, records)
await self._search_engine.index_records(dataset, records)

await self._db.commit()
Expand Down Expand Up @@ -237,6 +240,7 @@ async def _preload_records_relationships_before_index(db: "AsyncSession", record
.filter(Record.id.in_([record.id for record in records]))
.options(
selectinload(Record.responses).selectinload(Response.user),
selectinload(Record.responses_submitted),
selectinload(Record.suggestions).selectinload(Suggestion.question),
selectinload(Record.vectors),
)
Expand Down
15 changes: 14 additions & 1 deletion argilla-server/src/argilla_server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
VectorSettingsCreate,
)
from argilla_server.api.schemas.v1.vectors import Vector as VectorSchema
from argilla_server.contexts import accounts
from argilla_server.contexts import accounts, distribution
from argilla_server.enums import DatasetStatus, RecordInclude, UserRole
from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError
from argilla_server.models import (
Expand Down Expand Up @@ -940,6 +940,9 @@ async def create_response(
await db.flush([response])
await _touch_dataset_last_activity_at(db, record.dataset)
await search_engine.update_record_response(response)
await db.refresh(record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, record)
await search_engine.partial_record_update(record, status=record.status)

await db.commit()

Expand All @@ -963,6 +966,9 @@ async def update_response(
await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.update_record_response(response)
await db.refresh(response.record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, response.record)
await search_engine.partial_record_update(response.record, status=response.record.status)

await db.commit()

Expand Down Expand Up @@ -992,6 +998,9 @@ async def upsert_response(
await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.update_record_response(response)
await db.refresh(record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, record)
await search_engine.partial_record_update(record, status=record.status)

await db.commit()

Expand All @@ -1001,9 +1010,13 @@ async def upsert_response(
async def delete_response(db: AsyncSession, search_engine: SearchEngine, response: Response) -> Response:
async with db.begin_nested():
response = await response.delete(db, autocommit=False)

await _load_users_from_responses(response)
await _touch_dataset_last_activity_at(db, response.record.dataset)
await search_engine.delete_record_response(response)
await db.refresh(response.record, attribute_names=[Record.responses_submitted.key])
await distribution.update_record_status(db, response.record)
await search_engine.partial_record_update(record=response.record, status=response.record.status)

await db.commit()

Expand Down
42 changes: 42 additions & 0 deletions argilla-server/src/argilla_server/contexts/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from sqlalchemy.ext.asyncio import AsyncSession

from argilla_server.enums import DatasetDistributionStrategy, RecordStatus
from argilla_server.models import Record


# TODO: Do this with one single update statement for all records if possible to avoid too many queries.
async def update_records_status(db: AsyncSession, records: List[Record]):
for record in records:
await update_record_status(db, record)


async def update_record_status(db: AsyncSession, record: Record) -> Record:
if record.dataset.distribution_strategy == DatasetDistributionStrategy.overlap:
return await _update_record_status_with_overlap_strategy(db, record)

raise NotImplementedError(f"unsupported distribution strategy `{record.dataset.distribution_strategy}`")


async def _update_record_status_with_overlap_strategy(db: AsyncSession, record: Record) -> Record:
if len(record.responses_submitted) >= record.dataset.distribution["min_submitted"]:
record.status = RecordStatus.completed
else:
record.status = RecordStatus.pending

return await record.save(db, autocommit=False)
5 changes: 5 additions & 0 deletions argilla-server/src/argilla_server/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class UserRole(str, Enum):
annotator = "annotator"


class RecordStatus(str, Enum):
pending = "pending"
completed = "completed"


class RecordInclude(str, Enum):
responses = "responses"
suggestions = "suggestions"
Expand Down
30 changes: 25 additions & 5 deletions argilla-server/src/argilla_server/models/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
DatasetStatus,
MetadataPropertyType,
QuestionType,
RecordStatus,
ResponseStatus,
SuggestionType,
UserRole,
DatasetDistributionStrategy,
RecordStatus,
)
from argilla_server.models.base import DatabaseModel
from argilla_server.models.metadata_properties import MetadataPropertySettings
Expand Down Expand Up @@ -180,11 +183,17 @@ def __repr__(self) -> str:
)


RecordStatusEnum = SAEnum(RecordStatus, name="record_status_enum")


class Record(DatabaseModel):
__tablename__ = "records"

fields: Mapped[dict] = mapped_column(JSON, default={})
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", MutableDict.as_mutable(JSON), nullable=True)
status: Mapped[RecordStatus] = mapped_column(
RecordStatusEnum, default=RecordStatus.pending, server_default=RecordStatus.pending, index=True
)
external_id: Mapped[Optional[str]] = mapped_column(index=True)
dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True)

Expand All @@ -195,6 +204,13 @@ class Record(DatabaseModel):
passive_deletes=True,
order_by=Response.inserted_at.asc(),
)
responses_submitted: Mapped[List["Response"]] = relationship(
back_populates="record",
cascade="all, delete-orphan",
passive_deletes=True,
primaryjoin=f"and_(Record.id==Response.record_id, Response.status=='{ResponseStatus.submitted}')",
order_by=Response.inserted_at.asc(),
)
suggestions: Mapped[List["Suggestion"]] = relationship(
back_populates="record",
cascade="all, delete-orphan",
Expand All @@ -210,17 +226,17 @@ class Record(DatabaseModel):

__table_args__ = (UniqueConstraint("external_id", "dataset_id", name="record_external_id_dataset_id_uq"),)

def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]:
for vector in self.vectors:
if vector.vector_settings_id == vector_settings.id:
return vector.value

def __repr__(self):
return (
f"Record(id={str(self.id)!r}, external_id={self.external_id!r}, dataset_id={str(self.dataset_id)!r}, "
f"inserted_at={str(self.inserted_at)!r}, updated_at={str(self.updated_at)!r})"
)

def vector_value_by_vector_settings(self, vector_settings: "VectorSettings") -> Union[List[float], None]:
for vector in self.vectors:
if vector.vector_settings_id == vector_settings.id:
return vector.value


class Question(DatabaseModel):
__tablename__ = "questions"
Expand Down Expand Up @@ -354,6 +370,10 @@ def is_draft(self):
def is_ready(self):
return self.status == DatasetStatus.ready

@property
def distribution_strategy(self) -> DatasetDistributionStrategy:
return DatasetDistributionStrategy(self.distribution["strategy"])

def metadata_property_by_name(self, name: str) -> Union["MetadataProperty", None]:
for metadata_property in self.metadata_properties:
if metadata_property.name == name:
Expand Down
4 changes: 4 additions & 0 deletions argilla-server/src/argilla_server/search_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ async def configure_metadata_property(self, dataset: Dataset, metadata_property:
async def index_records(self, dataset: Dataset, records: Iterable[Record]):
pass

@abstractmethod
async def partial_record_update(self, record: Record, **update):
pass

@abstractmethod
async def delete_records(self, dataset: Dataset, records: Iterable[Record]):
pass
Expand Down
Loading

0 comments on commit 58c8257

Please sign in to comment.