In [0]:
%run ../utils/validation

In [0]:
import pytest
from pyspark.sql.functions import *

In [0]:
df = spark.table("nyc_taxi.silver.data")

In [0]:
def test_no_nulls_in_required_columns():
    null_counts_df = validation.null_check(df)
    null_counts = null_counts_df.collect()[0].asDict()
    required_cols = [
        "VendorID", "lpep_pickup_datetime", "lpep_dropoff_datetime",
        "passenger_count", "trip_distance", "fare_amount", "pickup_month", "trip_id"
    ]
    for col_name in required_cols:
        assert null_counts[f"{col_name}_null_count"] == 0, f"{col_name} has nulls!"

In [0]:
def test_no_duplicate_trip_ids():
    dup_df = validation.check_duplicates(df, ["trip_id"])
    assert dup_df.count() == 0, "Duplicate trip_id values found!"

In [0]:
def test_valid_trip_distance():
    invalids = validation.check_value_range(df, "trip_distance", 0.01, 500)
    assert invalids.count() == 0, "Found invalid trip_distance values"

In [0]:
def test_fare_amount_non_negative():
    invalids = validation.check_value_range(df, "fare_amount", 0, float('inf'))
    assert invalids.count() == 0, "Found negative fare_amount"


In [0]:
def test_trip_speed_reasonable():
    if "trip_speed_mph" in df.columns:
        invalids = validation.check_value_range(df, "trip_speed_mph", 0, 150)
        assert invalids.count() == 0, "Unrealistic trip speeds found"

In [0]:
def test_minimum_row_count():
    assert validation.count_rows(df) > 0, "Silver table is empty!"

In [0]:
def test_warning_trip_speed():
    nulls = df.filter(col("trip_speed_mph").isNull()).count()
    assert nulls < 10, f"⚠️ {nulls} nulls in trip_speed_mph. Warning threshold exceeded."