Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an
TE_MODEL_URI=# HuggingFace model URI
TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
# inference performance tuning
TE_TORCH_DEVICE=# defaults to 'cpu', but can be set to 'mps' for Apple Silicon, or theoretically 'cuda' for GPUs
TE_BATCH_SIZE=# batch size for each inference worker, defaults to 32
TE_NUM_WORKERS=# number of parallel model inference workers, defaults to 1
TE_CHUNK_SIZE=# number of batches each parallel worker grabs; no effect if TE_NUM_WORKERS=1
OMP_NUM_THREADS=# torch env var that sets thread usage during inference, default is not setting and using torch defaults
MKL_NUM_THREADS=# torch env var that sets thread usage during inference, default is not setting and using torch defaults
```

## Configuring an Embedding Model
Expand Down Expand Up @@ -106,14 +113,15 @@ Options:
[required]
--model-path PATH Path where the model will be downloaded to and
loaded from, e.g. '/path/to/model'. [required]
-d, --dataset-location PATH TIMDEX dataset location, e.g.
--dataset-location PATH TIMDEX dataset location, e.g.
's3://timdex/dataset', to read records from.
[required]
--run-id TEXT TIMDEX ETL run id. [required]
--run-id TEXT TIMDEX ETL run id.
--run-record-offset INTEGER TIMDEX ETL run record offset to start from,
default = 0. [required]
default = 0.
--record-limit INTEGER Limit number of records after --run-record-
offset, default = None (unlimited). [required]
offset, default = None (unlimited).
--input-jsonl TEXT Optional filepath to JSONLines file containing
TIMDEX records to create embeddings from.
--strategy [full_record] Pre-embedding record transformation strategy.
Repeatable to apply multiple strategies.
[required]
Expand Down
4 changes: 2 additions & 2 deletions embeddings/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class Embedding:
run_record_offset: int
model_uri: str
embedding_strategy: str
embedding_vector: list[float]
embedding_token_weights: dict
embedding_vector: list[float] | None
embedding_token_weights: dict | None

timestamp: datetime.datetime = field(
default_factory=lambda: datetime.datetime.now(datetime.UTC)
Expand Down
20 changes: 13 additions & 7 deletions embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def create_embeddings(

# read env vars for configurations
num_workers = int(os.getenv("TE_NUM_WORKERS", "1"))
batch_size = int(
os.getenv("TE_BATCH_SIZE", "32")
) # sentence-transformers default
batch_size = int(os.getenv("TE_BATCH_SIZE", "32"))
chunk_size_env = os.getenv("TE_CHUNK_SIZE")
chunk_size = int(chunk_size_env) if chunk_size_env else None

# configure device and worker pool based on number of workers requested
# configure for inference
if num_workers > 1 or self.device == "mps":
device = None
pool = self._model.start_multi_process_pool(
Expand All @@ -206,17 +206,20 @@ def create_embeddings(
pool = None
logger.info(
f"Num workers: {num_workers}, batch size: {batch_size}, "
f"device: {device}, pool: {pool}"
f"chunk size: {chunk_size, }device: {device}, pool: {pool}"
)

# get sparse vector embedding for input text(s)
inference_start = time.perf_counter()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate the var and logging name change for specificity on what it is tracking

sparse_vectors = self._model.encode_document(
texts,
batch_size=batch_size,
device=device,
pool=pool,
save_to_cpu=True,
chunk_size=chunk_size,
)
logger.info(f"Inference elapsed: {time.perf_counter()-inference_start}s")
sparse_vectors = cast("list[Tensor]", sparse_vectors)

for i, embedding_input in enumerate(embedding_inputs_list):
Expand Down Expand Up @@ -244,8 +247,11 @@ def _get_embedding_from_sparse_vector(
decoded_token_weights = cast("list[tuple[str, float]]", decoded_token_weights)
embedding_token_weights = dict(decoded_token_weights)

# prepare sparse vector for JSON serialization
embedding_vector = sparse_vector.to_dense().tolist()
# # prepare sparse vector for JSON serialization
# NOTE: at this time we are NOT including the sparse vector for output. This
# block can be uncommented in the future to include it when wanted.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good approach to this change!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought you'd like this update @ehanson8! Glad we have this stubbed if we do want them in the future, but I think your instinct was right to not include them at first. Going to be lots of churn in the embeddings creation for a bit as we tune things.

# embedding_vector = sparse_vector.to_dense().tolist() # noqa: ERA001
embedding_vector = None

return Embedding(
timdex_record_id=embedding_input.timdex_record_id,
Expand Down
13 changes: 11 additions & 2 deletions embeddings/strategies/full_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,14 @@ class FullRecordStrategy(BaseStrategy):
STRATEGY_NAME = "full_record"

def extract_text(self, timdex_record: dict) -> str:
"""Serialize the entire transformed_record as JSON."""
return json.dumps(timdex_record)
"""Serialize the entire TIMDEX record.

The final string form is:
<field>: <value as JSON><newline>
<field>: <value as JSON><newline>
...
"""
final_string = ""
for k, v in timdex_record.items():
final_string += f"{k}: {json.dumps(v)}\n"
return final_string
6 changes: 3 additions & 3 deletions tests/test_os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_create_embedding_returns_embedding_object(tmp_path):
assert embedding.run_record_offset == 42
assert embedding.model_uri == model.model_uri
assert embedding.embedding_strategy == "title_only"
assert embedding.embedding_vector == pytest.approx([0.1, 0.2])
assert embedding.embedding_vector is None
assert embedding.embedding_token_weights == {"sum": pytest.approx(0.3)}


Expand Down Expand Up @@ -257,6 +257,6 @@ def test_create_embeddings_consumes_iterator_and_returns_embeddings(

assert len(embeddings) == 2
assert embeddings[0].timdex_record_id == "id-1"
assert embeddings[0].embedding_vector == pytest.approx([0.1, 0.2])
assert embeddings[0].embedding_vector is None
assert embeddings[1].timdex_record_id == "id-2"
assert embeddings[1].embedding_vector == pytest.approx([0.3, 0.4])
assert embeddings[1].embedding_vector is None
4 changes: 1 addition & 3 deletions tests/test_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import pytest

from embeddings.strategies.base import BaseStrategy
Expand All @@ -14,7 +12,7 @@ def test_full_record_strategy_extracts_text():

text = strategy.extract_text(timdex_record)

assert text == json.dumps(timdex_record)
assert text == """timdex_record_id: "test-123"\ntitle: ["Test Title"]\n"""
assert strategy.STRATEGY_NAME == "full_record"


Expand Down
Loading