Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions tests/qdrant_syncronizer/test_qdrant_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def __init__(self, document_id, corpus_id):


class FakeSlice:
def __init__(self, document_id):
def __init__(self, document_id, embedding_model_name="embmodel"):
self.document_id = document_id
self.embedding_model_name = "embmodel"
self.embedding_model_name = embedding_model_name
self.embedding = numpy.random.uniform(low=-1, high=1, size=(50,))
self.order_sequence = 0
self.document = FakeDocument(document_id, uuid.uuid4())
self.id = uuid.uuid4()


class TestQdrantHandler(unittest.TestCase):
Expand Down Expand Up @@ -102,3 +103,32 @@ def test_should_handle_multiple_slices_for_same_collection(self):
"collection_welearn_fr_embmodel": {doc_id1},
}
self.assertEqual(collections_names, expected)

def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_collection(
self,
):
self.client.create_collection(
collection_name="collection_welearn_mul_mulembmodel",
vectors_config=models.VectorParams(
size=50, distance=models.Distance.COSINE
),
)

doc_id0 = uuid.uuid4()
doc_id1 = uuid.uuid4()
qdrant_connector = self.client
fake_slice0 = FakeSlice(doc_id0, embedding_model_name="embmodel")
fake_slice1 = FakeSlice(doc_id0, embedding_model_name="embmodel")

fake_slice1.order_sequence = 1

fake_slice2 = FakeSlice(doc_id1, embedding_model_name="mulembmodel")
fake_slice2.document.lang = "pt"

slices = [fake_slice0, fake_slice1, fake_slice2]
collections_names = classify_documents_per_collection(qdrant_connector, slices)
expected = {
"collection_welearn_en_embmodel": {doc_id0},
"collection_welearn_mul_mulembmodel": {doc_id1},
}
self.assertDictEqual(collections_names, expected)
2 changes: 2 additions & 0 deletions welearn_datastack/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,5 @@
"P4404668943",
"P4404677186",
]

QDRANT_MULTI_LINGUAL_CODE = "mul"
14 changes: 11 additions & 3 deletions welearn_datastack/modules/qdrant_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from qdrant_client.grpc import UpdateResult
from qdrant_client.http.models import models

from welearn_datastack.constants import QDRANT_MULTI_LINGUAL_CODE
from welearn_datastack.data.db_models import DocumentSlice
from welearn_datastack.exceptions import (
ErrorWhileDeletingChunks,
Expand Down Expand Up @@ -38,15 +39,22 @@ def classify_documents_per_collection(
for dslice in slices:
lang = dslice.document.lang
model = dslice.embedding_model_name
collection_name = f"collection_welearn_{lang}_{model.lower()}"
collection_name = (
f"collection_welearn_{QDRANT_MULTI_LINGUAL_CODE}_{model.lower()}"
)

if collection_name not in collections_names_in_qdrant:
logger.error(
"Collection %s not found in Qdrant, slice %s ignored",
"Collection %s not found in Qdrant, attempt with language-specific collection name",
collection_name,
dslice.id,
)
continue
collection_name = f"collection_welearn_{lang}_{model.lower()}"
if collection_name not in collections_names_in_qdrant:
logger.error(
f"Collection {collection_name} not found in Qdrant, {dslice.id} will be ignored",
)
continue

if collection_name not in ret:
ret[collection_name] = set()
Expand Down