Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 156 additions & 35 deletions abdiff/core/collate_ab_transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# ruff: noqa: TRY003

import itertools
import json
import logging
Expand All @@ -8,6 +10,7 @@

import duckdb
import ijson
import pandas as pd
import pyarrow as pa

from abdiff.core.exceptions import OutputValidationError
Expand All @@ -20,15 +23,21 @@
(
pa.field("timdex_record_id", pa.string()),
pa.field("source", pa.string()),
pa.field("run_date", pa.date32()),
pa.field("run_type", pa.string()),
pa.field("action", pa.string()),
pa.field("record", pa.binary()),
pa.field("version", pa.string()),
pa.field("transformed_file_name", pa.string()),
)
)
JOINED_DATASET_SCHEMA = pa.schema(
COLLATED_DATASET_SCHEMA = pa.schema(
(
pa.field("timdex_record_id", pa.string()),
pa.field("source", pa.string()),
pa.field("run_date", pa.date32()),
pa.field("run_type", pa.string()),
pa.field("action", pa.string()),
pa.field("record_a", pa.binary()),
pa.field("record_b", pa.binary()),
)
Expand All @@ -40,18 +49,19 @@ def collate_ab_transforms(
) -> str:
"""Collates A/B transformed files into a Parquet dataset.

This process can be summarized into two (2) important steps:
This process can be summarized into two (3) important steps:
1. Write all transformed JSON records into a temporary Parquet dataset
partitioned by the transformed file name.
2. For every transformed file, use DuckDB to join A/B Parquet tables
using the TIMDEX record ID and write joined records to a Parquet dataset.

This function (and its subfunctions) uses DuckDB, generators, and batching to
write records to Parquet datasets in a memory-efficient manner.
3. Dedupe joined records to ensure that only the most recent, not "deleted"
timdex_record_id is present in final output.
"""
transformed_dataset_path = tempfile.TemporaryDirectory()
joined_dataset_path = tempfile.TemporaryDirectory()
collated_dataset_path = str(Path(run_directory) / "collated")

# build temporary transformed dataset
transformed_written_files = write_to_dataset(
get_transformed_batches_iter(run_directory, ab_transformed_file_lists),
schema=TRANSFORMED_DATASET_SCHEMA,
Expand All @@ -62,15 +72,30 @@ def collate_ab_transforms(
f"Wrote {len(transformed_written_files)} parquet file(s) to transformed dataset"
)

# build temporary collated dataset
joined_written_files = write_to_dataset(
get_joined_batches_iter(transformed_dataset_path.name),
base_dir=collated_dataset_path,
schema=JOINED_DATASET_SCHEMA,
base_dir=joined_dataset_path.name,
schema=COLLATED_DATASET_SCHEMA,
)
logger.info(f"Wrote {len(joined_written_files)} parquet file(s) to collated dataset")

# build final deduped and collated dataset
deduped_written_files = write_to_dataset(
get_deduped_batches_iter(joined_dataset_path.name),
base_dir=collated_dataset_path,
schema=COLLATED_DATASET_SCHEMA,
)
logger.info(
f"Wrote {len(deduped_written_files)} parquet file(s) to deduped collated dataset"
)

validate_output(collated_dataset_path)

# ensure temporary artifacts removed
transformed_dataset_path.cleanup()
joined_dataset_path.cleanup()

return collated_dataset_path


Expand All @@ -85,22 +110,44 @@ def get_transformed_records_iter(

* timdex_record_id: The TIMDEX record ID.
* source: The shorthand name of the source as denoted in by Transmogrifier
(see https://github.com/MITLibraries/transmogrifier/blob/main/transmogrifier/config.py).
* run_date: Run date from TIMDEX ETL
* run_type: "full" or "daily"
* action: "index" or "delete"
* record: The TIMDEX record serialized to a JSON string then encoded to bytes.
* version: The version of the transform, parsed from the absolute filepath to a
transformed file.
* transformed_file_name: The name of the transformed file, excluding file extension.
"""
version = get_transform_version(transformed_file)
filename_details = parse_timdex_filename(transformed_file)
with open(transformed_file, "rb") as file:
for record in ijson.items(file, "item"):

base_record = {
"source": filename_details["source"],
"run_date": filename_details["run-date"],
"run_type": filename_details["run-type"],
"action": filename_details["action"],
"version": version,
"transformed_file_name": transformed_file.split("/")[-1],
}
Comment on lines +124 to +131
Copy link
Contributor Author

@ghukill ghukill Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These fields in the collated dataset are shared between JSON and TXT files, so broken out here.


# handle JSON files with records to index
if transformed_file.endswith(".json"):
with open(transformed_file, "rb") as file:
for record in ijson.items(file, "item"):
yield {
**base_record,
"timdex_record_id": record["timdex_record_id"],
"record": json.dumps(record).encode(),
}

# handle TXT files with records to delete
else:
deleted_records_df = pd.read_csv(transformed_file, header=None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike a JSON file to iterate over, we just have a CSV with record IDs. So using pandas to quickly parse and loop through those values.

for row in deleted_records_df.itertuples():
yield {
"timdex_record_id": record["timdex_record_id"],
"source": filename_details["source"],
"record": json.dumps(record).encode(),
"version": version,
"transformed_file_name": transformed_file.split("/")[-1],
**base_record,
"timdex_record_id": row[1],
"record": None,
Copy link
Contributor Author

@ghukill ghukill Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is None, because we don't care about a deleted record's actual record body! If it's the last instance of that record in the run, then it will be removed entirely. Otherwise, the more recent version, which will have a record, will be utilized.

}


Expand Down Expand Up @@ -192,6 +239,9 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]
SELECT
COALESCE(a.timdex_record_id, b.timdex_record_id) timdex_record_id,
COALESCE(a.source, b.source) source,
COALESCE(a.run_date, b.run_date) run_date,
COALESCE(a.run_type, b.run_type) run_type,
COALESCE(a.action, b.action) "action",
a.record as record_a,
b.record as record_b
FROM a
Expand All @@ -210,15 +260,79 @@ def get_joined_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]
break


def get_deduped_batches_iter(dataset_directory: str) -> Generator[pa.RecordBatch]:
"""Yield pyarrow.RecordBatch objects of deduped rows from the joined dataset.

ABDiff should be able to handle many input files, where a single timdex_record_id may
be duplicated across multiple files ("full" vs "daily" runs, incrementing date runs,
etc.)

This function writes the final dataset by deduping records from the temporary collated
dataset, given the following logic:
- use the MOST RECENT record based on 'run_date'
- if the MOST RECENT record is action='delete', then omit record entirely

The same mechanism is used by get_joined_batches_iter() to perform a DuckDB query then
stream write batches to a parquet dataset.
"""
with duckdb.connect(":memory:") as con:

results = con.execute(
"""
WITH collated as (
select * from read_parquet($collated_parquet_glob, hive_partitioning=true)
),
latest_records AS (
SELECT
*,
ROW_NUMBER() OVER (
PARTITION BY timdex_record_id
ORDER BY run_date DESC
) AS rn
FROM collated
),
deduped_records AS (
SELECT *
FROM latest_records
WHERE rn = 1 AND action != 'delete'
)
SELECT
timdex_record_id,
source,
run_date,
run_type,
action,
record_a,
record_b
FROM deduped_records;
""",
{
"collated_parquet_glob": f"{dataset_directory}/**/*.parquet",
},
).fetch_record_batch(READ_BATCH_SIZE)

while True:
try:
yield results.read_next_batch()
except StopIteration:
break # pragma: nocover


def validate_output(dataset_path: str) -> None:
"""Validate the output of collate_ab_transforms.

This function checks whether the collated dataset is empty
and whether any or both 'record_a' or 'record_b' columns are
totally empty.
"""

def fetch_single_value(query: str) -> int:
result = con.execute(query).fetchone()
if result is None:
raise RuntimeError(f"Query returned no results: {query}") # pragma: nocover
return int(result[0])
Comment on lines +329 to +333
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's possible this could even be a more global abdiff.core.utils helper function... but this felt acceptable for the time being.

Using fetchone()[0] certainly makes sense logically, but type hinting doesn't like it. This was an attempt to remove a handful of typing and ruff ignores.


with duckdb.connect(":memory:") as con:
# create view of collated table
con.execute(
f"""
CREATE VIEW collated AS (
Expand All @@ -228,41 +342,48 @@ def validate_output(dataset_path: str) -> None:
)

# check if the table is empty
record_count = con.execute("SELECT COUNT(*) FROM collated").fetchone()[0] # type: ignore[index]
record_count = fetch_single_value("SELECT COUNT(*) FROM collated")
if record_count == 0:
raise OutputValidationError( # noqa: TRY003
raise OutputValidationError(
"The collated dataset does not contain any records."
)

# check if any of the 'record_*' columns are empty
record_a_null_count = con.execute(
record_a_null_count = fetch_single_value(
"SELECT COUNT(*) FROM collated WHERE record_a ISNULL"
).fetchone()[
0
] # type: ignore[index]

record_b_null_count = con.execute(
)
record_b_null_count = fetch_single_value(
"SELECT COUNT(*) FROM collated WHERE record_b ISNULL"
).fetchone()[
0
] # type: ignore[index]
)

if record_count in {record_a_null_count, record_b_null_count}:
raise OutputValidationError( # noqa: TRY003
raise OutputValidationError(
"At least one or both record column(s) ['record_a', 'record_b'] "
"in the collated dataset are empty."
)

# check that timdex_record_id column is unique
non_unique_count = fetch_single_value(
"""
SELECT COUNT(*)
FROM (
SELECT timdex_record_id
FROM collated
GROUP BY timdex_record_id
HAVING COUNT(*) > 1
) as duplicates;
"""
)
if non_unique_count > 0:
raise OutputValidationError(
"The collated dataset contains duplicate 'timdex_record_id' records."
)
Comment on lines +377 to +380
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear if this is too much to ask (please push back if so), but it seems it would be helpful to know which timdex_record_id is duplicated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hear ya, but I would pushback, for reasons of scale. If this were not working correctly, it's conceivable that 10's, 100's, thousands of records could be duplicated. I would posit it's sufficient that if any records are duplicated, something is intrinsically wrong with the deduplication logic; it's that that needs attention, and not a specific record.

Counter-point: we could show a sample like 10 records. And then during debugging that work, you could look for those? But... I suppose my preference would be to skip that for now, unless we have trouble with this in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me, the scale argument makes perfect sense! Agree there's no need to add unless we find it to be a problem



def get_transform_version(transformed_filepath: str) -> str:
"""Get A/B transform version, either 'a' or 'b'."""
match_result = re.match(
r".*transformed\/(.*)\/.*.json",
transformed_filepath,
)
match_result = re.match(r".*transformed\/(.*)\/.*", transformed_filepath)
if not match_result:
raise ValueError( # noqa: TRY003
f"Transformed filepath is invalid: {transformed_filepath}."
)
raise ValueError(f"Transformed filepath is invalid: {transformed_filepath}.")

return match_result.groups()[0]
53 changes: 34 additions & 19 deletions abdiff/core/run_ab_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def run_ab_transforms(
"to complete successfully."
)
ab_transformed_file_lists = get_transformed_files(run_directory)
validate_output(ab_transformed_file_lists, len(input_files))
validate_output(ab_transformed_file_lists, input_files)

# write and return results
run_data = {
Expand Down Expand Up @@ -278,11 +278,11 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]:

Returns:
tuple[list[str]]: Tuple containing lists of paths to transformed
JSON files for each image, relative to 'run_directory'.
JSON and TXT (deletions) files for each image, relative to 'run_directory'.
"""
ordered_files = []
for version in ["a", "b"]:
absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*.json")
absolute_filepaths = glob.glob(f"{run_directory}/transformed/{version}/*")
relative_filepaths = [
os.path.relpath(file, run_directory) for file in absolute_filepaths
]
Expand All @@ -291,24 +291,39 @@ def get_transformed_files(run_directory: str) -> tuple[list[str], ...]:


def validate_output(
ab_transformed_file_lists: tuple[list[str], ...], input_files_count: int
ab_transformed_file_lists: tuple[list[str], ...], input_files: list[str]
) -> None:
"""Validate the output of run_ab_transforms.

This function checks that the number of files in each of the A/B
transformed file directories matches the number of input files
provided to run_ab_transforms (i.e., the expected number of
files that are transformed).
Transmogrifier produces JSON files for records that need indexing, and TXT files for
records that need deletion. Every run of Transmogrifier should produce one OR both of
these. Some TIMDEX sources provide one file to Transmogrifier that contains both
records to index and delete, and others provide separate files for each.

The net effect for validation is that, given an input file, we should expect to see
1+ files in the A and B output for that input file, ignoring if it's records to index
or delete.
"""
if any(
len(transformed_files) != input_files_count
for transformed_files in ab_transformed_file_lists
):
raise OutputValidationError( # noqa: TRY003
"At least one or more transformed JSON file(s) are missing. "
f"Expecting {input_files_count} transformed JSON file(s) per A/B version. "
"Check the transformed file directories."
)
for input_file in input_files:
file_parts = parse_timdex_filename(input_file)
logger.debug(f"Validating output for input file root: {file_parts}")

file_found = False
for version_files in ab_transformed_file_lists:
for version_file in version_files:
if (
file_parts["source"] in version_file # type: ignore[operator]
and file_parts["run-date"] in version_file # type: ignore[operator]
and file_parts["run-type"] in version_file # type: ignore[operator]
and (not file_parts["index"] or file_parts["index"] in version_file)
):
Comment on lines +314 to +319
Copy link
Contributor Author

@ghukill ghukill Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach allows for avoiding finicky regex to see if the input file has artifacts in the output files.

For example, we'd want to know that alma-2024-10-02-full exists in at least one file in the A/B output files. But... for alma it could have an index underscore like ..._01.xml so we'd need to confirm that as well.

If we think of each output filename as containing nuggets of information like source, run-date, run-type, index, etc., then it kind of makes sense that we could look for pieces of information independently, but required in the same filename.

This is obviously kind of a loopy, naive approach to doing this, but the scale of this makes it inconsequential; we're looking at max 2-3k input files, against max 4-6k output files, making this a 1-2 second check tops.

file_found = True
break

if not file_found:
raise OutputValidationError( # noqa: TRY003
f"Transmogrifier output was not found for input file '{input_file}'"
)


def get_transformed_filename(filename_details: dict) -> str:
Expand All @@ -318,13 +333,13 @@ def get_transformed_filename(filename_details: dict) -> str:
index=f"_{sequence}" if (sequence := filename_details["index"]) else "",
)
output_filename = (
"{source}-{run_date}-{run_type}-{stage}-records-to-index{index}.{file_type}"
"{source}-{run_date}-{run_type}-{stage}-records-to-{action}{index}.json"
)
return output_filename.format(
source=filename_details["source"],
run_date=filename_details["run-date"],
run_type=filename_details["run-type"],
stage=filename_details["stage"],
index=filename_details["index"],
file_type=filename_details["file_type"],
action=filename_details["action"],
)
Loading