In [0]:
import pytest
from pyspark.sql.functions import *

In [0]:
df = spark.table("nyc_taxi.silver.data")

In [0]:
def test_no_nulls_in_required_columns():
    required_cols = [
        "VendorID", "lpep_pickup_datetime", "lpep_dropoff_datetime",
        "passenger_count", "trip_distance", "fare_amount", "pickup_month", "trip_id"
    ]
    for col_name in required_cols:
        null_count = df.filter(col(col_name).isNull()).count()
        assert null_count == 0, f"Column '{col_name}' has {null_count} nulls"

In [0]:
def test_no_duplicate_trip_ids():
    total_count = df.count()
    distinct_count = df.select("trip_id").distinct().count()
    assert total_count == distinct_count, f"Found {total_count - distinct_count} duplicate trip_id rows"


In [0]:
def test_valid_trip_distance():
    invalid_count = df.filter(col("trip_distance") <= 0).count()
    assert invalid_count == 0, f"Found {invalid_count} rows with non-positive trip_distance"


In [0]:
def test_fare_amount_non_negative():
    invalid_count = df.filter(col("fare_amount") < 0).count()
    assert invalid_count == 0, f"Found {invalid_count} rows with negative fare_amount"


In [0]:
def test_trip_speed_reasonable():
    if "trip_speed_mph" in df.columns:
        max_speed = df.agg({"trip_speed_mph": "max"}).collect()[0][0]
        assert max_speed is None or max_speed <= 150, f"Unrealistic trip speed detected: {max_speed} mph"

In [0]:
def test_minimum_row_count():
    assert df.count() > 0, "Silver table is empty!"

In [0]:
def test_warning_trip_speed():
    non_critical_field = "trip_speed_mph"
    nulls = df.filter(col(non_critical_field).isNull()).count()
    
    if nulls > 10:
        raise Warning(f"⚠️ {nulls} nulls in {non_critical_field}. Investigate.")
    else:
        print(f"Only {nulls} nulls in {non_critical_field}, continuing.")