In [1]:
import logging

logger = logging.getLogger("blogpost_logger")
logger.setLevel(logging.DEBUG) # lowest level for the logger

logger.handlers.clear() # remove existing handlers to not accidentally duplicate them
sh = logging.StreamHandler() # handler for printing messages to console. Will need file handler in prod

sh.setLevel(logging.INFO) # lowest level for the handler to display
f = logging.Formatter("[%(asctime)s] %(levelname)s: %(message)s", "%Y-%m-%d %H:%M")
sh.setFormatter(f)
logger.addHandler(sh)

logger.info("Blogpost logger is initialized!")


[2026-01-07 16:37] INFO: Blogpost logger is initialized!


In [2]:
import findspark

from pyspark.sql import SparkSession

import pyspark.sql.functions as F
from pyspark.sql import Window
from pyspark.sql import DataFrame

from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.pipeline import PipelineModel
from pyspark.ml.evaluation import RegressionEvaluator

In [3]:
# little trick to make spark work locally
findspark.init()
# Create SparkSession
spark = SparkSession.builder.master("local[1]").appName("blogpost").config("spark.driver.bindAddress", "127.0.0.1").getOrCreate()

26/01/07 16:37:44 WARN Utils: Your hostname, Alexandrs-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.1.108 instead (on interface en0)
26/01/07 16:37:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/07 16:37:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/01/07 16:37:45 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [4]:
def read_dataset(dataset_name: str) -> DataFrame:
    file_path = f"../data/{dataset_name}.csv"
    return spark.read.csv(file_path, header=True, inferSchema=True)

def extract() -> dict[str, DataFrame]:
    logger.info("Extracting datasets")
    dataset_names = [
        "taxi_trip_data",
        "taxi_zone_geo",
    ]
    return {dataset_name: read_dataset(dataset_name) for dataset_name in dataset_names}

In [5]:
def limit_history_to_a_range(sdf: DataFrame) -> DataFrame:
    pickup_month = F.date_format(F.col("pickup_datetime"), "yyyyMM")
    return (
        sdf
        .filter(pickup_month > history_start_month)
        .filter(pickup_month <= history_end_month)
    )

def keep_evening_rides_only(sdf: DataFrame) -> DataFrame:
    dropoff_hour = F.date_format(F.col("dropoff_datetime"), "HH")
    return (
        sdf
        .filter(dropoff_hour >= first_evening_hour)
        .filter(dropoff_hour <= last_evening_hour)
    )

In [6]:
def exclude_airports_by_location(sdf: DataFrame, location_id_col_name: str) -> DataFrame:
    sdf_zone_geo_no_airport = (
        sdfs["taxi_zone_geo"]
        .filter(~F.lower(F.col("zone_name")).like("%airport%"))
    )
    return (
        sdf
        .join(
            sdf_zone_geo_no_airport,
            on=[F.col(location_id_col_name) == F.col("zone_id")],
            how="leftsemi"
        )
    )

In [7]:
def keep_first_n_daily_rides_only(sdf: DataFrame) -> DataFrame:
    pickup_date = F.date_format(F.col("pickup_datetime"), "yyyyMMdd")
    window = (
        Window
        .partitionBy("pickup_location_id", pickup_date)
        .orderBy(F.asc("pickup_datetime"))
    )
    return (
        sdf
        .withColumn("ride_number", F.row_number().over(window))
        .filter(F.col("ride_number") <= n_first_daily_rides_to_keep)
        .drop("ride_number")
    )

In [8]:
def filter_data(sdf: DataFrame) -> DataFrame:
    return (
        sdf
        .dropDuplicates()
        .transform(limit_history_to_a_range)
        .transform(keep_evening_rides_only)
        .transform(exclude_airports_by_location, "pickup_location_id")
        .transform(exclude_airports_by_location, "dropoff_location_id")
        .transform(keep_first_n_daily_rides_only)
    )

def add_features(sdf: DataFrame) -> DataFrame:
    return (
        sdf
        .withColumn("month", F.month(F.col("pickup_datetime")))
        .withColumn("day_of_week", F.dayofweek(F.col("pickup_datetime")))
        .withColumn("day_of_month", F.dayofmonth(F.col("pickup_datetime")))
        .withColumn("store_and_fwd_flag", F.when(F.col("store_and_fwd_flag") == "N", 0).otherwise(1))
    )

def train_test_split(sdf: DataFrame) -> tuple[DataFrame]:
    return (
        sdf
        .randomSplit(
            weights=[1-test_fraction, test_fraction],
            seed=42
        )
    )

In [9]:
def train_model() -> PipelineModel:

    assembler = VectorAssembler(
        inputCols=feature_cols,
        outputCol="features"
    )

    rf = RandomForestRegressor(
        labelCol="tip_amount",
        featuresCol="features",
        predictionCol="prediction",
        numTrees=10,
        maxDepth=4,
        featureSubsetStrategy="auto",
        seed=42,
        bootstrap=True,
    )

    pipeline = Pipeline(stages=[assembler, rf])

    return pipeline.fit(sdfs["training"])

In [10]:
def transform() -> PipelineModel:
    logger.info("Processing the data")
    sdf_prepared_data = (
        sdfs["taxi_trip_data"]
        .transform(filter_data)
        .transform(add_features)
    )
    sdfs["training"], sdfs["test"] = train_test_split(sdf_prepared_data)

    logger.info("Training the model")
    model = train_model()
    logger.info("Model is trained")

    return model

In [11]:
history_start_month = "201703"
history_end_month = "201811"
first_evening_hour = "17"
last_evening_hour = "23"
n_first_daily_rides_to_keep = 3
test_fraction = 0.2
feature_cols = [
    "passenger_count",
    "trip_distance",
    "rate_code",
    "store_and_fwd_flag",
    "payment_type",
    "fare_amount",
    "tolls_amount",
    "imp_surcharge",
    "month",
    "day_of_week",
    "day_of_month",
]

In [12]:
def check_features_importances() -> None:
    logger.info("  Features importances")
    importances = zip(feature_cols, model.stages[-1].featureImportances, strict=False)
    for name, importance in sorted(importances, key=lambda item: item[1], reverse=True):
        logger.info("%22s = %.2g", name, importance)

def evaluate_on_dataset(sdf: DataFrame) -> None:
    evaluator = RegressionEvaluator()
    evaluator.setPredictionCol("prediction")
    evaluator.setLabelCol("tip_amount")

    evaluation_metrics = ["rmse", "mae", "r2"]

    sdf_predictions = model.transform(sdf)

    for metric_name in evaluation_metrics:
        value = evaluator.evaluate(sdf_predictions, {evaluator.metricName: metric_name})
        logger.info("%8s = %.2g", metric_name, value)

def check_evaluation_metrics() -> None:
    for set_name in ["training", "test"]:
        logger.info("  Evaluation on the %s set", set_name)
        evaluate_on_dataset(sdfs[set_name])

def validate() -> None:
    logger.info("Start validation")
    check_features_importances()
    check_evaluation_metrics()

In [13]:
def load() -> None:
    logger.info("Saving the model")
    model.write().overwrite().save("../data/model")
    logger.info("The model is saved")

In [14]:
sdfs = extract()
model = transform()
validate()
load()

[2026-01-07 16:37] INFO: Extracting datasets
[2026-01-07 16:37] INFO: Processing the data                                    
[2026-01-07 16:37] INFO: Training the model
[2026-01-07 16:37] INFO: Model is trained                                       
[2026-01-07 16:37] INFO: Start validation
[2026-01-07 16:37] INFO:   Features importances
[2026-01-07 16:37] INFO:           payment_type = 0.66
[2026-01-07 16:37] INFO:            fare_amount = 0.2
[2026-01-07 16:37] INFO:              rate_code = 0.06
[2026-01-07 16:37] INFO:          trip_distance = 0.042
[2026-01-07 16:37] INFO:           tolls_amount = 0.03
[2026-01-07 16:37] INFO:                  month = 0.0063
[2026-01-07 16:37] INFO:           day_of_month = 0.00082
[2026-01-07 16:37] INFO:        passenger_count = 0.00047
[2026-01-07 16:37] INFO:          imp_surcharge = 0.00046
[2026-01-07 16:37] INFO:            day_of_week = 0.00013
[2026-01-07 16:37] INFO:     store_and_fwd_flag = 0
[2026-01-07 16:37] INFO:   Evaluation on th