diff --git a/abdiff/core/collate_ab_transforms.py b/abdiff/core/collate_ab_transforms.py index a69765b..f3eadeb 100644 --- a/abdiff/core/collate_ab_transforms.py +++ b/abdiff/core/collate_ab_transforms.py @@ -1,3 +1,5 @@ +# ruff: noqa: TRY003 + import itertools import json import logging @@ -8,6 +10,7 @@ import duckdb import ijson +import pandas as pd import pyarrow as pa from abdiff.core.exceptions import OutputValidationError @@ -20,15 +23,21 @@ ( pa.field("timdex_record_id", pa.string()), pa.field("source", pa.string()), + pa.field("run_date", pa.date32()), + pa.field("run_type", pa.string()), + pa.field("action", pa.string()), pa.field("record", pa.binary()), pa.field("version", pa.string()), pa.field("transformed_file_name", pa.string()), ) ) -JOINED_DATASET_SCHEMA = pa.schema( +COLLATED_DATASET_SCHEMA = pa.schema( ( pa.field("timdex_record_id", pa.string()), pa.field("source", pa.string()), + pa.field("run_date", pa.date32()), + pa.field("run_type", pa.string()), + pa.field("action", pa.string()), pa.field("record_a", pa.binary()), pa.field("record_b", pa.binary()), ) @@ -40,18 +49,19 @@ def collate_ab_transforms( ) -> str: """Collates A/B transformed files into a Parquet dataset. - This process can be summarized into two (2) important steps: + This process can be summarized into two (3) important steps: 1. Write all transformed JSON records into a temporary Parquet dataset partitioned by the transformed file name. 2. For every transformed file, use DuckDB to join A/B Parquet tables using the TIMDEX record ID and write joined records to a Parquet dataset. - - This function (and its subfunctions) uses DuckDB, generators, and batching to - write records to Parquet datasets in a memory-efficient manner. + 3. Dedupe joined records to ensure that only the most recent, not "deleted" + timdex_record_id is present in final output. """ transformed_dataset_path = tempfile.TemporaryDirectory() + joined_dataset_path = tempfile.TemporaryDirectory() collated_dataset_path = str(Path(run_directory) / "collated") + # build temporary transformed dataset transformed_written_files = write_to_dataset( get_transformed_batches_iter(run_directory, ab_transformed_file_lists), schema=TRANSFORMED_DATASET_SCHEMA, @@ -62,15 +72,30 @@ def collate_ab_transforms( f"Wrote {len(transformed_written_files)} parquet file(s) to transformed dataset" ) + # build temporary collated dataset joined_written_files = write_to_dataset( get_joined_batches_iter(transformed_dataset_path.name), - base_dir=collated_dataset_path, - schema=JOINED_DATASET_SCHEMA, + base_dir=joined_dataset_path.name, + schema=COLLATED_DATASET_SCHEMA, ) logger.info(f"Wrote {len(joined_written_files)} parquet file(s) to collated dataset") + # build final deduped and collated dataset + deduped_written_files = write_to_dataset( + get_deduped_batches_iter(joined_dataset_path.name), + base_dir=collated_dataset_path, + schema=COLLATED_DATASET_SCHEMA, + ) + logger.info( + f"Wrote {len(deduped_written_files)} parquet file(s) to deduped collated dataset" + ) + validate_output(collated_dataset_path) + # ensure temporary artifacts removed + transformed_dataset_path.cleanup() + joined_dataset_path.cleanup() + return collated_dataset_path @@ -85,7 +110,9 @@ def get_transformed_records_iter( * timdex_record_id: The TIMDEX record ID. * source: The shorthand name of the source as denoted in by Transmogrifier - (see https://github.com/MITLibraries/transmogrifier/blob/main/transmogrifier/config.py). + * run_date: Run date from TIMDEX ETL + * run_type: "full" or "daily" + * action: "index" or "delete" * record: The TIMDEX record serialized to a JSON string then encoded to bytes. * version: The version of the transform, parsed from the absolute filepath to a transformed file. @@ -93,14 +120,34 @@ def get_transformed_records_iter( """ version = get_transform_version(transformed_file) filename_details = parse_timdex_filename(transformed_file) - with open(transformed_file, "rb") as file: - for record in ijson.items(file, "item"): + + base_record = { + "source": filename_details["source"], + "run_date": filename_details["run-date"], + "run_type": filename_details["run-type"], + "action": filename_details["action"], + "version": version, + "transformed_file_name": transformed_file.split("/")[-1], + } + + # handle JSON files with records to index + if transformed_file.endswith(".json"): + with open(transformed_file, "rb") as file: + for record in ijson.items(file, "item"): + yield { + **base_record, + "timdex_record_id": record["timdex_record_id"], + "record": json.dumps(record).encode(), + } + + # handle TXT files with records to delete + else: + deleted_records_df = pd.read_csv(transformed_file, header=None) + for row in deleted_records_df.itertuples(): yield { - "timdex_record_id": record["timdex_record_id"], - "source": filename_details["source"], - "record": json.dumps(record).encode(), - "version": version, - "transformed_file_name": transformed_file.split("/")[-1], + **base_record, + "timdex_record_id": row[1], + "record": None, } @@ -192,6 +239,9 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch] SELECT COALESCE(a.timdex_record_id, b.timdex_record_id) timdex_record_id, COALESCE(a.source, b.source) source, + COALESCE(a.run_date, b.run_date) run_date, + COALESCE(a.run_type, b.run_type) run_type, + COALESCE(a.action, b.action) "action", a.record as record_a, b.record as record_b FROM a @@ -210,6 +260,64 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch] break +def get_deduped_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]: + """Yield pyarrow.RecordBatch objects of deduped rows from the joined dataset. + + ABDiff should be able to handle many input files, where a single timdex_record_id may + be duplicated across multiple files ("full" vs "daily" runs, incrementing date runs, + etc.) + + This function writes the final dataset by deduping records from the temporary collated + dataset, given the following logic: + - use the MOST RECENT record based on 'run_date' + - if the MOST RECENT record is action='delete', then omit record entirely + + The same mechanism is used by get_joined_batches_iter() to perform a DuckDB query then + stream write batches to a parquet dataset. + """ + with duckdb.connect(":memory:") as con: + + results = con.execute( + """ + WITH collated as ( + select * from read_parquet($collated_parquet_glob, hive_partitioning=true) + ), + latest_records AS ( + SELECT + *, + ROW_NUMBER() OVER ( + PARTITION BY timdex_record_id + ORDER BY run_date DESC + ) AS rn + FROM collated + ), + deduped_records AS ( + SELECT * + FROM latest_records + WHERE rn = 1 AND action != 'delete' + ) + SELECT + timdex_record_id, + source, + run_date, + run_type, + action, + record_a, + record_b + FROM deduped_records; + """, + { + "collated_parquet_glob": f"{dataset_directory}/**/*.parquet", + }, + ).fetch_record_batch(READ_BATCH_SIZE) + + while True: + try: + yield results.read_next_batch() + except StopIteration: + break # pragma: nocover + + def validate_output(dataset_path: str) -> None: """Validate the output of collate_ab_transforms. @@ -217,8 +325,14 @@ def validate_output(dataset_path: str) -> None: and whether any or both 'record_a' or 'record_b' columns are totally empty. """ + + def fetch_single_value(query: str) -> int: + result = con.execute(query).fetchone() + if result is None: + raise RuntimeError(f"Query returned no results: {query}") # pragma: nocover + return int(result[0]) + with duckdb.connect(":memory:") as con: - # create view of collated table con.execute( f""" CREATE VIEW collated AS ( @@ -228,41 +342,48 @@ def validate_output(dataset_path: str) -> None: ) # check if the table is empty - record_count = con.execute("SELECT COUNT(*) FROM collated").fetchone()[0] # type: ignore[index] + record_count = fetch_single_value("SELECT COUNT(*) FROM collated") if record_count == 0: - raise OutputValidationError( # noqa: TRY003 + raise OutputValidationError( "The collated dataset does not contain any records." ) # check if any of the 'record_*' columns are empty - record_a_null_count = con.execute( + record_a_null_count = fetch_single_value( "SELECT COUNT(*) FROM collated WHERE record_a ISNULL" - ).fetchone()[ - 0 - ] # type: ignore[index] - - record_b_null_count = con.execute( + ) + record_b_null_count = fetch_single_value( "SELECT COUNT(*) FROM collated WHERE record_b ISNULL" - ).fetchone()[ - 0 - ] # type: ignore[index] + ) if record_count in {record_a_null_count, record_b_null_count}: - raise OutputValidationError( # noqa: TRY003 + raise OutputValidationError( "At least one or both record column(s) ['record_a', 'record_b'] " "in the collated dataset are empty." ) + # check that timdex_record_id column is unique + non_unique_count = fetch_single_value( + """ + SELECT COUNT(*) + FROM ( + SELECT timdex_record_id + FROM collated + GROUP BY timdex_record_id + HAVING COUNT(*) > 1 + ) as duplicates; + """ + ) + if non_unique_count > 0: + raise OutputValidationError( + "The collated dataset contains duplicate 'timdex_record_id' records." + ) + def get_transform_version(transformed_filepath: str) -> str: """Get A/B transform version, either 'a' or 'b'.""" - match_result = re.match( - r".*transformed\/(.*)\/.*.json", - transformed_filepath, - ) + match_result = re.match(r".*transformed\/(.*)\/.*", transformed_filepath) if not match_result: - raise ValueError( # noqa: TRY003 - f"Transformed filepath is invalid: {transformed_filepath}." - ) + raise ValueError(f"Transformed filepath is invalid: {transformed_filepath}.") return match_result.groups()[0] diff --git a/abdiff/core/run_ab_transforms.py b/abdiff/core/run_ab_transforms.py index 28a447b..cfd21d8 100644 --- a/abdiff/core/run_ab_transforms.py +++ b/abdiff/core/run_ab_transforms.py @@ -110,7 +110,7 @@ def run_ab_transforms( "to complete successfully." ) ab_transformed_file_lists = get_transformed_files(run_directory) - validate_output(ab_transformed_file_lists, len(input_files)) + validate_output(ab_transformed_file_lists, input_files) # write and return results run_data = { @@ -278,11 +278,11 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]: Returns: tuple[list[str]]: Tuple containing lists of paths to transformed - JSON files for each image, relative to 'run_directory'. + JSON and TXT (deletions) files for each image, relative to 'run_directory'. """ ordered_files = [] for version in ["a", "b"]: - absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*.json") + absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*") relative_filepaths = [ os.path.relpath(file, run_directory) for file in absolute_filepaths ] @@ -291,24 +291,39 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]: def validate_output( - ab_transformed_file_lists: tuple[list[str], ...], input_files_count: int + ab_transformed_file_lists: tuple[list[str], ...], input_files: list[str] ) -> None: """Validate the output of run_ab_transforms. - This function checks that the number of files in each of the A/B - transformed file directories matches the number of input files - provided to run_ab_transforms (i.e., the expected number of - files that are transformed). + Transmogrifier produces JSON files for records that need indexing, and TXT files for + records that need deletion. Every run of Transmogrifier should produce one OR both of + these. Some TIMDEX sources provide one file to Transmogrifier that contains both + records to index and delete, and others provide separate files for each. + + The net effect for validation is that, given an input file, we should expect to see + 1+ files in the A and B output for that input file, ignoring if it's records to index + or delete. """ - if any( - len(transformed_files) != input_files_count - for transformed_files in ab_transformed_file_lists - ): - raise OutputValidationError( # noqa: TRY003 - "At least one or more transformed JSON file(s) are missing. " - f"Expecting {input_files_count} transformed JSON file(s) per A/B version. " - "Check the transformed file directories." - ) + for input_file in input_files: + file_parts = parse_timdex_filename(input_file) + logger.debug(f"Validating output for input file root: {file_parts}") + + file_found = False + for version_files in ab_transformed_file_lists: + for version_file in version_files: + if ( + file_parts["source"] in version_file # type: ignore[operator] + and file_parts["run-date"] in version_file # type: ignore[operator] + and file_parts["run-type"] in version_file # type: ignore[operator] + and (not file_parts["index"] or file_parts["index"] in version_file) + ): + file_found = True + break + + if not file_found: + raise OutputValidationError( # noqa: TRY003 + f"Transmogrifier output was not found for input file '{input_file}'" + ) def get_transformed_filename(filename_details: dict) -> str: @@ -318,7 +333,7 @@ def get_transformed_filename(filename_details: dict) -> str: index=f"_{sequence}" if (sequence := filename_details["index"]) else "", ) output_filename = ( - "{source}-{run_date}-{run_type}-{stage}-records-to-index{index}.{file_type}" + "{source}-{run_date}-{run_type}-{stage}-records-to-{action}{index}.json" ) return output_filename.format( source=filename_details["source"], @@ -326,5 +341,5 @@ def get_transformed_filename(filename_details: dict) -> str: run_type=filename_details["run-type"], stage=filename_details["stage"], index=filename_details["index"], - file_type=filename_details["file_type"], + action=filename_details["action"], ) diff --git a/tests/conftest.py b/tests/conftest.py index 38f7561..7084a89 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -334,6 +334,9 @@ def collated_dataset_directory(run_directory): { "timdex_record_id": "abc123", "source": "alma", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", "record_a": json.dumps( {"material": "concrete", "color": "green", "number": 42} ).encode(), @@ -344,6 +347,9 @@ def collated_dataset_directory(run_directory): { "timdex_record_id": "def456", "source": "dspace", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", "record_a": json.dumps( {"material": "concrete", "color": "blue", "number": 101} ).encode(), @@ -354,6 +360,9 @@ def collated_dataset_directory(run_directory): { "timdex_record_id": "ghi789", "source": "libguides", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", "record_a": json.dumps( { "material": "concrete", @@ -371,7 +380,101 @@ def collated_dataset_directory(run_directory): write_to_dataset( pa.Table.from_pandas(df), base_dir=dataset_directory, - partition_columns=["source"], + ) + return dataset_directory + + +@pytest.fixture +def collated_with_dupe_dataset_directory(run_directory): + """Simulate the outputs of core function collate_ab_transforms.""" + dataset_directory = str(Path(run_directory) / "collated") + df = pd.DataFrame( + [ + { + "timdex_record_id": "abc123", + "source": "alma", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", + "record_a": json.dumps( + {"material": "concrete", "color": "green", "number": 42} + ).encode(), + "record_b": json.dumps( + {"material": "concrete", "color": "red", "number": 42} + ).encode(), + }, + { + "timdex_record_id": "def456", + "source": "dspace", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", + "record_a": json.dumps( + {"material": "concrete", "color": "blue", "number": 101} + ).encode(), + "record_b": json.dumps( + {"material": "concrete", "color": "blue", "number": 101} + ).encode(), + }, + { + "timdex_record_id": "def456", + "source": "dspace", + "run_date": "2024-10-02", + "run_type": "full", + "action": "delete", + "record_a": json.dumps( + {"material": "concrete", "color": "blue", "number": 101} + ).encode(), + "record_b": json.dumps( + {"material": "concrete", "color": "blue", "number": 101} + ).encode(), + }, + { + "timdex_record_id": "ghi789", + "source": "libguides", + "run_date": "2024-10-01", + "run_type": "full", + "action": "index", + "record_a": json.dumps( + { + "material": "concrete", + "color": "purple", + "number": 13, + "fruit": "apple", + } + ).encode(), + "record_b": json.dumps( + {"material": "concrete", "color": "brown", "number": 99} + ).encode(), + }, + { + "timdex_record_id": "ghi789", + "source": "libguides", + "run_date": "2024-10-02", + "run_type": "daily", + "action": "index", + "record_a": json.dumps( + { + "material": "stucco", + "color": "green", + "number": 42, + "fruit": "banana", + } + ).encode(), + "record_b": json.dumps( + { + "material": "stucco", + "color": "green", + "number": 42, + "fruit": "banana", + } + ).encode(), + }, + ] + ) + write_to_dataset( + pa.Table.from_pandas(df), + base_dir=dataset_directory, ) return dataset_directory diff --git a/tests/test_collate_ab_transforms.py b/tests/test_collate_ab_transforms.py index 80f16ac..b95244b 100644 --- a/tests/test_collate_ab_transforms.py +++ b/tests/test_collate_ab_transforms.py @@ -1,4 +1,5 @@ # ruff: noqa: PLR2004 +import json import os import re from pathlib import Path @@ -9,10 +10,11 @@ import pytest from abdiff.core.collate_ab_transforms import ( - JOINED_DATASET_SCHEMA, + COLLATED_DATASET_SCHEMA, READ_BATCH_SIZE, TRANSFORMED_DATASET_SCHEMA, collate_ab_transforms, + get_deduped_batches_iter, get_joined_batches_iter, get_transform_version, get_transformed_batches_iter, @@ -73,13 +75,16 @@ def test_get_transformed_records_iter_success(example_transformed_directory): ) timdex_record_dict = next(records_iter) - assert list(timdex_record_dict.keys()) == [ + assert set(timdex_record_dict.keys()) == { "timdex_record_id", "source", + "run_date", + "run_type", + "action", "record", "version", "transformed_file_name", - ] + } assert isinstance(timdex_record_dict["record"], bytes) assert timdex_record_dict["version"] == "a" assert ( @@ -99,7 +104,7 @@ def test_get_transformed_batches_iter_success( assert isinstance(transformed_batch, pa.RecordBatch) assert transformed_batch.num_rows <= READ_BATCH_SIZE - assert transformed_batch.schema == TRANSFORMED_DATASET_SCHEMA + assert set(transformed_batch.schema.names) == set(TRANSFORMED_DATASET_SCHEMA.names) def test_get_joined_batches_iter_success(transformed_parquet_dataset): @@ -120,7 +125,20 @@ def test_get_joined_batches_iter_success(transformed_parquet_dataset): joined_batch = joined_batches[0] assert isinstance(joined_batch, pa.RecordBatch) assert joined_batch.num_rows <= max_rows_per_file - assert joined_batch.schema == JOINED_DATASET_SCHEMA + assert joined_batch.schema.names == COLLATED_DATASET_SCHEMA.names + + +def test_get_deduped_batches_iter_success(collated_with_dupe_dataset_directory): + deduped_batches_iter = get_deduped_batches_iter(collated_with_dupe_dataset_directory) + deduped_df = next(deduped_batches_iter).to_pandas() + + # assert record 'def456' was dropped because most recent is action=delete + assert len(deduped_df) == 2 + assert set(deduped_df.timdex_record_id) == {"abc123", "ghi789"} + + # assert record 'ghi789' has most recent 2024-10-02 version + deduped_record = deduped_df.set_index("timdex_record_id").loc["ghi789"] + assert json.loads(deduped_record.record_a)["material"] == "stucco" def test_validate_output_success(collated_dataset_directory): @@ -128,7 +146,7 @@ def test_validate_output_success(collated_dataset_directory): def test_validate_output_raises_error_if_dataset_is_empty(run_directory): - empty_table = pa.Table.from_batches(batches=[], schema=JOINED_DATASET_SCHEMA) + empty_table = pa.Table.from_batches(batches=[], schema=COLLATED_DATASET_SCHEMA) empty_dataset_path = Path(run_directory) / "empty_dataset" os.makedirs(empty_dataset_path) @@ -151,7 +169,7 @@ def test_validate_output_raises_error_if_missing_record_column(run_directory): "record_b": None, } ], - schema=JOINED_DATASET_SCHEMA, + schema=COLLATED_DATASET_SCHEMA, ) missing_record_cols_dataset_path = Path(run_directory) / "missing_record_cols_dataset" @@ -171,6 +189,16 @@ def test_validate_output_raises_error_if_missing_record_column(run_directory): validate_output(dataset_path=missing_record_cols_dataset_path) +def test_validate_output_raises_error_if_duplicate_records( + collated_with_dupe_dataset_directory, +): + with pytest.raises( + OutputValidationError, + match="The collated dataset contains duplicate 'timdex_record_id' records.", + ): + validate_output(dataset_path=collated_with_dupe_dataset_directory) + + def test_get_transform_version_success(transformed_directories, output_filename): transformed_directory_a, transformed_directory_b = transformed_directories transformed_file_a = str(Path(transformed_directory_a) / output_filename) diff --git a/tests/test_run_ab_transforms.py b/tests/test_run_ab_transforms.py index fa82e88..73012aa 100644 --- a/tests/test_run_ab_transforms.py +++ b/tests/test_run_ab_transforms.py @@ -166,19 +166,86 @@ def test_get_transformed_files_success( ) -def test_validate_output_success(): +@pytest.mark.parametrize( + ("ab_files", "input_files"), + [ + # single JSON from single file + ( + ( + ["dspace-2024-04-10-daily-extracted-records-to-index.json"], + ["dspace-2024-04-10-daily-extracted-records-to-index.json"], + ), + ["s3://X/dspace-2024-04-10-daily-extracted-records-to-index.xml"], + ), + # JSON and TXT from single file + ( + ( + [ + "dspace-2024-04-10-daily-extracted-records-to-index.json", + "dspace-2024-04-10-daily-extracted-records-to-delete.txt", + ], + [ + "dspace-2024-04-10-daily-extracted-records-to-index.json", + "dspace-2024-04-10-daily-extracted-records-to-delete.txt", + ], + ), + ["s3://X/dspace-2024-04-10-daily-extracted-records-to-index.xml"], + ), + # handles indexed files when multiple + ( + ( + ["alma-2024-04-10-daily-extracted-records-to-index_09.json"], + ["alma-2024-04-10-daily-extracted-records-to-index_09.json"], + ), + ["s3://X/alma-2024-04-10-daily-extracted-records-to-index_09.xml"], + ), + # handles deletes only for alma deletes + ( + ( + ["alma-2024-04-10-daily-extracted-records-to-delete.txt"], + ["alma-2024-04-10-daily-extracted-records-to-delete.txt"], + ), + ["s3://X/alma-2024-04-10-daily-extracted-records-to-delete.xml"], + ), + ], +) +def test_validate_output_success(ab_files, input_files): assert ( validate_output( - ab_transformed_file_lists=(["transformed/a/file1"], ["transformed/b/file2"]), - input_files_count=1, + ab_transformed_file_lists=ab_files, + input_files=input_files, ) is None ) -def test_validate_output_error(): +@pytest.mark.parametrize( + ("ab_files", "input_files"), + [ + # nothing returned + ( + ([], []), + ["s3://X/dspace-2024-04-10-daily-extracted-records-to-index.xml"], + ), + # output files don't have index, or wrong index, so not direct match + ( + ( + [ + "alma-2024-04-10-daily-extracted-records-to-index.json", + "alma-2024-04-10-daily-extracted-records-to-index_04.json", + ], + [ + "alma-2024-04-10-daily-extracted-records-to-index.json", + "alma-2024-04-10-daily-extracted-records-to-index_04.json", + ], + ), + ["s3://X/alma-2024-04-10-daily-extracted-records-to-index_09.xml"], + ), + ], +) +def test_validate_output_error(ab_files, input_files): with pytest.raises(OutputValidationError): - validate_output(ab_transformed_file_lists=([], []), input_files_count=1) + validate_output(ab_transformed_file_lists=ab_files, input_files=input_files) def test_get_output_filename_success(): @@ -191,10 +258,10 @@ def test_get_output_filename_success(): "stage": "extracted", "action": "index", "index": None, - "file_type": "xml", + "file_type": "json", } ) - == "source-2024-01-01-full-transformed-records-to-index.xml" + == "source-2024-01-01-full-transformed-records-to-index.json" ) @@ -297,8 +364,8 @@ def test_get_output_filename_indexed_success(): "stage": "extracted", "action": "index", "index": "01", - "file_type": "xml", + "file_type": "json", } ) - == "source-2024-01-01-full-transformed-records-to-index_01.xml" + == "source-2024-01-01-full-transformed-records-to-index_01.json" )