Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions data_transformations/citibike/domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from typing import Optional


@dataclass
class Trip:
"""
Domain model for a Citibike trip with validation logic.
"""

start_station_latitude: Optional[float]
start_station_longitude: Optional[float]
end_station_latitude: Optional[float]
end_station_longitude: Optional[float]

def is_valid(self) -> bool:
"""
Returns True if the trip has valid coordinates:
- All four coordinate fields are non-null
- Latitudes are in [-90, 90]
- Longitudes are in [-180, 180]
"""
# Check for null values
if (
self.start_station_latitude is None
or self.start_station_longitude is None
or self.end_station_latitude is None
or self.end_station_longitude is None
):
return False

# Check latitude bounds
if not (-90.0 <= self.start_station_latitude <= 90.0):
return False
if not (-90.0 <= self.end_station_latitude <= 90.0):
return False

# Check longitude bounds
if not (-180.0 <= self.start_station_longitude <= 180.0):
return False
if not (-180.0 <= self.end_station_longitude <= 180.0):
return False

return True
42 changes: 42 additions & 0 deletions data_transformations/citibike/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List

from pyspark.sql import Column, DataFrame
from pyspark.sql import functions as F

# Reuse the same coordinate names used across the codebase
REQUIRED_COORDS: List[str] = [
"start_station_latitude",
"start_station_longitude",
"end_station_latitude",
"end_station_longitude",
]


def trip_is_valid_expr() -> Column:
"""
Returns a Spark Column (boolean expression) that is True when:
- All four coordinate columns are non-null
- Latitudes are in [-90, 90]
- Longitudes are in [-180, 180]
The expression itself may evaluate to NULL for rows with NULL inputs; callers
may want to coalesce to False if a strict boolean is required.
"""
return (
F.col("start_station_latitude").isNotNull()
& F.col("start_station_longitude").isNotNull()
& F.col("end_station_latitude").isNotNull()
& F.col("end_station_longitude").isNotNull()
& F.col("start_station_latitude").between(-90.0, 90.0)
& F.col("end_station_latitude").between(-90.0, 90.0)
& F.col("start_station_longitude").between(-180.0, 180.0)
& F.col("end_station_longitude").between(-180.0, 180.0)
)


def add_trip_validity_column(df: DataFrame, col_name: str = "trip_is_valid") -> DataFrame:
"""
Returns a new DataFrame with a boolean column named `col_name` indicating validity.
Uses coalesce(..., False) so the column is always boolean (no NULLs).
"""
expr = trip_is_valid_expr()
return df.withColumn(col_name, F.coalesce(expr, F.lit(False)))
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from pyspark.sql import SparkSession


@pytest.fixture(scope="session")
def spark() -> SparkSession:
"""Session-scoped Spark fixture for unit tests."""
return SparkSession.builder.appName("UnitTests").getOrCreate()
63 changes: 63 additions & 0 deletions tests/unit/test_validation_vs_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest
from pyspark.sql import Row
from pyspark.sql.types import DoubleType, StructField, StructType

from data_transformations.citibike.domain import Trip
from data_transformations.citibike.validation import add_trip_validity_column

# Define explicit schema for DataFrames to handle None values
TRIP_SCHEMA = StructType([
StructField("start_station_latitude", DoubleType(), True),
StructField("start_station_longitude", DoubleType(), True),
StructField("end_station_latitude", DoubleType(), True),
StructField("end_station_longitude", DoubleType(), True),
])


def build_trip_row(start_lat, start_lon, end_lat, end_lon):
return Row(
start_station_latitude=start_lat,
start_station_longitude=start_lon,
end_station_latitude=end_lat,
end_station_longitude=end_lon,
)


@pytest.mark.parametrize(
"start_lat,start_lon,end_lat,end_lon",
[
(40.0, -73.0, 40.001, -73.001), # valid, small distance
(40.0, -73.0, 40.0, -73.0), # same point -> valid, distance 0
(None, -73.0, 40.0, -73.0), # missing start lat -> invalid
(95.0, -73.0, 40.0, -73.0), # start lat out of range -> invalid
(40.0, -200.0, 40.0, -73.0), # start lon out of range -> invalid
(40.0, -73.0, None, None), # missing end coords -> invalid
],
)
def test_spark_validation_matches_domain(spark, start_lat, start_lon, end_lat, end_lon):
"""
For each sample row, compare the boolean result from the Spark expression
with the Trip.is_valid() domain implementation.
"""
rows = [build_trip_row(start_lat, start_lon, end_lat, end_lon)]
df = spark.createDataFrame(rows, schema=TRIP_SCHEMA)

df_with_flag = add_trip_validity_column(df, col_name="trip_is_valid")
collected = df_with_flag.collect()

assert len(collected) == 1
row = collected[0].asDict()

# Domain-level validation using the Trip dataclass
trip = Trip(
start_station_latitude=row.get("start_station_latitude"),
start_station_longitude=row.get("start_station_longitude"),
end_station_latitude=row.get("end_station_latitude"),
end_station_longitude=row.get("end_station_longitude"),
)
domain_valid = trip.is_valid()

# Spark-side validation (coalesced to False by helper) should be boolean
spark_valid = bool(row.get("trip_is_valid"))

assert spark_valid == domain_valid, f"Mismatch for row {row}: spark={spark_valid}, domain={domain_valid}"