## Set up

In [1]:
# little trick to make spark work locally
import findspark
findspark.init()

In [2]:
# Create SparkSession
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.driver.bindAddress", "127.0.0.1").config("spark.driver.host", "127.0.0.1").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/12 16:16:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# make sure changes in imported modules become available without a need to restart
%load_ext autoreload
%autoreload 2

In [4]:
# pointing spark to search the folder with our modules for imports
import sys
sys.path.append('../..')

In [5]:
from pyspark_to_production.src.tip_amount_model import TipAmountModelConfig, TipAmountModel

[2026-01-12 16:16] INFO: pyspark_to_production.src.tip_amount_model logger is initialized!


## Debugging

In [6]:
config = TipAmountModelConfig()
job = TipAmountModel(config)

In [7]:
job.run()

[2026-01-12 16:16] INFO: Extracting datasets
[2026-01-12 16:16] INFO:   taxi_trip_data...
[2026-01-12 16:16] INFO:   taxi_zone_geo...                                     
[2026-01-12 16:16] INFO: Preparing the data for training
[2026-01-12 16:16] INFO: Training the model
[2026-01-12 16:16] INFO: Start validation                                       
[2026-01-12 16:16] INFO:   Features importances
[2026-01-12 16:16] INFO:           payment_type = 0.54
[2026-01-12 16:16] INFO:            fare_amount = 0.21
[2026-01-12 16:16] INFO:              rate_code = 0.15
[2026-01-12 16:16] INFO:           tolls_amount = 0.055
[2026-01-12 16:16] INFO:          trip_distance = 0.035
[2026-01-12 16:16] INFO:            day_of_week = 0.0039
[2026-01-12 16:16] INFO:          imp_surcharge = 0.0022
[2026-01-12 16:16] INFO:        passenger_count = 0.0011
[2026-01-12 16:16] INFO:           day_of_month = 0.0011
[2026-01-12 16:16] INFO:                  month = 0.00075
[2026-01-12 16:16] INFO: store_and_f

In [8]:
job.sdfs["training"].show(5)



+---------+-------------------+-------------------+---------------+-------------+---------+------------------+------------+-----------+-----+-------+----------+------------+-------------+------------+------------------+-------------------+-----+-----------+------------+-----------------------+
|vendor_id|    pickup_datetime|   dropoff_datetime|passenger_count|trip_distance|rate_code|store_and_fwd_flag|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|imp_surcharge|total_amount|pickup_location_id|dropoff_location_id|month|day_of_week|day_of_month|store_and_fwd_flag_is_N|
+---------+-------------------+-------------------+---------------+-------------+---------+------------------+------------+-----------+-----+-------+----------+------------+-------------+------------+------------------+-------------------+-----+-----------+------------+-----------------------+
|        1|2018-01-01 16:48:36|2018-01-01 17:20:58|              4|         10.0|        1|                 N|     

                                                                                

## Trying out different parameter values

In [9]:
job.config.test_fraction = 0.01
job.train()
job.validate()

[2026-01-12 16:16] INFO: Training the model
[2026-01-12 16:16] INFO: Start validation                                       
[2026-01-12 16:16] INFO:   Features importances
[2026-01-12 16:16] INFO:           payment_type = 0.55
[2026-01-12 16:16] INFO:            fare_amount = 0.23
[2026-01-12 16:16] INFO:              rate_code = 0.12
[2026-01-12 16:16] INFO:           tolls_amount = 0.055
[2026-01-12 16:16] INFO:          trip_distance = 0.034
[2026-01-12 16:16] INFO:        passenger_count = 0.0037
[2026-01-12 16:16] INFO:          imp_surcharge = 0.002
[2026-01-12 16:16] INFO:           day_of_month = 0.0013
[2026-01-12 16:16] INFO:            day_of_week = 0.0012
[2026-01-12 16:16] INFO:                  month = 0.00073
[2026-01-12 16:16] INFO: store_and_fwd_flag_is_N = 0.00021
[2026-01-12 16:16] INFO:   Evaluation on the training set
[2026-01-12 16:16] INFO:     rmse = 4.2
[2026-01-12 16:16] INFO:      mae = 1.9
[2026-01-12 16:16] INFO:       r2 = 0.33
[2026-01-12 16:16] INFO:   

## Prototyping

In [10]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor

def train_model(self) -> None:
    assembler = VectorAssembler(inputCols=self.feature_cols, outputCol="features")

    gbt = GBTRegressor(
        labelCol="tip_amount",
        featuresCol="features",
        predictionCol="prediction",
        stepSize=0.1,
        maxDepth=4,
        featureSubsetStrategy="auto",
        seed=42,
    )

    pipeline = Pipeline(stages=[assembler, gbt])

    self.model = pipeline.fit(self.sdfs["training"])

    print("Modified content")

In [11]:
import types
job.train_model = types.MethodType(train_model, job)

In [12]:
job.train()
job.validate()

[2026-01-12 16:16] INFO: Training the model
[2026-01-12 16:16] INFO: Start validation
[2026-01-12 16:16] INFO:   Features importances
[2026-01-12 16:16] INFO:           payment_type = 0.32
[2026-01-12 16:16] INFO:            fare_amount = 0.26
[2026-01-12 16:16] INFO:              rate_code = 0.14
[2026-01-12 16:16] INFO:          trip_distance = 0.097
[2026-01-12 16:16] INFO:           tolls_amount = 0.058
[2026-01-12 16:16] INFO:        passenger_count = 0.036
[2026-01-12 16:16] INFO:                  month = 0.034
[2026-01-12 16:16] INFO:           day_of_month = 0.029
[2026-01-12 16:16] INFO:          imp_surcharge = 0.02
[2026-01-12 16:16] INFO:            day_of_week = 0.014
[2026-01-12 16:16] INFO: store_and_fwd_flag_is_N = 0
[2026-01-12 16:16] INFO:   Evaluation on the training set


Modified content


26/01/12 16:16:42 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
26/01/12 16:16:42 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS
[2026-01-12 16:16] INFO:     rmse = 3.9
[2026-01-12 16:16] INFO:      mae = 1.6
[2026-01-12 16:16] INFO:       r2 = 0.42
[2026-01-12 16:16] INFO:   Evaluation on the test set
[2026-01-12 16:16] INFO:     rmse = 3.8
[2026-01-12 16:16] INFO:      mae = 1.6
[2026-01-12 16:16] INFO:       r2 = 0.37


In [13]:
job.spark.stop()