diff --git a/tests/qdrant_syncronizer/test_qdrant_handler.py b/tests/qdrant_syncronizer/test_qdrant_handler.py index 788844b..2a49d87 100644 --- a/tests/qdrant_syncronizer/test_qdrant_handler.py +++ b/tests/qdrant_syncronizer/test_qdrant_handler.py @@ -43,12 +43,13 @@ def __init__(self, document_id, corpus_id): class FakeSlice: - def __init__(self, document_id): + def __init__(self, document_id, embedding_model_name="embmodel"): self.document_id = document_id - self.embedding_model_name = "embmodel" + self.embedding_model_name = embedding_model_name self.embedding = numpy.random.uniform(low=-1, high=1, size=(50,)) self.order_sequence = 0 self.document = FakeDocument(document_id, uuid.uuid4()) + self.id = uuid.uuid4() class TestQdrantHandler(unittest.TestCase): @@ -102,3 +103,32 @@ def test_should_handle_multiple_slices_for_same_collection(self): "collection_welearn_fr_embmodel": {doc_id1}, } self.assertEqual(collections_names, expected) + + def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_collection( + self, + ): + self.client.create_collection( + collection_name="collection_welearn_mul_mulembmodel", + vectors_config=models.VectorParams( + size=50, distance=models.Distance.COSINE + ), + ) + + doc_id0 = uuid.uuid4() + doc_id1 = uuid.uuid4() + qdrant_connector = self.client + fake_slice0 = FakeSlice(doc_id0, embedding_model_name="embmodel") + fake_slice1 = FakeSlice(doc_id0, embedding_model_name="embmodel") + + fake_slice1.order_sequence = 1 + + fake_slice2 = FakeSlice(doc_id1, embedding_model_name="mulembmodel") + fake_slice2.document.lang = "pt" + + slices = [fake_slice0, fake_slice1, fake_slice2] + collections_names = classify_documents_per_collection(qdrant_connector, slices) + expected = { + "collection_welearn_en_embmodel": {doc_id0}, + "collection_welearn_mul_mulembmodel": {doc_id1}, + } + self.assertDictEqual(collections_names, expected) diff --git a/welearn_datastack/constants.py b/welearn_datastack/constants.py index 184bbf6..cfb64f0 100644 --- a/welearn_datastack/constants.py +++ b/welearn_datastack/constants.py @@ -323,3 +323,5 @@ "P4404668943", "P4404677186", ] + +QDRANT_MULTI_LINGUAL_CODE = "mul" diff --git a/welearn_datastack/modules/qdrant_handler.py b/welearn_datastack/modules/qdrant_handler.py index 16db67b..aaffea2 100644 --- a/welearn_datastack/modules/qdrant_handler.py +++ b/welearn_datastack/modules/qdrant_handler.py @@ -8,6 +8,7 @@ from qdrant_client.grpc import UpdateResult from qdrant_client.http.models import models +from welearn_datastack.constants import QDRANT_MULTI_LINGUAL_CODE from welearn_datastack.data.db_models import DocumentSlice from welearn_datastack.exceptions import ( ErrorWhileDeletingChunks, @@ -38,15 +39,22 @@ def classify_documents_per_collection( for dslice in slices: lang = dslice.document.lang model = dslice.embedding_model_name - collection_name = f"collection_welearn_{lang}_{model.lower()}" + collection_name = ( + f"collection_welearn_{QDRANT_MULTI_LINGUAL_CODE}_{model.lower()}" + ) if collection_name not in collections_names_in_qdrant: logger.error( - "Collection %s not found in Qdrant, slice %s ignored", + "Collection %s not found in Qdrant, attempt with language-specific collection name", collection_name, dslice.id, ) - continue + collection_name = f"collection_welearn_{lang}_{model.lower()}" + if collection_name not in collections_names_in_qdrant: + logger.error( + f"Collection {collection_name} not found in Qdrant, {dslice.id} will be ignored", + ) + continue if collection_name not in ret: ret[collection_name] = set()