diff --git a/.gitignore b/.gitignore index b7faf40..2efd350 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,5 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +.idea diff --git a/README.md b/README.md index d048aee..480cf54 100644 --- a/README.md +++ b/README.md @@ -1 +1,245 @@ -# gtfs_diff_engine \ No newline at end of file +# GTFS Diff Engine + +A memory-efficient Python library and CLI for comparing two GTFS feeds and producing a structured diff conforming to the [GTFS Diff v2 schema](https://github.com/MobilityData/gtfs_diff). + +## Overview + +GTFS Diff Engine compares two GTFS feeds (zip archives or directories) file-by-file and row-by-row, emitting a machine-readable JSON document that describes exactly what changed: which files were added or deleted, which columns appeared or disappeared, and which rows were inserted, removed, or modified (with before/after field values). + +The output conforms to the **GTFS Diff v2 schema** maintained by MobilityData: + + +## Features + +- **Memory-efficient streaming diff** — two-pass CSV indexing; no full in-memory table loads +- **Supports `.zip` archives and plain directories** — including zips with a single sub-directory layout +- **Row-level changes with primary key identification** — each change record includes the primary key fields for the affected row +- **Column-level change tracking** — columns added or deleted between feeds are reported with their original positions +- **Configurable row-changes cap** — limit output size per file; omitted changes are counted in a `Truncated` record +- **CLI and Python API** — use as a command-line tool or import directly in your code + +## Installation + +```bash +pip install gtfs-diff-engine +``` + +For a development (editable) install with test dependencies: + +```bash +git clone https://github.com/your-org/gtfs_diff_engine +cd gtfs_diff_engine +pip install -e ".[dev]" +``` + +## Quick Start + +```python +from gtfs_diff.engine import diff_feeds + +result = diff_feeds("base.zip", "new.zip") +print(result.summary.total_changes) + +# Save to JSON +with open("diff.json", "w") as f: + f.write(result.model_dump_json(indent=2)) +``` + +## CLI Usage + +``` +Usage: gtfs-diff [OPTIONS] BASE_FEED NEW_FEED + + Compare two GTFS feeds (zip or directory) and output a JSON diff. + + BASE_FEED: path to the base GTFS feed (zip or directory) + NEW_FEED: path to the new GTFS feed (zip or directory) + +Options: + --version Show the version and exit. + -o, --output FILE Write JSON output to FILE instead of stdout. + -c, --cap INTEGER Max row changes per file (0 = omit row-level + detail). + --pretty / --no-pretty Pretty-print JSON (default: --pretty). + --base-downloaded-at TEXT ISO 8601 datetime for when base was downloaded. + --new-downloaded-at TEXT ISO 8601 datetime for when new was downloaded. + --help Show this message and exit. +``` + +**Examples:** + +```bash +# Basic usage — print diff to stdout +gtfs-diff base.zip new.zip + +# Cap row changes to 500 per file +gtfs-diff --cap 500 base.zip new.zip + +# Save output to a file +gtfs-diff -o diff.json base.zip new.zip + +# Omit row-level detail (column diffs and counts are still computed) +gtfs-diff --cap 0 base.zip new.zip + +# With feed download timestamps +gtfs-diff --base-downloaded-at 2024-01-01T00:00:00Z \ + --new-downloaded-at 2024-06-01T00:00:00Z \ + base.zip new.zip +``` + +## Python API Reference + +### `diff_feeds()` + +```python +def diff_feeds( + base_path: str | Path, + new_path: str | Path, + row_changes_cap_per_file: int | None = None, + base_downloaded_at: datetime | None = None, + new_downloaded_at: datetime | None = None, +) -> GtfsDiff +``` + +| Parameter | Type | Description | +|---|---|---| +| `base_path` | `str \| Path` | Path to the base (old) GTFS feed — zip or directory | +| `new_path` | `str \| Path` | Path to the new GTFS feed — zip or directory | +| `row_changes_cap_per_file` | `int \| None` | `None` = include all; `0` = omit row detail; `N` = cap at N per file | +| `base_downloaded_at` | `datetime \| None` | When the base feed was downloaded (defaults to now) | +| `new_downloaded_at` | `datetime \| None` | When the new feed was downloaded (defaults to now) | + +**Returns:** a `GtfsDiff` Pydantic model with three top-level fields: + +| Field | Type | Description | +|---|---|---| +| `metadata` | `Metadata` | Schema version, timestamps, feed sources, unsupported files | +| `summary` | `Summary` | Aggregate counts of changed files, rows, columns | +| `file_diffs` | `list[FileDiff]` | Per-file diff records | + +## Supported GTFS Files + +| File | Primary Key | +|---|---| +| `agency.txt` | `agency_id` | +| `stops.txt` | `stop_id` | +| `routes.txt` | `route_id` | +| `trips.txt` | `trip_id` | +| `stop_times.txt` | `trip_id`, `stop_sequence` | +| `calendar.txt` | `service_id` | +| `calendar_dates.txt` | `service_id`, `date` | +| `fare_attributes.txt` | `fare_id` | +| `fare_rules.txt` | `fare_id`, `route_id`, `origin_id`, `destination_id`, `contains_id` | +| `shapes.txt` | `shape_id`, `shape_pt_sequence` | +| `frequencies.txt` | `trip_id`, `start_time` | +| `transfers.txt` | `from_stop_id`, `to_stop_id`, `from_route_id`, `to_route_id`, `from_trip_id`, `to_trip_id` | +| `pathways.txt` | `pathway_id` | +| `levels.txt` | `level_id` | +| `feed_info.txt` | *(all columns — single-row file)* | +| `translations.txt` | `table_name`, `field_name`, `language`, `record_id`, `record_sub_id`, `field_value` | +| `attributions.txt` | `attribution_id` | +| `areas.txt` | `area_id` | +| `stop_areas.txt` | `area_id`, `stop_id` | +| `networks.txt` | `network_id` | +| `route_networks.txt` | `route_id` | +| `fare_media.txt` | `fare_media_id` | +| `fare_products.txt` | `fare_product_id` | +| `fare_leg_rules.txt` | `leg_group_id` | +| `fare_transfer_rules.txt` | `from_leg_group_id`, `to_leg_group_id`, `transfer_count`, `duration_limit` | +| `timeframes.txt` | `timeframe_group_id`, `start_time`, `end_time`, `service_id` | +| `rider_categories.txt` | `rider_category_id` | +| `booking_rules.txt` | `booking_rule_id` | +| `location_groups.txt` | `location_group_id` | +| `location_group_stops.txt` | `location_group_id`, `stop_id` | + +Files not in this table (e.g. GeoJSON flex locations) are recorded in `metadata.unsupported_files` and skipped. + +## Output Schema + +The output follows the GTFS Diff v2 schema. Below is a minimal example: + +```json +{ + "metadata": { + "schema_version": "2.0", + "generated_at": "2024-06-01T12:00:00Z", + "row_changes_cap_per_file": null, + "base_feed": { "source": "base.zip", "downloaded_at": "2024-01-01T00:00:00Z" }, + "new_feed": { "source": "new.zip", "downloaded_at": "2024-06-01T00:00:00Z" }, + "unsupported_files": [] + }, + "summary": { + "total_changes": 3, + "files_added_count": 0, + "files_deleted_count": 0, + "files_modified_count": 1, + "files": [ + { + "file_name": "stops.txt", + "status": "modified", + "columns_added_count": 0, + "columns_deleted_count": 0, + "rows_added_count": 1, + "rows_deleted_count": 0, + "rows_modified_count": 2 + } + ] + }, + "file_diffs": [ + { + "file_name": "stops.txt", + "file_action": "modified", + "columns_added": [], + "columns_deleted": [], + "row_changes": { + "primary_key": ["stop_id"], + "columns": ["stop_id", "stop_name", "stop_lat", "stop_lon"], + "added": [ + { + "identifier": { "stop_id": "S999" }, + "raw_value": "S999,New Stop,48.8566,2.3522", + "new_line_number": 42 + } + ], + "deleted": [], + "modified": [ + { + "identifier": { "stop_id": "S001" }, + "raw_value": "S001,Central Station,48.8600,2.3470", + "base_line_number": 5, + "new_line_number": 5, + "field_changes": [ + { "field": "stop_name", "base_value": "Central Stn", "new_value": "Central Station" } + ] + } + ] + }, + "truncated": null + } + ] +} +``` + +## Memory Efficiency + +The engine uses a **streaming two-pass algorithm**: + +1. **Pass 1 (base feed):** stream the CSV line by line, building a `primary_key_tuple → (line_number, raw_csv_string)` in-memory index. +2. **Pass 2 (new feed):** same, producing a second index. +3. **Set arithmetic:** `added = new_keys − base_keys`, `deleted = base_keys − new_keys`, `common = intersection`. +4. **Modified detection:** for keys in `common`, parse the stored raw lines and compare only the *shared* columns — this avoids flagging every row as changed when a column is added or removed. + +Only the raw CSV strings are stored in the index (not parsed dicts), keeping memory proportional to the number of rows rather than rows × columns. + +> **Note:** For very large feeds (`stop_times.txt` with 10 M+ rows) the in-memory index may become a bottleneck. A disk-backed index (e.g. SQLite) would be more appropriate for production deployments at that scale; that optimisation is left as future work. + +## Running Tests + +```bash +pip install -e ".[dev]" +pytest tests/ -v +``` + +## License + +See [LICENSE](LICENSE). diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..aab3127 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,125 @@ +# Architecture + +## Design Goals + +- **Schema compliance** — all output conforms to the GTFS Diff v2 schema () +- **Memory efficiency** — stream CSV files rather than loading entire tables; store only compact raw-CSV strings in the key index +- **Clear public API** — a single `diff_feeds()` function returns a typed Pydantic model, suitable for programmatic use or JSON serialisation + +## Module Structure + +| Module | Responsibility | +|---|---| +| `engine.py` | Core diff logic: feed opener, CSV indexing, per-file diff, public `diff_feeds()` function | +| `models.py` | Pydantic v2 data models for the GTFS Diff v2 output format (`GtfsDiff`, `FileDiff`, `RowChanges`, etc.) | +| `gtfs_definitions.py` | Static registry of supported GTFS files and their primary key columns; `get_primary_key()` helper | +| `cli.py` | Click-based CLI entry point (`gtfs-diff`); thin wrapper around `diff_feeds()` | + +## Streaming Algorithm + +The per-file diff (`_diff_file`) operates in two streaming passes: + +### Pass 1 — index the base feed + +``` +for each row in base CSV: + pk_tuple = tuple(row[col] for col in pk_columns) + index[pk_tuple] = (line_number, raw_csv_string) +``` + +Only the compact `raw_csv_string` (re-serialised with `csv.writer`) is stored — not the parsed dict — to keep memory proportional to row count rather than row count × column count. + +### Pass 2 — index the new feed + +Identical pass over the new CSV, producing a second index. + +### Set arithmetic + +``` +added_keys = new_keys − base_keys +deleted_keys = base_keys − new_keys +common_keys = base_keys ∩ new_keys +``` + +### Modified detection + +For each key in `common_keys`, the stored raw lines are parsed back to dicts and compared **on shared columns only**: + +```python +shared_cols = [col for col in base_headers if col in new_header_set] +field_changes = [ + FieldChange(field=col, base_value=b_dict[col], new_value=n_dict[col]) + for col in shared_cols + if b_dict[col] != n_dict[col] +] +``` + +Comparing only shared columns ensures that adding or removing a column from a file does not cause every existing row to appear as modified. + +## Handling Edge Cases + +### Files with no explicit primary key + +`feed_info.txt` is a single-row file and `agency.txt` allows omitting `agency_id` when there is only one agency. Both are defined with an empty primary key list (`[]`) in `GTFS_PRIMARY_KEYS`. When the engine encounters an empty PK definition it falls back to using **all base-feed columns** as a composite key: + +```python +if len(pk_def) == 0: + pk_cols = initial_base_headers # all columns form the key +``` + +### BOM handling + +All files are opened with `encoding="utf-8-sig"`, which transparently strips the UTF-8 byte-order mark that some GTFS producers include. + +### Malformed / short rows + +Rows with fewer columns than the header are padded with empty strings; rows wider than the header are trimmed: + +```python +if len(row) < n: + row = row + [""] * (n - len(row)) +row_vals = row[:n] +``` + +### Zip archives with a sub-directory layout + +Some producers wrap all `.txt` files inside a single subdirectory within the zip. The feed opener handles this by mapping `basename → internal_path` regardless of path depth: + +```python +basename = member.rsplit("/", 1)[-1] +``` + +## Cap and Truncation + +When `row_changes_cap_per_file` is set to a positive integer `N`, row changes are filled in priority order until the cap is reached: + +1. **Added** rows (up to `N`) +2. **Deleted** rows (up to `N − len(added)`) +3. **Modified** rows (up to `N − len(added) − len(deleted)`) + +When `cap=0`, row-level detail is omitted entirely (column diffs and true change counts are still computed). + +If the true total exceeds the cap, a `Truncated` record is attached: + +```json +"truncated": { "is_truncated": true, "omitted_count": 4321 } +``` + +The `omitted_count` reflects the number of row changes that were detected but not included in the output. + +## Column Union Ordering + +When a file gains or loses columns between feeds, the `row_changes.columns` list uses a **union** ordering: + +``` +union_columns = base_headers + [col for col in new_headers if col not in base_header_set] +``` + +Base columns appear first (preserving their original order), followed by any new-only columns appended at the end. This ordering is used to align `raw_value` strings for both added and modified rows, so consumers can parse `raw_value` with a single consistent header list. + +## Limitations and Future Work + +- **Disk-backed index for huge feeds** — `stop_times.txt` can exceed 10 M rows; the in-memory index should be replaced with a SQLite-backed approach for production deployments at that scale. +- **Parallel file processing** — files within a feed are currently processed sequentially; parallel workers (e.g. `concurrent.futures.ThreadPoolExecutor`) could reduce wall-clock time for feeds with many files. +- **GeoJSON / Flex location support** — `locations.geojson` and other non-CSV GTFS Flex files are not CSV and are currently reported as unsupported. Dedicated diff logic for these formats is left as future work. + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..64ab3c1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "gtfs-diff-engine" +version = "0.1.0" +description = "A diff engine for GTFS (General Transit Feed Specification) feeds" +requires-python = ">=3.10" +dependencies = [ + "pydantic>=2.0", + "click>=8.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov", +] + +[project.scripts] +gtfs-diff = "gtfs_diff.cli:main" + +[tool.hatch.build.targets.wheel] +packages = ["src/gtfs_diff"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/scripts/compare_feeds.sh b/scripts/compare_feeds.sh new file mode 100755 index 0000000..b0afce1 --- /dev/null +++ b/scripts/compare_feeds.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +# compare_feeds.sh — Download two GTFS feeds by URL and diff them. +# +# Usage: +# ./scripts/compare_feeds.sh [OPTIONS] +# +# Arguments: +# BASE_URL URL of the base (old) GTFS zip feed +# NEW_URL URL of the new GTFS zip feed +# +# Options: +# -o, --output FILE Write JSON diff to FILE (default: stdout) +# -c, --cap INT Max row changes per file (default: no cap; 0 = counts only) +# --no-pretty Compact JSON output (default: pretty-printed) +# -h, --help Show this help message +# +# Examples: +# ./scripts/compare_feeds.sh \ +# https://example.com/gtfs-jan.zip \ +# https://example.com/gtfs-feb.zip +# +# ./scripts/compare_feeds.sh \ +# https://example.com/gtfs-jan.zip \ +# https://example.com/gtfs-feb.zip \ +# --output diff.json --cap 1000 + +set -euo pipefail + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # + +usage() { + sed -n '2,/^$/p' "$0" | sed 's/^# \{0,1\}//' + exit 0 +} + +die() { echo "ERROR: $*" >&2; exit 1; } + +require_cmd() { + command -v "$1" &>/dev/null || die "'$1' is required but not installed." +} + +# --------------------------------------------------------------------------- # +# Argument parsing +# --------------------------------------------------------------------------- # + +BASE_URL="" +NEW_URL="" +OUTPUT_FILE="" +CAP="" +PRETTY="--pretty" + +while [[ $# -gt 0 ]]; do + case "$1" in + -h|--help) usage ;; + -o|--output) OUTPUT_FILE="$2"; shift 2 ;; + -c|--cap) CAP="$2"; shift 2 ;; + --no-pretty) PRETTY="--no-pretty"; shift ;; + -*) die "Unknown option: $1" ;; + *) + if [[ -z "$BASE_URL" ]]; then BASE_URL="$1" + elif [[ -z "$NEW_URL" ]]; then NEW_URL="$1" + else die "Unexpected argument: $1" + fi + shift ;; + esac +done + +[[ -n "$BASE_URL" ]] || die "BASE_URL is required.\nRun with --help for usage." +[[ -n "$NEW_URL" ]] || die "NEW_URL is required.\nRun with --help for usage." + +# --------------------------------------------------------------------------- # +# Dependency checks +# --------------------------------------------------------------------------- # + +require_cmd curl +require_cmd gtfs-diff + +# --------------------------------------------------------------------------- # +# Temp directory (cleaned up on exit) +# --------------------------------------------------------------------------- # + +WORK_DIR="$(mktemp -d)" +trap 'rm -rf "$WORK_DIR"' EXIT + +BASE_ZIP="$WORK_DIR/base.zip" +NEW_ZIP="$WORK_DIR/new.zip" + +# --------------------------------------------------------------------------- # +# Download feeds +# --------------------------------------------------------------------------- # + +echo "⬇ Downloading base feed..." >&2 +curl -fsSL --retry 3 --retry-delay 2 -o "$BASE_ZIP" "$BASE_URL" \ + || die "Failed to download base feed: $BASE_URL" + +echo "⬇ Downloading new feed..." >&2 +curl -fsSL --retry 3 --retry-delay 2 -o "$NEW_ZIP" "$NEW_URL" \ + || die "Failed to download new feed: $NEW_URL" + +# Basic sanity check — both must look like zip files +file "$BASE_ZIP" | grep -qi "zip" || die "Base feed does not appear to be a valid zip: $BASE_URL" +file "$NEW_ZIP" | grep -qi "zip" || die "New feed does not appear to be a valid zip: $NEW_URL" + +echo "✔ Both feeds downloaded." >&2 + +# --------------------------------------------------------------------------- # +# Build gtfs-diff command +# --------------------------------------------------------------------------- # + +CMD=(gtfs-diff "$BASE_ZIP" "$NEW_ZIP" "$PRETTY") + +[[ -n "$CAP" ]] && CMD+=(--cap "$CAP") +[[ -n "$OUTPUT_FILE" ]] && CMD+=(--output "$OUTPUT_FILE") + +# --------------------------------------------------------------------------- # +# Run diff +# --------------------------------------------------------------------------- # + +echo "🔍 Running diff..." >&2 +"${CMD[@]}" + +if [[ -n "$OUTPUT_FILE" ]]; then + echo "✔ Diff written to: $OUTPUT_FILE" >&2 +fi diff --git a/src/gtfs_diff/__init__.py b/src/gtfs_diff/__init__.py new file mode 100644 index 0000000..7561aa6 --- /dev/null +++ b/src/gtfs_diff/__init__.py @@ -0,0 +1,3 @@ +"""GTFS Diff Engine - compare two GTFS feeds and surface structured differences.""" + +__version__ = "0.1.0" diff --git a/src/gtfs_diff/__main__.py b/src/gtfs_diff/__main__.py new file mode 100644 index 0000000..a992e06 --- /dev/null +++ b/src/gtfs_diff/__main__.py @@ -0,0 +1,6 @@ +"""Entry point for `python -m gtfs_diff`.""" + +from gtfs_diff.cli import main + +if __name__ == "__main__": + main() diff --git a/src/gtfs_diff/cli.py b/src/gtfs_diff/cli.py new file mode 100644 index 0000000..6111560 --- /dev/null +++ b/src/gtfs_diff/cli.py @@ -0,0 +1,68 @@ +"""Click-based CLI entry point for gtfs-diff.""" + +import sys +from datetime import datetime +from pathlib import Path + +import click + +from gtfs_diff.engine import diff_feeds + + +@click.command() +@click.version_option(version="0.1.0", prog_name="gtfs-diff-engine") +@click.argument("base_feed", type=click.Path(exists=True, path_type=Path)) +@click.argument("new_feed", type=click.Path(exists=True, path_type=Path)) +@click.option("--output", "-o", type=click.Path(path_type=Path), default=None, + help="Write JSON output to FILE instead of stdout.") +@click.option("--cap", "-c", type=int, default=None, + help="Max row changes per file (0 = omit row-level detail).") +@click.option("--pretty/--no-pretty", default=True, + help="Pretty-print JSON (default: --pretty).") +@click.option("--base-downloaded-at", default=None, + help="ISO 8601 datetime for when base was downloaded.") +@click.option("--new-downloaded-at", default=None, + help="ISO 8601 datetime for when new was downloaded.") +def main( + base_feed: Path, + new_feed: Path, + output: Path | None, + cap: int | None, + pretty: bool, + base_downloaded_at: str | None, + new_downloaded_at: str | None, +) -> None: + """Compare two GTFS feeds (zip or directory) and output a JSON diff. + + BASE_FEED: path to the base GTFS feed (zip or directory)\n + NEW_FEED: path to the new GTFS feed (zip or directory) + """ + try: + base_dt = datetime.fromisoformat(base_downloaded_at) if base_downloaded_at else None + new_dt = datetime.fromisoformat(new_downloaded_at) if new_downloaded_at else None + except ValueError as exc: + click.echo(f"Error: {exc}", err=True) + sys.exit(1) + + try: + result = diff_feeds( + base_path=base_feed, + new_path=new_feed, + row_changes_cap_per_file=cap, + base_downloaded_at=base_dt, + new_downloaded_at=new_dt, + ) + except Exception as exc: + click.echo(f"Error: {exc}", err=True) + sys.exit(1) + + json_str = result.model_dump_json(indent=2) if pretty else result.model_dump_json() + + if output is not None: + try: + output.write_text(json_str, encoding="utf-8") + except OSError as exc: + click.echo(f"Error writing output: {exc}", err=True) + sys.exit(1) + else: + click.echo(json_str) diff --git a/src/gtfs_diff/engine.py b/src/gtfs_diff/engine.py new file mode 100644 index 0000000..011f570 --- /dev/null +++ b/src/gtfs_diff/engine.py @@ -0,0 +1,526 @@ +"""Core diff engine: loads two GTFS feeds and computes structured differences. + +Memory note +----------- +The two-pass algorithm builds in-memory indexes mapping primary-key tuples to +(line_number, raw_csv_string) for every row in each file. For typical transit +feeds this is fine. For very large feeds (stop_times.txt can exceed 10 M rows) +a disk-backed index (e.g. SQLite) would be more appropriate; that is left as a +future optimisation. +""" + +from __future__ import annotations + +import csv +import io +import zipfile +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path +from typing import Callable, Generator, TextIO + +from .gtfs_definitions import get_primary_key +from .models import ( + ColumnEntry, + FeedSource, + FieldChange, + FileDiff, + FileSummary, + GtfsDiff, + Metadata, + RowAdded, + RowChanges, + RowDeleted, + RowModified, + Summary, + Truncated, + UnsupportedFile, +) + +SCHEMA_VERSION = "2.0" + +# A "lazy opener" maps a filename (e.g. "stops.txt") to a zero-arg callable +# that opens the file and returns a text stream. +LazyOpeners = dict[str, Callable[[], TextIO]] + + +# --------------------------------------------------------------------------- +# Low-level CSV helpers +# --------------------------------------------------------------------------- + +def _row_to_csv(values: list[str]) -> str: + """Serialise a list of string values to a single CSV line (no trailing newline).""" + buf = io.StringIO() + writer = csv.writer(buf, lineterminator="") + writer.writerow(values) + return buf.getvalue() + + +def _read_headers(text_io: TextIO) -> list[str]: + """Read only the header row from a CSV stream, stripping whitespace.""" + reader = csv.reader(text_io) + try: + row = next(reader) + return [h.strip() for h in row] + except StopIteration: + return [] + + +def _read_csv_index( + text_io: TextIO, + pk_columns: list[str] | None = None, +) -> tuple[list[str], dict[tuple, tuple[int, str]]]: + """Stream a CSV file and build a primary-key → (line_number, raw_csv_string) index. + + Args: + text_io: Open text stream for the CSV file (utf-8-sig recommended). + pk_columns: Columns that form the primary key. `None` / empty list + means use *all* columns as the composite key. + + Returns: + headers: Stripped column names from the header row. + index: Maps `pk_tuple` → `(line_number, raw_csv_string)`. + Line numbers are 1-based; the header row is line 1, so the + first data row is line 2. + """ + reader = csv.reader(text_io) + try: + raw_headers = next(reader) + except StopIteration: + return [], {} + + headers = [h.strip() for h in raw_headers] + n = len(headers) + effective_pk = pk_columns if pk_columns else headers + + index: dict[tuple, tuple[int, str]] = {} + for line_num, row in enumerate(reader, start=2): + # Pad short rows; trim rows wider than the header (malformed CSV safety). + if len(row) < n: + row = row + [""] * (n - len(row)) + row_vals = row[:n] + row_dict = dict(zip(headers, row_vals)) + pk_tuple = tuple(row_dict.get(col, "") for col in effective_pk) + index[pk_tuple] = (line_num, _row_to_csv(row_vals)) + + return headers, index + + +def _parse_raw_line(raw_line: str, headers: list[str]) -> dict[str, str]: + """Deserialise a raw CSV string (as stored in the index) back to a row dict.""" + reader = csv.reader(io.StringIO(raw_line)) + try: + row = next(reader) + except StopIteration: + return {col: "" for col in headers} + if len(row) < len(headers): + row = row + [""] * (len(headers) - len(row)) + return dict(zip(headers, row)) + + +def _compute_raw_value( + row_dict: dict[str, str], + columns: list[str], + present_headers: set[str], +) -> str: + """Build an ordered CSV string aligned to the *union* columns list. + + Columns absent from ``present_headers`` (i.e. the file this row came from + did not have that column) are rendered as empty strings. + """ + values = [row_dict[col] if col in present_headers else "" for col in columns] + return _row_to_csv(values) + + +# --------------------------------------------------------------------------- +# Feed opener +# --------------------------------------------------------------------------- + +@contextmanager +def _open_feed(path: str | Path) -> Generator[LazyOpeners, None, None]: + """Open a GTFS feed (zip archive or directory) and yield lazy file openers. + + Each entry in the returned dict is a zero-arg callable that, when called, + opens the corresponding ``.txt`` file and returns a utf-8-sig text stream. + Callers are responsible for closing each stream. + + Supports: + * ``.zip`` archives (files at root *or* inside a single sub-directory). + * Plain directories containing ``.txt`` files. + """ + path = Path(path) + + if path.is_dir(): + openers: LazyOpeners = {} + for txt_file in sorted(path.glob("*.txt")): + def _make_opener(p: Path) -> Callable[[], TextIO]: + return lambda: p.open(encoding="utf-8-sig") + openers[txt_file.name] = _make_opener(txt_file) + yield openers + + elif zipfile.is_zipfile(path): + zf = zipfile.ZipFile(path, "r") + try: + # Accept both root-level files and files inside exactly one subdirectory. + name_map: dict[str, str] = {} + for member in zf.namelist(): + if member.endswith(".txt"): + basename = member.rsplit("/", 1)[-1] + if basename: + name_map[basename] = member + + openers = {} + for basename, internal_path in name_map.items(): + def _make_opener(ip: str) -> Callable[[], TextIO]: # type: ignore[misc] + return lambda: io.TextIOWrapper(zf.open(ip), encoding="utf-8-sig") + openers[basename] = _make_opener(internal_path) + yield openers + finally: + zf.close() + + else: + raise ValueError( + f"Unsupported feed path: {path!r}. Must be a .zip file or a directory." + ) + + +# --------------------------------------------------------------------------- +# Per-file diff +# --------------------------------------------------------------------------- + +def _diff_file( + file_name: str, + base_opener: Callable[[], TextIO] | None, + new_opener: Callable[[], TextIO] | None, + row_changes_cap: int | None, +) -> tuple[FileDiff, FileSummary]: + """Compute the complete diff for one GTFS file. + + Returns a (FileDiff, FileSummary) pair. + """ + # ------------------------------------------------------------------ + # File added (only in new feed) + # ------------------------------------------------------------------ + if base_opener is None: + assert new_opener is not None + with new_opener() as f: + new_headers = _read_headers(f) + columns_added = [ + ColumnEntry(name=col, position=i + 1) + for i, col in enumerate(new_headers) + ] + file_diff = FileDiff( + file_name=file_name, + file_action="added", + columns_added=columns_added, + columns_deleted=[], + row_changes=None, + truncated=None, + ) + summary = FileSummary( + file_name=file_name, + status="added", + columns_added_count=len(columns_added), + columns_deleted_count=0, + ) + return file_diff, summary + + # ------------------------------------------------------------------ + # File deleted (only in base feed) + # ------------------------------------------------------------------ + if new_opener is None: + with base_opener() as f: + base_headers = _read_headers(f) + columns_deleted = [ + ColumnEntry(name=col, position=i + 1) + for i, col in enumerate(base_headers) + ] + file_diff = FileDiff( + file_name=file_name, + file_action="deleted", + columns_added=[], + columns_deleted=columns_deleted, + row_changes=None, + truncated=None, + ) + summary = FileSummary( + file_name=file_name, + status="deleted", + columns_added_count=0, + columns_deleted_count=len(columns_deleted), + ) + return file_diff, summary + + # ------------------------------------------------------------------ + # File present in both feeds + # ------------------------------------------------------------------ + pk_def = get_primary_key(file_name) + assert pk_def is not None # caller guarantees supported files only + + # For files with an empty PK definition, use all base columns as the key. + if len(pk_def) == 0: + with base_opener() as f: + initial_base_headers = _read_headers(f) + pk_cols: list[str] = initial_base_headers + else: + pk_cols = pk_def + + # Build indexes (two streaming passes, one per file). + with base_opener() as f: + base_headers, base_index = _read_csv_index(f, pk_cols) + + with new_opener() as f: + new_headers, new_index = _read_csv_index(f, pk_cols) + + # Column-level diff + base_header_set = set(base_headers) + new_header_set = set(new_headers) + + columns_added = [ + ColumnEntry(name=col, position=i + 1) + for i, col in enumerate(new_headers) + if col not in base_header_set + ] + columns_deleted = [ + ColumnEntry(name=col, position=i + 1) + for i, col in enumerate(base_headers) + if col not in new_header_set + ] + + # Union columns: base order first, then new-only columns appended. + new_only_cols = [col for col in new_headers if col not in base_header_set] + union_columns: list[str] = base_headers + new_only_cols + + # ------------------------------------------------------------------ + # Row-level diff + # ------------------------------------------------------------------ + base_keys = set(base_index) + new_keys = set(new_index) + + added_keys = new_keys - base_keys + deleted_keys = base_keys - new_keys + common_keys = base_keys & new_keys + + # For modified-row detection, compare only the *shared* columns to avoid + # flagging every row when a column is added to / removed from the file. + shared_cols = [col for col in base_headers if col in new_header_set] + + added_rows: list[RowAdded] = [] + deleted_rows: list[RowDeleted] = [] + modified_rows: list[RowModified] = [] + + # Collect true counts before applying the cap. + true_added = len(added_keys) + true_deleted = len(deleted_keys) + + # We count true modified while building; detect first to get true count. + modified_candidates: list[tuple[tuple, list[FieldChange], int, int]] = [] + + for pk_tuple in common_keys: + b_line, b_raw = base_index[pk_tuple] + n_line, n_raw = new_index[pk_tuple] + b_dict = _parse_raw_line(b_raw, base_headers) + n_dict = _parse_raw_line(n_raw, new_headers) + + field_changes = [ + FieldChange(field=col, base_value=b_dict[col], new_value=n_dict[col]) + for col in shared_cols + if b_dict.get(col, "") != n_dict.get(col, "") + ] + if field_changes: + modified_candidates.append((pk_tuple, field_changes, b_line, n_line)) + + true_modified = len(modified_candidates) + + # Determine row-changes output based on cap. + include_row_changes: bool + if row_changes_cap == 0: + include_row_changes = False + else: + include_row_changes = True + + truncated: Truncated | None = None + + if include_row_changes: + cap = row_changes_cap # None = unlimited + + def _remaining(used: int) -> int | None: + return None if cap is None else max(0, cap - used) + + # Fill added rows up to cap. + added_limit = _remaining(0) + for pk_tuple in list(added_keys)[:added_limit]: + n_line, n_raw = new_index[pk_tuple] + n_dict = _parse_raw_line(n_raw, new_headers) + identifier = {col: n_dict.get(col, "") for col in pk_cols} + raw_value = _compute_raw_value(n_dict, union_columns, new_header_set) + added_rows.append( + RowAdded(identifier=identifier, raw_value=raw_value, new_line_number=n_line) + ) + + # Fill deleted rows up to remaining cap. + deleted_limit = _remaining(len(added_rows)) + for pk_tuple in list(deleted_keys)[:deleted_limit]: + b_line, b_raw = base_index[pk_tuple] + b_dict = _parse_raw_line(b_raw, base_headers) + identifier = {col: b_dict.get(col, "") for col in pk_cols} + raw_value = _compute_raw_value(b_dict, union_columns, base_header_set) + deleted_rows.append( + RowDeleted(identifier=identifier, raw_value=raw_value, base_line_number=b_line) + ) + + # Fill modified rows up to remaining cap. + modified_limit = _remaining(len(added_rows) + len(deleted_rows)) + for pk_tuple, field_changes, b_line, n_line in modified_candidates[:modified_limit]: + n_raw = new_index[pk_tuple][1] + n_dict = _parse_raw_line(n_raw, new_headers) + identifier = {col: n_dict.get(col, "") for col in pk_cols} + raw_value = _compute_raw_value(n_dict, union_columns, new_header_set) + modified_rows.append( + RowModified( + identifier=identifier, + raw_value=raw_value, + base_line_number=b_line, + new_line_number=n_line, + field_changes=field_changes, + ) + ) + + total_included = len(added_rows) + len(deleted_rows) + len(modified_rows) + total_true = true_added + true_deleted + true_modified + if cap is not None and total_true > cap: + truncated = Truncated(is_truncated=True, omitted_count=total_true - total_included) + + row_changes: RowChanges | None = None + if include_row_changes: + # Use pk_cols for the primary_key field; for empty-pk files that means + # all base columns — which is correct (they form the composite key). + row_changes = RowChanges( + primary_key=pk_cols, + columns=union_columns, + added=added_rows, + deleted=deleted_rows, + modified=modified_rows, + ) + + file_diff = FileDiff( + file_name=file_name, + file_action="modified", + columns_added=columns_added, + columns_deleted=columns_deleted, + row_changes=row_changes, + truncated=truncated, + ) + summary = FileSummary( + file_name=file_name, + status="modified", + columns_added_count=len(columns_added), + columns_deleted_count=len(columns_deleted), + rows_added_count=true_added, + rows_deleted_count=true_deleted, + rows_modified_count=true_modified, + ) + return file_diff, summary + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def diff_feeds( + base_path: str | Path, + new_path: str | Path, + row_changes_cap_per_file: int | None = None, + base_downloaded_at: datetime | None = None, + new_downloaded_at: datetime | None = None, +) -> GtfsDiff: + """Compare two GTFS feeds and return a structured :class:`GtfsDiff` result. + + Args: + base_path: Path to the base (old) GTFS feed — zip or directory. + new_path: Path to the new GTFS feed — zip or directory. + row_changes_cap_per_file: + * ``None`` — include all row changes (default). + * ``0`` — omit all row-level detail (column diffs and counts still computed). + * ``N > 0`` — include up to *N* row changes per file (added first, then + deleted, then modified); a :class:`Truncated` record is + attached when the true count exceeds *N*. + base_downloaded_at: When the base feed was downloaded; defaults to *now*. + new_downloaded_at: When the new feed was downloaded; defaults to *now*. + + Returns: + A fully-populated :class:`GtfsDiff` instance. + """ + now = datetime.now(timezone.utc) + base_downloaded_at = base_downloaded_at or now + new_downloaded_at = new_downloaded_at or now + + file_diffs: list[FileDiff] = [] + file_summaries: list[FileSummary] = [] + unsupported_files: list[UnsupportedFile] = [] + + with _open_feed(base_path) as base_openers, _open_feed(new_path) as new_openers: + all_files = sorted(set(base_openers) | set(new_openers)) + + for file_name in all_files: + in_base = file_name in base_openers + in_new = file_name in new_openers + + # Determine support status. + pk_def = get_primary_key(file_name) + if pk_def is None: + # Unsupported file — record it and skip. + if in_base and in_new: + present_in: str = "both" + elif in_base: + present_in = "base" + else: + present_in = "new" + unsupported_files.append( + UnsupportedFile(file_name=file_name, present_in=present_in) # type: ignore[arg-type] + ) + continue + + base_opener = base_openers.get(file_name) + new_opener = new_openers.get(file_name) + + file_diff, file_summary = _diff_file( + file_name=file_name, + base_opener=base_opener, + new_opener=new_opener, + row_changes_cap=row_changes_cap_per_file, + ) + file_diffs.append(file_diff) + file_summaries.append(file_summary) + + # Build summary aggregates. + files_added = sum(1 for s in file_summaries if s.status == "added") + files_deleted = sum(1 for s in file_summaries if s.status == "deleted") + files_modified = sum(1 for s in file_summaries if s.status == "modified") + + total_changes = ( + sum(s.rows_added_count or 0 for s in file_summaries) + + sum(s.rows_deleted_count or 0 for s in file_summaries) + + sum(s.rows_modified_count or 0 for s in file_summaries) + + sum(s.columns_added_count or 0 for s in file_summaries) + + sum(s.columns_deleted_count or 0 for s in file_summaries) + + files_added + + files_deleted + ) + + metadata = Metadata( + schema_version=SCHEMA_VERSION, + generated_at=now, + row_changes_cap_per_file=row_changes_cap_per_file, + base_feed=FeedSource(source=str(base_path), downloaded_at=base_downloaded_at), + new_feed=FeedSource(source=str(new_path), downloaded_at=new_downloaded_at), + unsupported_files=unsupported_files, + ) + summary = Summary( + total_changes=total_changes, + files_added_count=files_added, + files_deleted_count=files_deleted, + files_modified_count=files_modified, + files=file_summaries, + ) + return GtfsDiff(metadata=metadata, summary=summary, file_diffs=file_diffs) diff --git a/src/gtfs_diff/gtfs_definitions.py b/src/gtfs_diff/gtfs_definitions.py new file mode 100644 index 0000000..1fd0853 --- /dev/null +++ b/src/gtfs_diff/gtfs_definitions.py @@ -0,0 +1,65 @@ +""" +GTFS file definitions, primary keys, and schema constants. + +Design decisions: +- Primary keys follow the official GTFS Static specification. Where the spec defines + a composite primary key, all constituent columns are listed. +- Files with an empty primary key list ([]) have no single identifying column (e.g. + feed_info.txt is a single-row file, agency.txt allows omitting agency_id when there + is only one agency). For these, the diff engine falls back to treating ALL columns + collectively as the row identity. +- Files with no known primary key or that are not CSV (e.g. geojson locations) are + simply absent from GTFS_PRIMARY_KEYS and therefore absent from SUPPORTED_FILES. + Callers can check ``get_primary_key`` returning None to detect unsupported files. +""" + +GTFS_PRIMARY_KEYS: dict[str, list[str]] = { + # Core required files + "agency.txt": ["agency_id"], + "stops.txt": ["stop_id"], + "routes.txt": ["route_id"], + "trips.txt": ["trip_id"], + "stop_times.txt": ["trip_id", "stop_sequence"], + # Conditionally required + "calendar.txt": ["service_id"], + "calendar_dates.txt": ["service_id", "date"], + # Fares v1 + "fare_attributes.txt": ["fare_id"], + "fare_rules.txt": ["fare_id", "route_id", "origin_id", "destination_id", "contains_id"], + # Shapes / frequencies / transfers + "shapes.txt": ["shape_id", "shape_pt_sequence"], + "frequencies.txt": ["trip_id", "start_time"], + "transfers.txt": ["from_stop_id", "to_stop_id", "from_route_id", "to_route_id", "from_trip_id", "to_trip_id"], + # Pathways / levels + "pathways.txt": ["pathway_id"], + "levels.txt": ["level_id"], + # Feed metadata + "feed_info.txt": [], # single-row file, no primary key + "translations.txt": ["table_name", "field_name", "language", "record_id", "record_sub_id", "field_value"], + "attributions.txt": ["attribution_id"], + # Areas + "areas.txt": ["area_id"], + "stop_areas.txt": ["area_id", "stop_id"], + # Networks + "networks.txt": ["network_id"], + "route_networks.txt": ["route_id"], + # Fares v2 + "fare_media.txt": ["fare_media_id"], + "fare_products.txt": ["fare_product_id"], + "fare_leg_rules.txt": ["leg_group_id"], # partial key, best effort + "fare_transfer_rules.txt": ["from_leg_group_id", "to_leg_group_id", "transfer_count", "duration_limit"], + "timeframes.txt": ["timeframe_group_id", "start_time", "end_time", "service_id"], + # Rider categories / booking + "rider_categories.txt": ["rider_category_id"], + "booking_rules.txt": ["booking_rule_id"], + # Location groups (flex) + "location_groups.txt": ["location_group_id"], + "location_group_stops.txt": ["location_group_id", "stop_id"], +} + +SUPPORTED_FILES: set[str] = set(GTFS_PRIMARY_KEYS) + + +def get_primary_key(file_name: str) -> list[str] | None: + """Return the primary key columns for a supported GTFS file, or None if unsupported.""" + return GTFS_PRIMARY_KEYS.get(file_name) diff --git a/src/gtfs_diff/models.py b/src/gtfs_diff/models.py new file mode 100644 index 0000000..6665e5e --- /dev/null +++ b/src/gtfs_diff/models.py @@ -0,0 +1,178 @@ +"""Pydantic v2 models for the GTFS Diff v2 output format.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class ColumnEntry(BaseModel): + """A column that was added or deleted, with its name and original position.""" + + model_config = ConfigDict(populate_by_name=True) + + name: str + position: int = Field(..., ge=1) + + +class FeedSource(BaseModel): + """Identifies a GTFS feed by its source URL and the time it was downloaded.""" + + model_config = ConfigDict(populate_by_name=True) + + source: str + downloaded_at: datetime + + +class UnsupportedFile(BaseModel): + """A file present in one or both feeds that the diff engine does not support.""" + + model_config = ConfigDict(populate_by_name=True) + + file_name: str + present_in: Literal["base", "new", "both"] + + +class Metadata(BaseModel): + """Top-level metadata describing the diff run and both feed sources.""" + + model_config = ConfigDict(populate_by_name=True) + + schema_version: str + generated_at: datetime + row_changes_cap_per_file: Optional[int] = Field(None, ge=0) + base_feed: FeedSource + new_feed: FeedSource + unsupported_files: list[UnsupportedFile] + + +class FileSummary(BaseModel): + """High-level change counts for a single GTFS file.""" + + model_config = ConfigDict(populate_by_name=True) + + file_name: str + status: Literal["added", "deleted", "modified"] + columns_added_count: Optional[int] = Field(None, ge=0) + columns_deleted_count: Optional[int] = Field(None, ge=0) + rows_added_count: Optional[int] = Field(None, ge=0) + rows_deleted_count: Optional[int] = Field(None, ge=0) + rows_modified_count: Optional[int] = Field(None, ge=0) + + +class Summary(BaseModel): + """Aggregate change counts across all GTFS files in the diff.""" + + model_config = ConfigDict(populate_by_name=True) + + total_changes: int = Field(..., ge=0) + files_added_count: int = Field(..., ge=0) + files_deleted_count: int = Field(..., ge=0) + files_modified_count: int = Field(..., ge=0) + files: list[FileSummary] + + +class FieldChange(BaseModel): + """The before and after values for a single field within a modified row.""" + + model_config = ConfigDict(populate_by_name=True) + + field: str + base_value: str + new_value: str + + +class RowAdded(BaseModel): + """A row that exists only in the new feed.""" + + model_config = ConfigDict(populate_by_name=True) + + identifier: dict[str, str] + raw_value: str + new_line_number: int = Field(..., ge=1) + + +class RowDeleted(BaseModel): + """A row that exists only in the base feed.""" + + model_config = ConfigDict(populate_by_name=True) + + identifier: dict[str, str] + raw_value: str + base_line_number: int = Field(..., ge=1) + + +class RowModified(BaseModel): + """A row present in both feeds whose field values differ.""" + + model_config = ConfigDict(populate_by_name=True) + + identifier: dict[str, str] + raw_value: str + base_line_number: int = Field(..., ge=1) + new_line_number: int = Field(..., ge=1) + field_changes: list[FieldChange] = Field(..., min_length=1) + + +class RowChanges(BaseModel): + """All row-level changes for a file, keyed by primary key columns.""" + + model_config = ConfigDict(populate_by_name=True) + + primary_key: list[str] = Field(..., min_length=1) + columns: list[str] + added: list[RowAdded] + deleted: list[RowDeleted] + modified: list[RowModified] + + +class Truncated(BaseModel): + """Indicates that row changes were capped and some were omitted.""" + + model_config = ConfigDict(populate_by_name=True) + + is_truncated: Literal[True] + omitted_count: int = Field(..., ge=1) + + +class FileDiff(BaseModel): + """Complete diff for a single GTFS file, including column and row changes.""" + + model_config = ConfigDict(populate_by_name=True) + + file_name: str + file_action: Literal["modified", "added", "deleted"] + columns_added: list[ColumnEntry] + columns_deleted: list[ColumnEntry] + row_changes: Optional[RowChanges] = None + truncated: Optional[Truncated] = None + + +class GtfsDiff(BaseModel): + """Root model for the GTFS Diff v2 output format.""" + + model_config = ConfigDict(populate_by_name=True) + + metadata: Metadata + summary: Summary + file_diffs: list[FileDiff] + + +__all__ = [ + "ColumnEntry", + "FeedSource", + "UnsupportedFile", + "Metadata", + "FileSummary", + "Summary", + "FieldChange", + "RowAdded", + "RowDeleted", + "RowModified", + "RowChanges", + "Truncated", + "FileDiff", + "GtfsDiff", +] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..275e894 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for gtfs_diff.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1932d91 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,42 @@ +"""Shared pytest fixtures for the gtfs_diff test suite.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from tests.helpers import write_zip + + +# --------------------------------------------------------------------------- +# Reusable fixtures +# --------------------------------------------------------------------------- + +MINIMAL_BASE_FILES = { + "stops.txt": ( + "stop_id,stop_name,stop_lat,stop_lon\n" + "S1,Stop One,1.0,2.0\n" + "S2,Stop Two,3.0,4.0\n" + ), +} + +MINIMAL_NEW_FILES = { + "stops.txt": ( + "stop_id,stop_name,stop_lat,stop_lon\n" + "S1,Stop One,1.0,2.0\n" + "S3,Stop Three,5.0,6.0\n" + ), +} + + +@pytest.fixture +def minimal_base_zip(tmp_path: Path) -> Path: + """Return the path to a small GTFS zip for use in integration tests.""" + return write_zip(tmp_path / "base.zip", MINIMAL_BASE_FILES) + + +@pytest.fixture +def minimal_new_zip(tmp_path: Path) -> Path: + """Return the path to a small GTFS zip for use in integration tests.""" + return write_zip(tmp_path / "new.zip", MINIMAL_NEW_FILES) diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..8e36543 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,22 @@ +"""Shared helper functions for building in-memory GTFS zip archives.""" + +from __future__ import annotations + +import io +import zipfile +from pathlib import Path + + +def make_gtfs_zip(files: dict[str, str]) -> bytes: + """Build an in-memory GTFS zip archive from a {filename: csv_content} dict.""" + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: + for name, content in files.items(): + zf.writestr(name, content) + return buf.getvalue() + + +def write_zip(path: Path, files: dict[str, str]) -> Path: + """Write a GTFS zip to *path* and return the path.""" + path.write_bytes(make_gtfs_zip(files)) + return path diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..d0868ba --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,114 @@ +"""Tests for gtfs_diff.cli.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from gtfs_diff.cli import main + +from tests.helpers import write_zip + +STOPS_HEADER = "stop_id,stop_name,stop_lat,stop_lon\n" + +TINY_BASE_FILES = {"stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n"} +TINY_NEW_FILES = { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n" +} + + +def _make_feeds(tmp_path: Path): + base = write_zip(tmp_path / "base.zip", TINY_BASE_FILES) + new = write_zip(tmp_path / "new.zip", TINY_NEW_FILES) + return base, new + + +class TestHelp: + def test_help(self): + runner = CliRunner() + result = runner.invoke(main, ["--help"]) + assert result.exit_code == 0 + assert "Usage" in result.output or "usage" in result.output.lower() + + +class TestVersion: + def test_version(self): + runner = CliRunner() + result = runner.invoke(main, ["--version"]) + assert result.exit_code == 0 + assert "0.1.0" in result.output + + +class TestDiffToStdout: + def test_diff_to_stdout(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new)]) + assert result.exit_code == 0, result.output + data = json.loads(result.output) + assert "metadata" in data + assert "summary" in data + assert "file_diffs" in data + + +class TestDiffToFile: + def test_diff_to_file(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + out_file = tmp_path / "result.json" + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new), "--output", str(out_file)]) + assert result.exit_code == 0, result.output + assert out_file.exists() + data = json.loads(out_file.read_text(encoding="utf-8")) + assert "metadata" in data + + +class TestCapOption: + def test_cap_zero_produces_null_row_changes(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new), "--cap", "0"]) + assert result.exit_code == 0, result.output + data = json.loads(result.output) + for fd in data["file_diffs"]: + assert fd["row_changes"] is None + + def test_cap_stored_in_metadata(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new), "--cap", "5"]) + assert result.exit_code == 0, result.output + data = json.loads(result.output) + assert data["metadata"]["row_changes_cap_per_file"] == 5 + + +class TestInvalidPath: + def test_invalid_path_exits_nonzero(self, tmp_path: Path): + base = tmp_path / "nonexistent_base.zip" + new = tmp_path / "nonexistent_new.zip" + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new)]) + assert result.exit_code != 0 + + +class TestPrettyJson: + def test_pretty_output_is_indented(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new), "--pretty"]) + assert result.exit_code == 0, result.output + # Indented JSON contains newlines beyond just the top level + assert "\n" in result.output + assert " " in result.output # indentation present + + def test_no_pretty_output_is_compact(self, tmp_path: Path): + base, new = _make_feeds(tmp_path) + runner = CliRunner() + result = runner.invoke(main, [str(base), str(new), "--no-pretty"]) + assert result.exit_code == 0, result.output + # Should be valid JSON even without pretty-printing + data = json.loads(result.output) + assert "metadata" in data diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..69258b0 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,493 @@ +"""Tests for gtfs_diff.engine — diff_feeds().""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from gtfs_diff.engine import diff_feeds +from gtfs_diff.models import GtfsDiff + +from tests.helpers import write_zip + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _get_file_diff(result: GtfsDiff, file_name: str): + """Return the FileDiff for a given file name, or raise.""" + for fd in result.file_diffs: + if fd.file_name == file_name: + return fd + raise KeyError(f"{file_name!r} not found in file_diffs") + + +def _get_file_summary(result: GtfsDiff, file_name: str): + for fs in result.summary.files: + if fs.file_name == file_name: + return fs + raise KeyError(f"{file_name!r} not found in summary.files") + + +STOPS_HEADER = "stop_id,stop_name,stop_lat,stop_lon\n" + + +# --------------------------------------------------------------------------- +# File-level tests +# --------------------------------------------------------------------------- + +class TestFileAdded: + def test_file_added(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "routes.txt") + assert fd.file_action == "added" + assert len(fd.columns_added) == 2 + column_names = [c.name for c in fd.columns_added] + assert "route_id" in column_names + assert "route_short_name" in column_names + + def test_file_added_summary_status(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }) + result = diff_feeds(base, new) + fs = _get_file_summary(result, "routes.txt") + assert fs.status == "added" + assert result.summary.files_added_count == 1 + + +class TestFileDeleted: + def test_file_deleted(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "routes.txt") + assert fd.file_action == "deleted" + assert len(fd.columns_deleted) == 2 + column_names = [c.name for c in fd.columns_deleted] + assert "route_id" in column_names + + def test_file_deleted_summary_status(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "routes.txt": "route_id,route_short_name\nR1,Route 1\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fs = _get_file_summary(result, "routes.txt") + assert fs.status == "deleted" + assert result.summary.files_deleted_count == 1 + + +class TestIdenticalFeeds: + def test_identical_feeds(self, tmp_path: Path): + content = STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n" + base = write_zip(tmp_path / "base.zip", {"stops.txt": content}) + new = write_zip(tmp_path / "new.zip", {"stops.txt": content}) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert fd.file_action == "modified" + assert fd.row_changes is not None + assert fd.row_changes.added == [] + assert fd.row_changes.deleted == [] + assert fd.row_changes.modified == [] + + +# --------------------------------------------------------------------------- +# Row-level tests +# --------------------------------------------------------------------------- + +class TestRowsAdded: + def test_rows_added_count(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }) + result = diff_feeds(base, new) + fs = _get_file_summary(result, "stops.txt") + assert fs.rows_added_count == 1 + + def test_rows_added_identifier(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.added) == 1 + assert fd.row_changes.added[0].identifier == {"stop_id": "S2"} + + +class TestRowsDeleted: + def test_rows_deleted_count(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fs = _get_file_summary(result, "stops.txt") + assert fs.rows_deleted_count == 1 + + def test_rows_deleted_identifier(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.deleted) == 1 + assert fd.row_changes.deleted[0].identifier == {"stop_id": "S2"} + + +class TestRowsModified: + def test_rows_modified_count(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fs = _get_file_summary(result, "stops.txt") + assert fs.rows_modified_count == 1 + + def test_rows_modified_field_changes(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One Renamed,1.0,2.0\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.modified) == 1 + mod = fd.row_changes.modified[0] + assert mod.identifier == {"stop_id": "S1"} + field_names = [fc.field for fc in mod.field_changes] + assert "stop_name" in field_names + stop_name_change = next(fc for fc in mod.field_changes if fc.field == "stop_name") + assert stop_name_change.base_value == "Stop One" + assert stop_name_change.new_value == "Stop One Renamed" + + +class TestNoFalsePositives: + def test_no_false_positives_on_unchanged_rows(self, tmp_path: Path): + content = ( + STOPS_HEADER + + "S1,Stop One,1.0,2.0\n" + + "S2,Stop Two,3.0,4.0\n" + + "S3,Stop Three,5.0,6.0\n" + ) + base = write_zip(tmp_path / "base.zip", {"stops.txt": content}) + # only S3 changes + new_content = ( + STOPS_HEADER + + "S1,Stop One,1.0,2.0\n" + + "S2,Stop Two,3.0,4.0\n" + + "S3,Stop Three RENAMED,5.0,6.0\n" + ) + new = write_zip(tmp_path / "new.zip", {"stops.txt": new_content}) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + modified_ids = [m.identifier["stop_id"] for m in fd.row_changes.modified] + assert "S1" not in modified_ids + assert "S2" not in modified_ids + assert "S3" in modified_ids + + +# --------------------------------------------------------------------------- +# Column-level tests +# --------------------------------------------------------------------------- + +class TestColumnAdded: + def test_column_added(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + added_names = [c.name for c in fd.columns_added] + assert "stop_desc" in added_names + + def test_column_added_position(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + stop_desc_col = next(c for c in fd.columns_added if c.name == "stop_desc") + assert stop_desc_col.position == 3 + + +class TestColumnDeleted: + def test_column_deleted(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + deleted_names = [c.name for c in fd.columns_deleted] + assert "stop_desc" in deleted_names + + def test_column_deleted_position(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,A description\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + stop_desc_col = next(c for c in fd.columns_deleted if c.name == "stop_desc") + assert stop_desc_col.position == 3 + + +# --------------------------------------------------------------------------- +# Cap tests +# --------------------------------------------------------------------------- + +def _make_stops_csv(n: int) -> str: + header = "stop_id,stop_name\n" + rows = "".join(f"S{i},Stop {i}\n" for i in range(1, n + 1)) + return header + rows + + +class TestCapZero: + def test_cap_zero_omits_row_changes(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }) + result = diff_feeds(base, new, row_changes_cap_per_file=0) + fd = _get_file_diff(result, "stops.txt") + assert fd.row_changes is None + + +class TestCapLimits: + def test_cap_limits_row_changes(self, tmp_path: Path): + # 5 new rows added, cap = 3 → 3 included, omitted_count = 2 + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS0,Stop Zero\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\n" + "".join(f"S{i},Stop {i}\n" for i in range(1, 7)), + }) + result = diff_feeds(base, new, row_changes_cap_per_file=3) + fd = _get_file_diff(result, "stops.txt") + # Total included across added+deleted+modified <= 3 + total_included = ( + len(fd.row_changes.added) + + len(fd.row_changes.deleted) + + len(fd.row_changes.modified) + ) + assert total_included <= 3 + assert fd.truncated is not None + assert fd.truncated.is_truncated is True + assert fd.truncated.omitted_count >= 1 + + def test_truncated_omitted_count_correct(self, tmp_path: Path): + # 5 added rows, 0 deleted, 0 modified; cap = 3 → omitted = 2 + base = write_zip(tmp_path / "base.zip", { + "stops.txt": _make_stops_csv(0).replace("\n", "", 1), # header only + }) + # Write header-only base + base = write_zip(tmp_path / "base2.zip", { + "stops.txt": "stop_id,stop_name\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": _make_stops_csv(5), + }) + result = diff_feeds(base, new, row_changes_cap_per_file=3) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.added) == 3 + assert fd.truncated.omitted_count == 2 + + +class TestCapNone: + def test_cap_none_includes_all(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": _make_stops_csv(5), + }) + result = diff_feeds(base, new, row_changes_cap_per_file=None) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.added) == 5 + assert fd.truncated is None + + +# --------------------------------------------------------------------------- +# Unsupported files +# --------------------------------------------------------------------------- + +class TestUnsupportedFile: + def test_unsupported_file_in_metadata(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }) + result = diff_feeds(base, new) + unsupported_names = [u.file_name for u in result.metadata.unsupported_files] + assert "custom_data.txt" in unsupported_names + + def test_unsupported_file_present_in_both(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + "custom_data.txt": "foo,bar\n1,2\n", + }) + result = diff_feeds(base, new) + uf = next(u for u in result.metadata.unsupported_files if u.file_name == "custom_data.txt") + assert uf.present_in == "both" + + +# --------------------------------------------------------------------------- +# Schema validation (model round-trip) +# --------------------------------------------------------------------------- + +class TestOutputMatchesSchema: + def test_output_matches_schema(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": STOPS_HEADER + "S1,Stop One,1.0,2.0\nS2,Stop Two,3.0,4.0\n", + }) + result = diff_feeds(base, new) + json_str = result.model_dump_json() + restored = GtfsDiff.model_validate_json(json_str) + assert restored.summary.total_changes >= 1 + assert restored.metadata.schema_version == "2.0" + + +# --------------------------------------------------------------------------- +# raw_value column ordering +# --------------------------------------------------------------------------- + +class TestRawValueColumnOrder: + def test_raw_value_has_empty_for_base_only_column(self, tmp_path: Path): + # base has stop_lat; new does NOT have stop_lat (deleted column). + # Added rows in new should have empty string for stop_lat in raw_value. + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name,stop_lat\nS1,Stop One,1.0\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + # S2 is an added row; union_columns = [stop_id, stop_name, stop_lat] + # stop_lat is not in new_header_set → should be empty string + assert len(fd.row_changes.added) == 1 + added = fd.row_changes.added[0] + assert added.identifier == {"stop_id": "S2"} + # Parse raw_value CSV + import csv, io + row = next(csv.reader(io.StringIO(added.raw_value))) + # union_columns = base_headers + new_only_cols = [stop_id, stop_name, stop_lat] + # stop_lat not in new → empty + assert row[2] == "" # stop_lat column + + def test_union_columns_order_base_first(self, tmp_path: Path): + # New feed has an extra column; union_columns must be base_headers + new_only + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name,stop_desc\nS1,Stop One,Desc One\nS2,Stop Two,Desc Two\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert fd.row_changes.columns == ["stop_id", "stop_name", "stop_desc"] + + +# --------------------------------------------------------------------------- +# Line numbers +# --------------------------------------------------------------------------- + +class TestLineNumbers: + def test_added_row_line_number(self, tmp_path: Path): + # Header = line 1, first data row = line 2, second data row = line 3 + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.added) == 1 + assert fd.row_changes.added[0].new_line_number == 3 + + def test_deleted_row_line_number(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.deleted) == 1 + assert fd.row_changes.deleted[0].base_line_number == 3 + + def test_modified_row_line_numbers(self, tmp_path: Path): + base = write_zip(tmp_path / "base.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two\n", + }) + new = write_zip(tmp_path / "new.zip", { + "stops.txt": "stop_id,stop_name\nS1,Stop One\nS2,Stop Two RENAMED\n", + }) + result = diff_feeds(base, new) + fd = _get_file_diff(result, "stops.txt") + assert len(fd.row_changes.modified) == 1 + mod = fd.row_changes.modified[0] + assert mod.base_line_number == 3 + assert mod.new_line_number == 3 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..1878322 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,248 @@ +"""Tests for gtfs_diff.models.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from pydantic import ValidationError + +from gtfs_diff.models import ( + ColumnEntry, + FieldChange, + FileDiff, + FileSummary, + GtfsDiff, + Metadata, + FeedSource, + RowAdded, + RowChanges, + RowDeleted, + RowModified, + Summary, + Truncated, + UnsupportedFile, +) + +NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) + + +# --------------------------------------------------------------------------- +# Helpers to build valid model instances +# --------------------------------------------------------------------------- + +def _feed_source(url: str = "http://example.com/feed.zip") -> FeedSource: + return FeedSource(source=url, downloaded_at=NOW) + + +def _metadata(**kwargs) -> Metadata: + defaults = dict( + schema_version="2.0", + generated_at=NOW, + row_changes_cap_per_file=None, + base_feed=_feed_source("http://base"), + new_feed=_feed_source("http://new"), + unsupported_files=[], + ) + defaults.update(kwargs) + return Metadata(**defaults) + + +def _file_diff(**kwargs) -> FileDiff: + defaults = dict( + file_name="stops.txt", + file_action="modified", + columns_added=[], + columns_deleted=[], + row_changes=None, + truncated=None, + ) + defaults.update(kwargs) + return FileDiff(**defaults) + + +def _summary(**kwargs) -> Summary: + defaults = dict( + total_changes=0, + files_added_count=0, + files_deleted_count=0, + files_modified_count=0, + files=[], + ) + defaults.update(kwargs) + return Summary(**defaults) + + +def _gtfs_diff(**kwargs) -> GtfsDiff: + defaults = dict( + metadata=_metadata(), + summary=_summary(), + file_diffs=[], + ) + defaults.update(kwargs) + return GtfsDiff(**defaults) + + +# --------------------------------------------------------------------------- +# GtfsDiff round-trip +# --------------------------------------------------------------------------- + +class TestGtfsDiffRoundTrip: + def test_round_trip_empty(self): + obj = _gtfs_diff() + dumped = obj.model_dump() + restored = GtfsDiff.model_validate(dumped) + assert restored == obj + + def test_round_trip_with_file_diff(self): + file_diff = _file_diff( + file_action="added", + columns_added=[ColumnEntry(name="stop_id", position=1)], + ) + obj = _gtfs_diff( + summary=_summary(files_added_count=1, total_changes=1), + file_diffs=[file_diff], + ) + restored = GtfsDiff.model_validate(obj.model_dump()) + assert restored.file_diffs[0].file_action == "added" + + def test_round_trip_json(self): + obj = _gtfs_diff() + json_str = obj.model_dump_json() + restored = GtfsDiff.model_validate_json(json_str) + assert restored.metadata.schema_version == "2.0" + + +# --------------------------------------------------------------------------- +# ColumnEntry +# --------------------------------------------------------------------------- + +class TestColumnEntry: + def test_valid(self): + col = ColumnEntry(name="stop_id", position=1) + assert col.name == "stop_id" + assert col.position == 1 + + def test_position_zero_rejected(self): + with pytest.raises(ValidationError): + ColumnEntry(name="stop_id", position=0) + + def test_position_negative_rejected(self): + with pytest.raises(ValidationError): + ColumnEntry(name="stop_id", position=-5) + + def test_position_one_accepted(self): + col = ColumnEntry(name="stop_id", position=1) + assert col.position == 1 + + +# --------------------------------------------------------------------------- +# RowChanges +# --------------------------------------------------------------------------- + +class TestRowChanges: + def test_valid(self): + rc = RowChanges( + primary_key=["stop_id"], + columns=["stop_id", "stop_name"], + added=[], + deleted=[], + modified=[], + ) + assert rc.primary_key == ["stop_id"] + + def test_empty_primary_key_rejected(self): + with pytest.raises(ValidationError): + RowChanges( + primary_key=[], + columns=["stop_id"], + added=[], + deleted=[], + modified=[], + ) + + +# --------------------------------------------------------------------------- +# RowModified +# --------------------------------------------------------------------------- + +class TestRowModified: + def test_valid(self): + rm = RowModified( + identifier={"stop_id": "S1"}, + raw_value="S1,Old Name", + base_line_number=2, + new_line_number=2, + field_changes=[ + FieldChange(field="stop_name", base_value="Old", new_value="New") + ], + ) + assert len(rm.field_changes) == 1 + + def test_empty_field_changes_rejected(self): + with pytest.raises(ValidationError): + RowModified( + identifier={"stop_id": "S1"}, + raw_value="S1", + base_line_number=2, + new_line_number=2, + field_changes=[], + ) + + +# --------------------------------------------------------------------------- +# FileSummary +# --------------------------------------------------------------------------- + +class TestFileSummary: + def test_all_optional_counts_none(self): + fs = FileSummary(file_name="stops.txt", status="modified") + assert fs.rows_added_count is None + assert fs.rows_deleted_count is None + assert fs.rows_modified_count is None + assert fs.columns_added_count is None + assert fs.columns_deleted_count is None + + def test_with_counts(self): + fs = FileSummary( + file_name="stops.txt", + status="modified", + rows_added_count=3, + rows_deleted_count=1, + rows_modified_count=0, + ) + assert fs.rows_added_count == 3 + + +# --------------------------------------------------------------------------- +# Truncated +# --------------------------------------------------------------------------- + +class TestTruncated: + def test_valid(self): + t = Truncated(is_truncated=True, omitted_count=5) + assert t.is_truncated is True + assert t.omitted_count == 5 + + def test_is_truncated_must_be_true(self): + with pytest.raises(ValidationError): + Truncated(is_truncated=False, omitted_count=1) # type: ignore[arg-type] + + def test_omitted_count_must_be_positive(self): + with pytest.raises(ValidationError): + Truncated(is_truncated=True, omitted_count=0) + + +# --------------------------------------------------------------------------- +# UnsupportedFile +# --------------------------------------------------------------------------- + +class TestUnsupportedFile: + @pytest.mark.parametrize("present_in", ["base", "new", "both"]) + def test_valid_present_in(self, present_in: str): + uf = UnsupportedFile(file_name="custom.txt", present_in=present_in) # type: ignore[arg-type] + assert uf.present_in == present_in + + def test_invalid_present_in(self): + with pytest.raises(ValidationError): + UnsupportedFile(file_name="custom.txt", present_in="neither") # type: ignore[arg-type]