## Set up

In [1]:
# little trick to make spark work locally
import findspark
findspark.init()

In [2]:
# Create SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.bindAddress", "127.0.0.1").config("spark.driver.host", "127.0.0.1").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/12 16:10:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# make sure changes in imported modules become available without a need to restart
%load_ext autoreload
%autoreload 2

In [12]:
# pointing spark to search the folder with our modules for imports
import sys
sys.path.append('../..')

In [13]:
from pyspark_to_production.src.tip_amount_model import TipAmountModelConfig, TipAmountModel

[2026-01-12 16:12] INFO: pyspark_to_production.src.tip_amount_model logger is initialized!


In [14]:
config = TipAmountModelConfig()
job = TipAmountModel(config)

In [15]:
from datetime import datetime

def is_subset(a: list, b: list) -> bool:
    return set(a) <= set(b)

data = [
    (datetime(2021, 1, 1, 12, 0, 0), "Y"),
    (datetime(2021, 6, 15, 9, 30, 0), "N")
]

expected_columns = job.feature_cols[-4:]

sdf_fake_input = job.spark.createDataFrame(data, schema=["pickup_datetime", "store_and_fwd_flag"])
assert not is_subset(expected_columns, sdf_fake_input.columns)

sdf_fake_features = job.add_features(sdf_fake_input)
assert is_subset(expected_columns, sdf_fake_features.columns)

In [16]:
from pyspark.sql import Row
from datetime import datetime
from typing import TypeVar, Type
from dataclasses import dataclass, asdict

T = TypeVar("T")

@dataclass
class Trip:
    vendor_id: int = 1
    pickup_datetime: datetime = datetime(2018, 2, 4, 18, 0, 0)
    dropoff_datetime: datetime = datetime(2018, 2, 4, 19, 30, 0)
    passenger_count: int = 2
    trip_distance: float = 50.2
    rate_code: int = 3
    store_and_fwd_flag: str = "N"
    payment_type: int = 1
    fare_amount: float = 10.5
    extra: float = 0.1
    mta_tax: float = 0.5
    tip_amount: float = 0.8
    tolls_amount: float = 0.1
    imp_surcharge: float = 1.2
    total_amount: float = 15.2
    pickup_location_id: int = 1
    dropoff_location_id: int = 2

@dataclass
class ZoneGeo:
    zone_id: int = 1
    zone_name: str = "Snack Zone"
    borough: str = "Food Borough"

def generate_rows(data_class: Type[T], data: list[tuple] = [()], columns: list[str] = []) -> list[Row]:
    generated_rows = []
    for record in data:
        record_dict = dict(zip(columns, record))
        record_class = data_class(**record_dict)
        record_row = Row(**asdict(record_class))
        generated_rows.append(record_row)
    return generated_rows

In [17]:
from datetime import datetime

def is_subset(a: list, b: list) -> bool:
    return set(a) <= set(b)


def test_add_features_column_names() -> None:
    columns=["pickup_datetime", "store_and_fwd_flag"]
    data = [
        (datetime(2021, 1, 1, 12, 0, 0), "Y"),
        (datetime(2021, 6, 15, 9, 30, 0), "N")
    ]

    tip_model = TipAmountModel(TipAmountModelConfig())

    tip_model.sdfs["taxi_trip_data"] = tip_model.spark.createDataFrame(generate_rows(Trip, data, columns))
    tip_model.sdfs["taxi_zone_geo"] = tip_model.spark.createDataFrame(generate_rows(ZoneGeo))

    assert not is_subset(tip_model.feature_cols, tip_model.sdfs["taxi_trip_data"].columns)

    tip_model.transform()
    assert is_subset(tip_model.feature_cols, tip_model.sdfs["prepared_data"].columns)


test_add_features_column_names()

[2026-01-12 16:12] INFO: Preparing the data for training


In [18]:
def test_exclude_airports_by_location() -> None:
    columns=["pickup_location_id", "dropoff_location_id"]
    data = [
        (1, 1),
        (100, 1),
        (1, 100),
        (100, 100),
    ]

    tip_model = TipAmountModel(TipAmountModelConfig())

    # no airports
    tip_model.sdfs["taxi_trip_data"] = tip_model.spark.createDataFrame(generate_rows(Trip, data, columns))
    tip_model.sdfs["taxi_zone_geo"] = tip_model.spark.createDataFrame(
        generate_rows(ZoneGeo, [(100, "terrestrial", ), ["zone_id", "zone_name"]])
    )

    tip_model.transform()
    assert tip_model.sdfs["prepared_data"].count() == 4

    # all except one have airports
    tip_model.sdfs["taxi_zone_geo"] = tip_model.spark.createDataFrame(
        generate_rows(ZoneGeo, [(100, "is airport or so", )], ["zone_id", "zone_name"])
    )

    tip_model.transform()
    assert tip_model.sdfs["prepared_data"].count() == 1

test_exclude_airports_by_location()

[2026-01-12 16:13] INFO: Preparing the data for training
[2026-01-12 16:13] INFO: Preparing the data for training                        


In [19]:
job.spark.stop()