Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ repos:
types: ["python"]
- id: pip-audit
name: pip-audit
entry: pipenv run pip-audit --ignore-vuln GHSA-4xh5-x5gv-qwph
entry: pipenv run pip-audit
language: system
pass_filenames: false
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ruff: # Run 'ruff' linter and print a preview of errors
pipenv run ruff check .

safety: # Check for security vulnerabilities and verify Pipfile.lock is up-to-date
pipenv run pip-audit --ignore-vuln GHSA-4xh5-x5gv-qwph
pipenv run pip-audit
pipenv verify

lint-apply: black-apply ruff-apply # Apply changes with 'black' and resolve 'fixable errors' with 'ruff'
Expand Down
4 changes: 2 additions & 2 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ name = "pypi"
[packages]
attrs = "*"
boto3 = "*"
duckdb = "==1.4.2"
duckdb-engine = "*"
pandas = "*"
pyarrow = "*"
sqlalchemy = "*"
duckdb-engine = "*"
duckdb = "==1.4.2.dev27"

[dev-packages]
black = "*"
Expand Down
1,172 changes: 596 additions & 576 deletions Pipfile.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"attrs",
"boto3",
"duckdb==1.4.2.dev27",
"duckdb==1.4.2",
"duckdb_engine",
"pandas",
"pyarrow",
Expand Down
19 changes: 18 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import moto
import pytest

from tests.utils import generate_sample_records
from tests.utils import generate_sample_embeddings, generate_sample_records
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
from timdex_dataset_api.dataset import TIMDEXDatasetConfig
from timdex_dataset_api.embeddings import DatasetEmbedding
from timdex_dataset_api.record import DatasetRecord


Expand Down Expand Up @@ -305,3 +306,19 @@ def _generate(num_records: int = 100, **kwargs) -> Iterator[DatasetRecord]:
return generate_sample_records(num_records=num_records, **kwargs)

return _generate


@pytest.fixture
def sample_embeddings() -> Iterator[DatasetEmbedding]:
"""Generate 100 sample embeddings with default parameters."""
return generate_sample_embeddings(num_embeddings=100)


@pytest.fixture
def sample_embeddings_generator():
"""Factory fixture for generating custom sample embeddings."""

def _generate(num_embeddings: int = 100, **kwargs) -> Iterator[DatasetEmbedding]:
return generate_sample_embeddings(num_embeddings=num_embeddings, **kwargs)

return _generate
130 changes: 130 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# ruff: noqa: PLR2004
import json
import math
import os
from datetime import UTC, datetime

import pyarrow.dataset as ds

from timdex_dataset_api.embeddings import (
TIMDEX_DATASET_EMBEDDINGS_SCHEMA,
DatasetEmbedding,
TIMDEXEmbeddings,
)


def test_dataset_embedding_init():
values = {
"timdex_record_id": "alma:123",
"run_id": "test-run-1",
"run_record_offset": 0,
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_strategy": "full_record",
"timestamp": "2024-12-01T10:00:00+00:00",
"embedding_vector": [0.1, 0.2, 0.3],
"embedding_object": json.dumps(
{"token1": 0.1, "token2": 0.2, "token3": 0.3}
).encode(),
}
embedding = DatasetEmbedding(**values)

assert embedding
assert embedding.timdex_record_id == "alma:123"
assert embedding.timestamp == datetime(2024, 12, 1, 10, 0, tzinfo=UTC)
assert embedding.embedding_object == b'{"token1": 0.1, "token2": 0.2, "token3": 0.3}'


def test_dataset_embedding_date_properties():
embedding = DatasetEmbedding(
timdex_record_id="alma:123",
run_id="test-run-1",
run_record_offset=0,
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_strategy="full_record",
timestamp="2024-12-01T10:00:00+00:00",
embedding_vector=[0.1, 0.2, 0.3],
)

assert (embedding.year, embedding.month, embedding.day) == ("2024", "12", "01")


def test_dataset_embedding_to_dict():
values = {
"timdex_record_id": "alma:123",
"run_id": "test-run-1",
"run_record_offset": 0,
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"embedding_strategy": "full_record",
"timestamp": "2024-12-01T10:00:00+00:00",
"embedding_vector": [0.1, 0.2, 0.3],
"embedding_object": None,
}
embedding = DatasetEmbedding(**values)
embedding_dict = embedding.to_dict()

assert embedding_dict["timdex_record_id"] == "alma:123"
assert embedding_dict["year"] == "2024"
assert embedding_dict["month"] == "12"
assert embedding_dict["day"] == "01"
assert embedding_dict["embedding_vector"] == [0.1, 0.2, 0.3]


def test_embeddings_data_root_property(timdex_dataset_empty):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)

expected = f"{timdex_dataset_empty.location.removesuffix('/')}/data/embeddings"
assert timdex_embeddings.data_embeddings_root == expected


def test_embeddings_write_basic(timdex_dataset_empty, sample_embeddings_generator):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
written_files = timdex_embeddings.write(sample_embeddings_generator(100))

assert len(written_files) == 1
assert os.path.exists(written_files[0].path)

# verify written data can be read
dataset = ds.dataset(
timdex_embeddings.data_embeddings_root, format="parquet", partitioning="hive"
)
assert dataset.count_rows() == 100


def test_embeddings_write_partitioning(timdex_dataset_empty, sample_embeddings_generator):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
written_files = timdex_embeddings.write(sample_embeddings_generator(10))

assert len(written_files) == 1
assert "year=2024/month=12/day=01" in written_files[0].path


def test_embeddings_write_schema_applied(
timdex_dataset_empty, sample_embeddings_generator
):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
timdex_embeddings.write(sample_embeddings_generator(10))

# manually load dataset to confirm schema
dataset = ds.dataset(
timdex_embeddings.data_embeddings_root,
format="parquet",
partitioning="hive",
)

assert set(dataset.schema.names) == set(TIMDEX_DATASET_EMBEDDINGS_SCHEMA.names)


def test_embeddings_create_batches(timdex_dataset_empty, sample_embeddings_generator):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
total_embeddings = 101
timdex_dataset_empty.config.write_batch_size = 50

batches = list(
timdex_embeddings.create_embedding_batches(
sample_embeddings_generator(total_embeddings)
)
)

assert len(batches) == math.ceil(
total_embeddings / timdex_dataset_empty.config.write_batch_size
)
36 changes: 36 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

# ruff: noqa: S311

import json
import random
import uuid
from collections.abc import Iterator

from timdex_dataset_api import DatasetRecord
from timdex_dataset_api.embeddings import DatasetEmbedding


def generate_sample_records(
Expand Down Expand Up @@ -58,3 +60,37 @@ def generate_sample_records_with_simulated_partitions(
action=random.choice(actions),
)
records_remaining -= batch_size


def generate_sample_embeddings(
num_embeddings: int,
source: str | None = "alma",
embedding_model: str | None = "super-org/amazing-model",
embedding_strategy: str | None = "full_record",
run_id: str | None = None,
timestamp: str | None = "2024-12-01T00:00:00+00:00",
) -> Iterator[DatasetEmbedding]:
"""Generate sample DatasetEmbeddings."""
if not run_id:
run_id = str(uuid.uuid4())

for x in range(num_embeddings):
embedding_vector = [random.random() for _ in range(768)]
embedding_object = json.dumps(
{
"token1": 0.1,
"token2": 0.2,
"token3": 0.3,
}
).encode()

yield DatasetEmbedding(
timdex_record_id=f"{source}:{x}",
run_id=run_id,
run_record_offset=x,
embedding_model=embedding_model,
embedding_strategy=embedding_strategy,
timestamp=timestamp,
embedding_vector=embedding_vector,
embedding_object=embedding_object,
)
2 changes: 1 addition & 1 deletion timdex_dataset_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.record import DatasetRecord

__version__ = "3.5.0"
__version__ = "3.6.0"

__all__ = [
"DatasetRecord",
Expand Down
Loading