Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ spec:
default: {{ .workflowTemplates.runNode.requests.memory }}
- name: size_limit
default: 10000000000 # In bytes
- name: st_backend
default: "onnx"
steps:
- - name: generate-to-vectorize-batch
templateRef:
Expand Down Expand Up @@ -112,6 +114,9 @@ spec:
- name: memory
value: >-
{{ print "{{inputs.parameters.memory_collect_docs}}" }}
- name: st_backend
value: >-
{{ print "{{inputs.parameters.st_backend}}" }}
artifacts:
- name: batch_ids_csv
from: >-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ spec:
default: 10000000000
- name: st_device
default: "cpu"
- name: st_backend
default: "onnx"
- name: embedding_model_fr
default: {{ $.Values.common.embeddingModelFr }}
- name: embedding_model_en
Expand Down Expand Up @@ -78,6 +80,9 @@ spec:
- name: ST_DEVICE
value: >-
{{ print "{{inputs.parameters.st_device}}" }}
- name: ST_BACKEND
value: >-
{{ print "{{inputs.parameters.st_backend}}" }}
- name: MODELS_PATH_ROOT
value: {{ $.Values.common.modelsPathRoot }}

Expand Down
1,347 changes: 1,311 additions & 36 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ lingua-language-detector = "^2.1.1"
psycopg2-binary = "^2.9.10"
brotli = "^1.1.0"
scikit-learn = "1.6.1"
optimum = {extras = ["onnxruntime"], version = "^1.26.1"}

[tool.poetry.group.metrics.dependencies]
alembic = "^1.16.1"
Expand Down
4 changes: 1 addition & 3 deletions tests/document_vectorizer/test_embedding_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
class TestEmbeddingHelper(TestCase):
def setUp(self) -> None:
get_sub_environ_according_prefix.cache_clear()
os.environ["MODELS_NAME_PREFIX"] = "EMBEDDING_MODEL"
os.environ["EMBEDDING_MODEL_FR"] = "test_fr"
os.environ["EMBEDDING_MODEL_EN"] = "test_en"
os.environ["ST_BACKEND"] = "onnx"

def tearDown(self) -> None:
os.environ.clear()
Expand Down
6 changes: 6 additions & 0 deletions welearn_datastack/modules/embedding_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,16 @@ def load_embedding_model(str_path: str) -> SentenceTransformer:
logger.info("Loading embedding model from %s", str_path)

device = os.environ.get("ST_DEVICE", None)
backend = os.environ.get("ST_BACKEND", None)
logger.info("ST_DEVICE: %s", device)
logger.info("ST_BACKEND: %s", backend)

if device not in ["cpu", "cuda", None]:
raise ValueError("ST_DEVICE must be one of 'cpu', 'cuda' or None")

if backend not in ["torch", "onnx", "openvino"]:
raise ValueError("ST_BACKEND must be one of 'torch', 'onnx' or 'openvino'")

model = loaded_models.get(str_path, None)
if model is not None:
logger.info("%s Model already loaded", str_path)
Expand All @@ -98,6 +103,7 @@ def load_embedding_model(str_path: str) -> SentenceTransformer:
loaded_models[str_path] = SentenceTransformer(
model_name_or_path=str_path,
device=device,
backend=backend, # type: ignore
)
return loaded_models[str_path]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def main() -> None:
# Create content slices
docids_processed = 0
docsids_not_processed = 0
bulk_slices: list[DocumentSlice] = []
bulk_process_state: list[ProcessState] = []
for i, document in enumerate(welearn_documents):
logger.info("Processing document %s/%s", i, len(welearn_documents))
try:
Expand All @@ -87,40 +89,49 @@ def main() -> None:
f"No embedding model found for document {document.id}"
)
slices = create_content_slices(document, embedding_model_name=embedding_model_name, embedding_model_id=embedding_model_id) # type: ignore
logger.info("'%s' slices were created", len(slices))
logger.info("Delete old slices")
db_session.query(DocumentSlice).filter(
DocumentSlice.document_id == document.id
).delete()
db_session.commit()
logger.info("Insert new slices")
db_session.add_all(slices)
logger.info("Insert new state")
db_session.add(

logger.info("Adding slices to bulk")
bulk_slices.extend(slices)

logger.info("Adding process state to bulk")
bulk_process_state.append(
ProcessState(
id=uuid.uuid4(),
document_id=document.id,
title=Step.DOCUMENT_VECTORIZED.value,
)
)
db_session.commit()
logger.info("'%s' slices were created", len(slices))

docids_processed += 1
except NoModelFoundError:
logger.error("No model found for document %s", document.id)
db_session.add(
bulk_process_state.append(
ProcessState(
id=uuid.uuid4(),
document_id=document.id,
title=Step.KEPT_FOR_TRACE.value,
)
)
db_session.commit()
docsids_not_processed += 1
continue

logger.info("'%s' documents were processed", docids_processed)
logger.info("'%s' documents were not processed", docsids_not_processed)

db_session.bulk_save_objects(bulk_slices)
logger.info("'%s' slices were added to the session", len(bulk_slices))
db_session.bulk_save_objects(bulk_process_state)
logger.info(
"'%s' process states were added to the session", len(bulk_process_state)
)

db_session.commit()
db_session.close()
logger.info("DocumentVectorizer finished")

Expand Down