## Set up

In [1]:
# little trick to make spark work locally
import findspark
findspark.init()

In [2]:
# make sure changes in imported modules become available without a need to restart
%load_ext autoreload
%autoreload 2

In [3]:
# pointing spark to search the folder with our modules for imports
import sys
sys.path.append('../src')

In [4]:
from tip_amount_model import TipAmountModelConfig, TipAmountModel

[2026-01-09 10:32] INFO: tip_amount_model logger is initialized!


## Debugging

In [5]:
config = TipAmountModelConfig()
job = TipAmountModel(config)

26/01/09 10:32:19 WARN Utils: Your hostname, Alexandrs-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.1.108 instead (on interface en0)
26/01/09 10:32:19 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/09 10:32:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
job.run()

[2026-01-09 10:32] INFO: Extracting datasets
[2026-01-09 10:32] INFO: Preparing the data for training                        
[2026-01-09 10:32] INFO: Training the model
[2026-01-09 10:32] INFO: Start validation                                       
[2026-01-09 10:32] INFO:   Features importances
[2026-01-09 10:32] INFO:           payment_type = 0.7
[2026-01-09 10:32] INFO:            fare_amount = 0.17
[2026-01-09 10:32] INFO:          trip_distance = 0.053
[2026-01-09 10:32] INFO:           tolls_amount = 0.037
[2026-01-09 10:32] INFO:              rate_code = 0.032
[2026-01-09 10:32] INFO:          imp_surcharge = 0.0039
[2026-01-09 10:32] INFO:           day_of_month = 0.0032
[2026-01-09 10:32] INFO:        passenger_count = 0.0031
[2026-01-09 10:32] INFO:                  month = 0.00049
[2026-01-09 10:32] INFO:            day_of_week = 0.00024
[2026-01-09 10:32] INFO: store_and_fwd_flag_is_N = 9.3e-06
[2026-01-09 10:32] INFO:   Evaluation on the training set
[2026-01-09 10:32] I

In [7]:
job.feature_cols

['passenger_count',
 'trip_distance',
 'rate_code',
 'payment_type',
 'fare_amount',
 'tolls_amount',
 'imp_surcharge',
 'month',
 'day_of_week',
 'day_of_month',
 'store_and_fwd_flag_is_N']

In [8]:
job.sdfs["training"].show(5)



+---------+-------------------+-------------------+---------------+-------------+---------+------------------+------------+-----------+-----+-------+----------+------------+-------------+------------+------------------+-------------------+-----+-----------+------------+-----------------------+
|vendor_id|    pickup_datetime|   dropoff_datetime|passenger_count|trip_distance|rate_code|store_and_fwd_flag|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|imp_surcharge|total_amount|pickup_location_id|dropoff_location_id|month|day_of_week|day_of_month|store_and_fwd_flag_is_N|
+---------+-------------------+-------------------+---------------+-------------+---------+------------------+------------+-----------+-----+-------+----------+------------+-------------+------------+------------------+-------------------+-----+-----------+------------+-----------------------+
|        1|2018-01-01 16:48:36|2018-01-01 17:20:58|              4|         10.0|        1|                 N|     

                                                                                

## Trying out different parameter values

In [9]:
job.config.test_fraction = 0.01
job.transform()
job.validate()

[2026-01-09 10:32] INFO: Preparing the data for training
[2026-01-09 10:32] INFO: Training the model
[2026-01-09 10:32] INFO: Start validation
[2026-01-09 10:32] INFO:   Features importances
[2026-01-09 10:32] INFO:           payment_type = 0.72
[2026-01-09 10:32] INFO:            fare_amount = 0.17
[2026-01-09 10:32] INFO:          trip_distance = 0.045
[2026-01-09 10:32] INFO:           tolls_amount = 0.03
[2026-01-09 10:32] INFO:              rate_code = 0.029
[2026-01-09 10:32] INFO:        passenger_count = 0.0036
[2026-01-09 10:32] INFO:          imp_surcharge = 0.0018
[2026-01-09 10:32] INFO:            day_of_week = 0.0012
[2026-01-09 10:32] INFO:                  month = 0.00072
[2026-01-09 10:32] INFO:           day_of_month = 0.00058
[2026-01-09 10:32] INFO: store_and_fwd_flag_is_N = 0
[2026-01-09 10:32] INFO:   Evaluation on the training set
[2026-01-09 10:32] INFO:     rmse = 3.1
[2026-01-09 10:32] INFO:      mae = 1.6
[2026-01-09 10:32] INFO:       r2 = 0.36
[2026-01-09 1

## Prototyping

In [10]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor

def train_model(self) -> None:
    assembler = VectorAssembler(inputCols=self.feature_cols, outputCol="features")

    gbt = GBTRegressor(
        labelCol="tip_amount",
        featuresCol="features",
        predictionCol="prediction",
        stepSize=0.1,
        maxDepth=4,
        featureSubsetStrategy="auto",
        seed=42,
    )

    pipeline = Pipeline(stages=[assembler, gbt])

    self.model = pipeline.fit(self.sdfs["training"])

    print("Modified content")

In [11]:
import types
job.train_model = types.MethodType(train_model, job)

In [12]:
job.transform()
job.validate()

[2026-01-09 10:32] INFO: Preparing the data for training
[2026-01-09 10:32] INFO: Training the model
[2026-01-09 10:32] INFO: Start validation
[2026-01-09 10:32] INFO:   Features importances
[2026-01-09 10:32] INFO:           payment_type = 0.37
[2026-01-09 10:32] INFO:            fare_amount = 0.22
[2026-01-09 10:32] INFO:           tolls_amount = 0.076
[2026-01-09 10:32] INFO:          trip_distance = 0.069
[2026-01-09 10:32] INFO:        passenger_count = 0.059
[2026-01-09 10:32] INFO:           day_of_month = 0.059
[2026-01-09 10:32] INFO:                  month = 0.05
[2026-01-09 10:32] INFO:              rate_code = 0.047
[2026-01-09 10:32] INFO:          imp_surcharge = 0.026
[2026-01-09 10:32] INFO:            day_of_week = 0.02
[2026-01-09 10:32] INFO: store_and_fwd_flag_is_N = 0
[2026-01-09 10:32] INFO:   Evaluation on the training set


Modified content


26/01/09 10:32:52 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
26/01/09 10:32:52 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
[2026-01-09 10:32] INFO:     rmse = 3
[2026-01-09 10:32] INFO:      mae = 1.4
[2026-01-09 10:32] INFO:       r2 = 0.42
[2026-01-09 10:32] INFO:   Evaluation on the test set
[2026-01-09 10:32] INFO:     rmse = 2.8                                         
[2026-01-09 10:32] INFO:      mae = 1.4
[2026-01-09 10:32] INFO:       r2 = 0.48                                        


## Testing the code

In [31]:
from datetime import datetime

def is_subset(a: list, b: list) -> bool:
    return set(a) <= set(b)

data = [
    (datetime(2021, 1, 1, 12, 0, 0), "Y"),
    (datetime(2021, 6, 15, 9, 30, 0), "N")
]

expected_columns = job.feature_cols[-4:]

sdf_fake_input = job.spark.createDataFrame(data, schema=["pickup_datetime", "store_and_fwd_flag"])
assert not is_subset(expected_columns, sdf_fake_input.columns)

sdf_fake_features = job.add_features(sdf_fake_input)
assert is_subset(expected_columns, sdf_fake_features.columns)

In [21]:
job.feature_cols[-4:]

['month', 'day_of_week', 'day_of_month', 'store_and_fwd_flag_is_N']

In [22]:
sdf_fake_features.columns

['pickup_datetime',
 'store_and_fwd_flag',
 'month',
 'day_of_week',
 'day_of_month',
 'store_and_fwd_flag_is_N']

In [19]:
sdf_fake_features.show(5)

+-------------------+------------------+-----+-----------+------------+-----------------------+
|    pickup_datetime|store_and_fwd_flag|month|day_of_week|day_of_month|store_and_fwd_flag_is_N|
+-------------------+------------------+-----+-----------+------------+-----------------------+
|2021-01-01 12:00:00|                 Y|    1|          6|           1|                      0|
|2021-06-15 09:30:00|                 N|    6|          3|          15|                      1|
|2022-03-03 23:59:59|              NULL|    3|          5|           3|                      0|
+-------------------+------------------+-----+-----------+------------+-----------------------+



In [73]:
from pyspark.sql import Row
from datetime import datetime
from typing import TypeVar, Type
from dataclasses import dataclass, asdict

T = TypeVar("T")

@dataclass
class Trip:
    vendor_id: int = 1
    pickup_datetime: datetime = datetime(2018, 2, 4, 11, 0, 0)
    dropoff_datetime: datetime = datetime(2018, 2, 4, 12, 30, 0)
    passenger_count: int = 2
    trip_distance: float = 50.2
    rate_code: int = 3
    store_and_fwd_flag: str = "N"
    payment_type: int = 1
    fare_amount: float = 10.5
    extra: float = 0.1
    mta_tax: float = 0.5
    tip_amount: float = 0.8
    tolls_amount: float = 0.1
    imp_surcharge: float = 1.2
    total_amount: float = 15.2
    pickup_location_id: int = 1
    dropoff_location_id: int = 2

@dataclass
class ZoneGeo:
    zone_id: int = 1
    zone_name: str = "Snack Zone"
    borough: str = "Food Borough"

def generate_rows(data_class: Type[T], data: list[tuple] = [()], columns: list[str] = []) -> list[Row]:
    generated_rows = []
    for record in data:
        record_dict = dict(zip(columns, record))
        record_class = data_class(**record_dict)
        record_row = Row(**asdict(record_class))
        generated_rows.append(record_row)
    return generated_rows

In [74]:
generate_rows(ZoneGeo)

[Row(zone_id=1, zone_name='Snack Zone', borough='Food Borough')]

In [56]:
ZoneGeo(**{"zone_id":"5"})

ZoneGeo(zone_id='5', zone_name='Snack Zone', borough='Food Borough')

In [81]:
from datetime import datetime

def is_subset(a: list, b: list) -> bool:
    return set(a) <= set(b)


def test_add_features_columns():
    columns=["pickup_datetime", "store_and_fwd_flag"]
    data = [
        (datetime(2021, 1, 1, 12, 0, 0), "Y"),
        (datetime(2021, 6, 15, 9, 30, 0), "N")
    ]

    tip_model = TipAmountModel(TipAmountModelConfig())

    tip_model.sdfs["taxi_trip_data"] = tip_model.spark.createDataFrame(generate_rows(Trip, data, columns))
    tip_model.sdfs["taxi_zone_geo"] = tip_model.spark.createDataFrame(generate_rows(ZoneGeo))

    assert not is_subset(tip_model.feature_cols, tip_model.sdfs["taxi_trip_data"].columns)

    tip_model.prepare_data()
    assert is_subset(tip_model.feature_cols, tip_model.sdfs["prepared_data"].columns)


In [82]:
job.spark.stop()