Skip to content

Add Spark-side trip validation helper with domain parity test#1

Draft
Copilot wants to merge 2 commits intodevfrom
copilot/add-spark-trip-validation-helper
Draft

Add Spark-side trip validation helper with domain parity test#1
Copilot wants to merge 2 commits intodevfrom
copilot/add-spark-trip-validation-helper

Conversation

Copy link

Copilot AI commented Jan 7, 2026

Adds Spark boolean expression helper for trip coordinate validation and verifies parity with domain-level Trip.is_valid() implementation.

Changes

  • data_transformations/citibike/domain.py: Trip dataclass with is_valid() validating coordinate nullability and ranges (lat: [-90, 90], lon: [-180, 180])

  • data_transformations/citibike/validation.py:

    • trip_is_valid_expr(): Spark Column expression mirroring domain validation logic
    • add_trip_validity_column(): Applies expression with NULL coalesced to False for strict boolean output
  • tests/conftest.py: Session-scoped spark fixture for unit tests

  • tests/unit/test_validation_vs_domain.py: Parametrized test comparing Spark and domain validation across 6 scenarios (valid trips, missing coords, out-of-range values)

Usage

from data_transformations.citibike.validation import add_trip_validity_column

df = spark.read.parquet("trips.parquet")
validated_df = add_trip_validity_column(df)  # Adds 'trip_is_valid' boolean column

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • api.launchpad.net
    • Triggering command: /usr/bin/add-apt-repository add-apt-repository ppa:deadsnakes/ppa -y (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Original prompt

Add Spark-side trip validation helper and unit test to mirror Trip.is_valid()

This PR adds a Spark boolean expression helper and a unit test that verifies parity between the Spark-side validation and the domain-level Trip.is_valid() implementation. Files to add:

  1. data_transformations/citibike/validation.py
  • Provides trip_is_valid_expr() that returns a pyspark Column boolean expression checking:
    • all four coordinates are non-null
    • latitudes between -90 and 90
    • longitudes between -180 and 180
  • Provides add_trip_validity_column(df, col_name="trip_is_valid") which applies the expression and coalesces NULL to False to produce a strict boolean column.

File contents:

from typing import List

from pyspark.sql import Column, DataFrame
from pyspark.sql import functions as F

# Reuse the same coordinate names used across the codebase
REQUIRED_COORDS: List[str] = [
    "start_station_latitude",
    "start_station_longitude",
    "end_station_latitude",
    "end_station_longitude",
]


def trip_is_valid_expr() -> Column:
    """
    Returns a Spark Column (boolean expression) that is True when:
      - All four coordinate columns are non-null
      - Latitudes are in [-90, 90]
      - Longitudes are in [-180, 180]
    The expression itself may evaluate to NULL for rows with NULL inputs; callers
    may want to coalesce to False if a strict boolean is required.
    """
    return (
        F.col("start_station_latitude").isNotNull()
        & F.col("start_station_longitude").isNotNull()
        & F.col("end_station_latitude").isNotNull()
        & F.col("end_station_longitude").isNotNull()
        & F.col("start_station_latitude").between(-90.0, 90.0)
        & F.col("end_station_latitude").between(-90.0, 90.0)
        & F.col("start_station_longitude").between(-180.0, 180.0)
        & F.col("end_station_longitude").between(-180.0, 180.0)
    )


def add_trip_validity_column(df: DataFrame, col_name: str = "trip_is_valid") -> DataFrame:
    """
    Returns a new DataFrame with a boolean column named `col_name` indicating validity.
    Uses coalesce(..., False) so the column is always boolean (no NULLs).
    """
    expr = trip_is_valid_expr()
    return df.withColumn(col_name, F.coalesce(expr, F.lit(False)))
  1. tests/unit/test_validation_vs_domain.py
  • Unit test that builds a set of sample rows, applies add_trip_validity_column and compares the boolean to Trip.is_valid() from data_transformations/citibike/domain.py
  • Depends on the existing session-scoped spark fixture in tests/conftest.py

File contents:

import pytest
from pyspark.sql import Row
from data_transformations.citibike.validation import add_trip_validity_column
from data_transformations.citibike.domain import Trip


def build_trip_row(start_lat, start_lon, end_lat, end_lon):
    return Row(
        start_station_latitude=start_lat,
        start_station_longitude=start_lon,
        end_station_latitude=end_lat,
        end_station_longitude=end_lon,
    )


@pytest.mark.parametrize(
    "start_lat,start_lon,end_lat,end_lon",
    [
        (40.0, -73.0, 40.001, -73.001),  # valid, small distance
        (40.0, -73.0, 40.0, -73.0),      # same point -> valid, distance 0
        (None, -73.0, 40.0, -73.0),      # missing start lat -> invalid
        (95.0, -73.0, 40.0, -73.0),      # start lat out of range -> invalid
        (40.0, -200.0, 40.0, -73.0),     # start lon out of range -> invalid
        (40.0, -73.0, None, None),       # missing end coords -> invalid
    ],
)
def test_spark_validation_matches_domain(spark, start_lat, start_lon, end_lat, end_lon):
    """
    For each sample row, compare the boolean result from the Spark expression
    with the Trip.is_valid() domain implementation.
    """
    rows = [build_trip_row(start_lat, start_lon, end_lat, end_lon)]
    df = spark.createDataFrame(rows)

    df_with_flag = add_trip_validity_column(df, col_name="trip_is_valid")
    collected = df_with_flag.collect()

    assert len(collected) == 1
    row = collected[0].asDict()

    # Domain-level validation using the Trip dataclass
    trip = Trip(
        start_station_latitude=row.get("start_station_latitude"),
        start_station_longitude=row.get("start_station_longitude"),
        end_station_latitude=row.get("end_station_latitude"),
        end_station_longitude=row.get("end_station_longitude"),
    )
    domain_valid = trip.is_valid()

    # Spark-side validation (coalesced to False by helper) should be boolean
    spark_valid = bool(row.get("trip_is_valid"))

    assert spark_valid == domain_valid, f"Mismatch for row {row}: spark={spark_valid}, domain={domain_valid}"

Notes

  • The tests use the existing tests/conftest.py spark fixture (session-scoped). Ensure tests/ is included in pytest discovery.
  • The new validation.py module is self-contained and aligns with the Trip domain model already present in dev.
  • Request: create a PR that adds these two files on top of the dev branch.

This pull request was created from Copilot chat.


💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.

Co-authored-by: abhiab54 <65629795+abhiab54@users.noreply.github.com>
Copilot AI changed the title [WIP] Add Spark-side trip validation helper and unit test Add Spark-side trip validation helper with domain parity test Jan 7, 2026
Copilot AI requested a review from abhiab54 January 7, 2026 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants