Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# CODEOWNERS file (from GitHub template at
# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners)
# Each line is a file pattern followed by one or more owners.

################################################################################
# These owners will be the default owners for everything in the repo. This is commented
# out in favor of using a team as the default (see below). It is left here as a comment
# to indicate the primary expert for this code.
# * @adamshire123

# Teams can be specified as code owners as well. Teams should be identified in
# the format @org/team-name. Teams must have explicit write access to the
# repository.
* @mitlibraries/dataeng

# We set the senior engineer in the team as the owner of the CODEOWNERS file as
# a layer of protection for unauthorized changes.
/.github/CODEOWNERS @ghukill
15 changes: 2 additions & 13 deletions .github/pull-request-template.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,5 @@ YES | NO
### What are the relevant tickets?
- Include links to Jira Software and/or Jira Service Management tickets here.

### Developer
- [ ] All new ENV is documented in README
- [ ] All new ENV has been added to staging and production environments
- [ ] All related Jira tickets are linked in commit message(s)
- [ ] Stakeholder approval has been confirmed (or is not needed)

### Code Reviewer(s)
- [ ] The commit message is clear and follows our guidelines (not just this PR message)
- [ ] There are appropriate tests covering any new functionality
- [ ] The provided documentation is sufficient for understanding any new functionality introduced
- [ ] Any manual tests have been performed **or** provided examples verified
- [ ] New dependencies are appropriate or there were no changes

### Code review
* Code review best practices are documented [here](https://mitlibraries.github.io/guides/collaboration/code_review.html) and you are encouraged to have a constructive dialogue with your reviewers about their preferences and expectations.
7 changes: 2 additions & 5 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name = "pypi"
[packages]
attrs = "*"
boto3 = "*"
duckdb = "==1.4.2"
duckdb = "*"
duckdb-engine = "*"
pandas = "*"
pyarrow = "*"
Expand All @@ -29,7 +29,4 @@ setuptools = "*"
pip-audit = "*"

[requires]
python_version = "3.12"

[pipenv]
allow_prereleases = true
python_version = "3.12"
622 changes: 376 additions & 246 deletions Pipfile.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"attrs",
"boto3",
"duckdb==1.4.2",
"duckdb",
"duckdb_engine",
"pandas",
"pyarrow",
Expand Down
17 changes: 16 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from tests.utils import generate_sample_embeddings, generate_sample_records
from timdex_dataset_api import TIMDEXDataset, TIMDEXDatasetMetadata
from timdex_dataset_api.dataset import TIMDEXDatasetConfig
from timdex_dataset_api.embeddings import DatasetEmbedding
from timdex_dataset_api.embeddings import (
DatasetEmbedding,
TIMDEXEmbeddings,
)
from timdex_dataset_api.record import DatasetRecord


Expand Down Expand Up @@ -287,6 +290,18 @@ def timdex_metadata_merged_deltas(
return metadata


# ================================================================================
# Dataset Embeddings Fixtures
# ================================================================================
@pytest.fixture
def timdex_embeddings_with_runs(timdex_dataset_empty):
"""TIMDEXEmbeddings with multiple runs for single strategy."""
embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
embeddings.write(generate_sample_embeddings(100, run_id="abc123")) # run 1
embeddings.write(generate_sample_embeddings(50, run_id="def456")) # run 2
return TIMDEXEmbeddings(timdex_dataset_empty)


# ================================================================================
# Utility Fixtures
# ================================================================================
Expand Down
83 changes: 83 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from datetime import UTC, datetime

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds

from timdex_dataset_api.embeddings import (
Expand All @@ -12,6 +14,8 @@
TIMDEXEmbeddings,
)

EMBEDDINGS_COLUMNS_SET = set(TIMDEX_DATASET_EMBEDDINGS_SCHEMA.names)


def test_dataset_embedding_init():
values = {
Expand Down Expand Up @@ -128,3 +132,82 @@ def test_embeddings_create_batches(timdex_dataset_empty, sample_embeddings_gener
assert len(batches) == math.ceil(
total_embeddings / timdex_dataset_empty.config.write_batch_size
)


def test_embeddings_read_batches_yields_pyarrow_record_batches(
timdex_dataset_empty, sample_embeddings_generator
):
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)
timdex_embeddings.write(sample_embeddings_generator(100))
timdex_embeddings = TIMDEXEmbeddings(timdex_dataset_empty)

batches = timdex_embeddings.read_batches_iter()
batch = next(batches)
assert isinstance(batch, pa.RecordBatch)


def test_embeddings_read_batches_all_columns_by_default(timdex_embeddings_with_runs):
batches = timdex_embeddings_with_runs.read_batches_iter()
batch = next(batches)
assert set(batch.column_names) == EMBEDDINGS_COLUMNS_SET


def test_embeddings_read_batches_filter_columns(timdex_embeddings_with_runs):
columns_subset = ["timdex_record_id", "run_id", "embedding_strategy"]
batches = timdex_embeddings_with_runs.read_batches_iter(columns=columns_subset)
batch = next(batches)
assert set(batch.column_names) == set(columns_subset)


def test_embeddings_read_batches_gets_full_dataset(timdex_embeddings_with_runs):
batches = timdex_embeddings_with_runs.read_batches_iter()
table = pa.Table.from_batches(batches)
dataset = ds.dataset(
timdex_embeddings_with_runs.data_embeddings_root,
format="parquet",
partitioning="hive",
)
assert len(table) == dataset.count_rows()


def test_embeddings_read_batches_with_filters_gets_subset_of_dataset(
timdex_embeddings_with_runs,
):
batches = timdex_embeddings_with_runs.read_batches_iter(
run_id="abc123", embedding_strategy="full_record"
)
table = pa.Table.from_batches(batches)
dataset = ds.dataset(
timdex_embeddings_with_runs.data_embeddings_root,
format="parquet",
partitioning="hive",
)
assert len(table) == 100
assert len(table) < dataset.count_rows()


def test_embeddings_read_dataframes_yields_dataframes(timdex_embeddings_with_runs):
df_iter = timdex_embeddings_with_runs.read_dataframes_iter()
df_batch = next(df_iter)
assert isinstance(df_batch, pd.DataFrame)
assert len(df_batch) == 150


def test_embeddings_read_dataframe_gets_full_dataset(timdex_embeddings_with_runs):
df = timdex_embeddings_with_runs.read_dataframe()
dataset = ds.dataset(
timdex_embeddings_with_runs.data_embeddings_root,
format="parquet",
partitioning="hive",
)
assert isinstance(df, pd.DataFrame)
assert len(df) == dataset.count_rows()


def test_embeddings_read_dicts_yields_dictionary_for_each_embeddings_record(
timdex_embeddings_with_runs,
):
dict_iter = timdex_embeddings_with_runs.read_dicts_iter()
record = next(dict_iter)
assert isinstance(record, dict)
assert set(record.keys()) == EMBEDDINGS_COLUMNS_SET
4 changes: 2 additions & 2 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ def test_read_batches_with_filters_gets_subset_of_dataset(timdex_dataset_multi_s
assert timdex_dataset_multi_source.dataset.count_rows() == 5_000


def test_read_dataframe_batches_yields_dataframes(timdex_dataset_multi_source):
def test_read_dataframes_yields_dataframes(timdex_dataset_multi_source):
df_iter = timdex_dataset_multi_source.read_dataframes_iter()
df_batch = next(df_iter)
assert isinstance(df_batch, pd.DataFrame)
assert len(df_batch) == 1_000


def test_read_dataframe_reads_all_dataset_rows_after_filtering(
def test_read_dataframe_gets_full_dataset(
timdex_dataset_multi_source,
):
df = timdex_dataset_multi_source.read_dataframe()
Expand Down
2 changes: 1 addition & 1 deletion timdex_dataset_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.record import DatasetRecord

__version__ = "3.6.1"
__version__ = "3.7.0"

__all__ = [
"DatasetEmbedding",
Expand Down
7 changes: 6 additions & 1 deletion timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pyarrow import fs

from timdex_dataset_api.config import configure_logger
from timdex_dataset_api.embeddings import TIMDEXEmbeddings
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -140,6 +141,9 @@ def __init__(
# DuckDB context
self.conn = self.setup_duckdb_context()

# dataset embeddings
self.embeddings = TIMDEXEmbeddings(self)

@property
def location_scheme(self) -> Literal["file", "s3"]:
scheme = urlparse(self.location).scheme
Expand Down Expand Up @@ -255,7 +259,8 @@ def setup_duckdb_context(self) -> DuckDBPyConnection:
conn.execute("""create schema data;""")

logger.debug(
f"DuckDB data context created, {round(time.perf_counter()-start_time,2)}s"
"DuckDB context created for TIMDEXDataset, "
f"{round(time.perf_counter()-start_time,2)}s"
)
return conn

Expand Down
Loading