Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/qdrant_syncronizer/test_qdrant_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def setUp(self):
def tearDown(self):
self.client.close()

def test_slice_without_embedding_model_should_go_to_none_collection(self):
doc_id = uuid.uuid4()
qdrant_connector = self.client
fake_slice = FakeSlice(doc_id)
fake_slice.embedding_model = None
slices = [fake_slice]
collections_names = classify_documents_per_collection(qdrant_connector, slices)

expected = {
None: {fake_slice.document_id},
}
self.assertEqual(dict(collections_names), expected)

def test_should_get_collections_names_for_given_slices(self):
doc_id = uuid.uuid4()
qdrant_connector = self.client
Expand All @@ -83,7 +96,10 @@ def test_should_get_collections_names_for_given_slices(self):
slices = [fake_slice]
collections_names = classify_documents_per_collection(qdrant_connector, slices)

expected = {"collection_welearn_en_english-embmodel": {fake_slice.document_id}}
expected = {
None: set(),
"collection_welearn_en_english-embmodel": {fake_slice.document_id},
}
self.assertEqual(dict(collections_names), expected)

def test_should_handle_multiple_slices_for_same_collection(self):
Expand All @@ -101,6 +117,7 @@ def test_should_handle_multiple_slices_for_same_collection(self):
slices = [fake_slice0, fake_slice1, fake_slice2]
collections_names = classify_documents_per_collection(qdrant_connector, slices)
expected = {
None: set(),
"collection_welearn_en_english-embmodel": {doc_id0},
"collection_welearn_fr_french-embmodel": {doc_id1},
}
Expand Down Expand Up @@ -130,6 +147,7 @@ def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_co
slices = [fake_slice0, fake_slice1, fake_slice2]
collections_names = classify_documents_per_collection(qdrant_connector, slices)
expected = {
None: set(),
"collection_welearn_en_english-embmodel": {doc_id0},
"collection_welearn_mul_mulembmodel": {doc_id1},
}
Expand Down
14 changes: 11 additions & 3 deletions welearn_datastack/modules/qdrant_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def classify_documents_per_collection(
qdrant_connector: QdrantClient, slices: Collection[Type[DocumentSlice]]
) -> Dict[str, Set[UUID]]:
) -> Dict[str | None, Set[UUID]]:
"""
Classify documents per collection in Qdrant.

Expand All @@ -29,10 +29,18 @@ def classify_documents_per_collection(
tmp_collections_names_in_qdrant = qdrant_connector.get_collections().collections
collections_names_in_qdrant = [c.name for c in tmp_collections_names_in_qdrant]

ret: Dict[str, Set[UUID]] = {}
ret: Dict[str | None, Set[UUID]] = {None: set()}
for dslice in slices:
lang = dslice.document.lang
model = dslice.embedding_model.title
try:
model = dslice.embedding_model.title
except AttributeError:
logger.error(
f"Slice {dslice.id} has no updated embedding model, document ({dslice.document_id}) put in error",
)
ret[None].add(dslice.document_id) # type: ignore
continue

collection_name = None
multilingual_collection = f"collection_welearn_mul_{model}"
mono_collection = f"collection_welearn_{lang}_{model}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ def main() -> None:
qdrant_connector=qdrant_client, slices=slices
)

# Flag documents with no collection
logger.info(
"Flag documents with no collection: %s", len(documents_per_collection[None])
)
for docid in documents_per_collection[None]:
db_session.add(
ProcessState(
id=uuid.uuid4(),
document_id=docid,
title=Step.KEPT_FOR_TRACE.value,
)
)
del documents_per_collection[None]
db_session.commit()

# Iterate on each collection
for collection_name in documents_per_collection:
logger.info(f"We are working on collection : {collection_name}")
Expand Down
Loading