Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/qdrant_syncronizer/test_qdrant_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,28 @@ def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_co
"collection_welearn_mul_mulembmodel": {doc_id1},
}
self.assertDictEqual(dict(collections_names), expected)

def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_collection_and_gibberish(
self,
):
self.client.create_collection(
collection_name="collection_welearn_mul_mulembmodel_og",
vectors_config=models.VectorParams(
size=50, distance=models.Distance.COSINE
),
)

doc_id0 = uuid.uuid4()
doc_id1 = uuid.uuid4()
qdrant_connector = self.client
fake_slice0 = FakeSlice(doc_id0, embedding_model_name="english-embmodel")
fake_slice1 = FakeSlice(doc_id0, embedding_model_name="english-embmodel")

fake_slice1.order_sequence = 1

fake_slice2 = FakeSlice(doc_id1, embedding_model_name="mulembmodel")
fake_slice2.document.lang = "pt"

slices = [fake_slice0, fake_slice1, fake_slice2]
collections_names = classify_documents_per_collection(qdrant_connector, slices)
self.assertNotIn("collection_welearn_mul_mulembmodel_og", collections_names)
44 changes: 20 additions & 24 deletions welearn_datastack/modules/qdrant_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from collections import defaultdict
from typing import Collection, Dict, List, Set, Type
from uuid import UUID

Expand All @@ -9,9 +8,7 @@
from qdrant_client.http.models import models

from welearn_datastack.data.db_models import DocumentSlice
from welearn_datastack.exceptions import (
ErrorWhileDeletingChunks,
)
from welearn_datastack.exceptions import ErrorWhileDeletingChunks

logger = logging.getLogger(__name__)

Expand All @@ -31,30 +28,30 @@ def classify_documents_per_collection(
"""
tmp_collections_names_in_qdrant = qdrant_connector.get_collections().collections
collections_names_in_qdrant = [c.name for c in tmp_collections_names_in_qdrant]
model_name_collection_name = {}
for x in collections_names_in_qdrant:
parts = x.split("_")
if len(parts) >= 4:
model_name_collection_name[parts[3]] = x
else:
logger.warning(
"Collection name '%s' does not follow the expected format", x
)

ret: Dict[str, Set[UUID]] = defaultdict(set)
ret: Dict[str, Set[UUID]] = {}
for dslice in slices:
model_name = dslice.embedding_model.title
try:
collection_name = model_name_collection_name[model_name]
ret[collection_name].add(dslice.document_id) # type: ignore
except KeyError:
logger.warning(
"No collection found for model %s, document %s",
model_name,
dslice.document_id,
lang = dslice.document.lang
model = dslice.embedding_model.title
collection_name = None
multilingual_collection = f"collection_welearn_mul_{model}"
mono_collection = f"collection_welearn_{lang}_{model}"

# Check multilingual or mono lingual
if multilingual_collection in collections_names_in_qdrant:
collection_name = multilingual_collection
elif mono_collection in collections_names_in_qdrant:
collection_name = mono_collection
else:
logger.error(
f"Collection {collection_name} not found in Qdrant, slice {dslice.id} ignored",
)
continue

if collection_name not in ret:
ret[collection_name] = set()
ret[collection_name].add(dslice.document_id) # type: ignore

return ret


Expand All @@ -73,7 +70,6 @@ def delete_points_related_to_document(
"""
logger.info("Deletion started")
logger.debug(f"Deleting points related to {documents_ids} in {collection_name}")
op_res = None

try:
op_res = qdrant_connector.delete(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def main() -> None:

# Iterate on each collection
for collection_name in documents_per_collection:
logger.info(f"We are working on collection : {collection_name}")
# We need to delete all points related to the documents in the collection for avoiding duplicates
del_res = delete_points_related_to_document(
collection_name=collection_name,
Expand Down