# **F1 Pit Stop Prediction**

## Libraries

In [1]:
import os
import glob

from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name, regexp_extract, regexp_replace, col, when, to_timestamp, lead, avg, stddev, lag, max, sum, first, last, split, coalesce, lit
from pyspark.sql.types import IntegerType, BooleanType, FloatType
from pyspark.sql.window import Window

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator

import optuna
from concurrent.futures import ThreadPoolExecutor

os.chdir(os.path.abspath(os.path.join(os.getcwd(), "..", "scripts")))
from constants import LAPS, TELEMETRY
# from preprocessing import add_pit_stop_label, engineer_features

## Data Loading

In [2]:
# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Lap Data Aggregation") \
    .master("local[*]") \
    .config("spark.driver.memory", "24g") \
    .config("spark.executor.memory", "24g") \
    .getOrCreate()

In [3]:
all_telemetry_files = glob.glob(os.path.join(TELEMETRY, "*.csv"))
all_laps_files = glob.glob(os.path.join(LAPS, "*.csv"))

telemetry_data = spark.read.option("header", True).csv(all_telemetry_files)
lap_data = spark.read.option("header", True).csv(all_laps_files)

## Data Preprocessing

### Lap Data

In [4]:
# Extract the file name from the file path
file_name_col = input_file_name()

In [5]:
# Extract the event name and session from the file name
lap_data = (
    lap_data
    .withColumn("Year", regexp_extract(file_name_col, r"/(\d{4})_[^/]+_[QR]\.csv$", 1))
    .withColumn("EventName", regexp_replace(regexp_extract(file_name_col, r"/\d{4}_(.+)_[QR]\.csv$", 1), "_", " "))
    .withColumn("Session", regexp_extract(file_name_col, r"/\d{4}_[^/]+_([QR])\.csv$", 1))
)

In [6]:
# Create sql view
lap_data.createOrReplaceTempView("laps")

# Filter for only Race sessions
lap_data = spark.sql("""
    SELECT *
    FROM laps
    WHERE Session = 'R'
""")

**Fixing Datatypes**

In [7]:
# # Check datatypes
# lap_data.printSchema()

In [8]:
# Fix datatypes
lap_data = (
    lap_data
    .withColumn("LapSessionTime", regexp_replace(col("Time"), r"^0 days ", ""))
    .withColumn("DriverNumber", col("DriverNumber").cast(IntegerType()))
    .withColumn("LapTime", split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Stint", col("Stint").cast(IntegerType()))
    .withColumn("PitOutTime", split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("PitInTime", split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector1Time", split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector2Time", split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector3Time", split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector1SessionTime", split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector2SessionTime", split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector3SessionTime", split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("SpeedI1", col("SpeedI1").cast(IntegerType()))
    .withColumn("SpeedI2", col("SpeedI2").cast(IntegerType()))
    .withColumn("SpeedFL", col("SpeedFL").cast(IntegerType()))
    .withColumn("SpeedST", col("SpeedST").cast(IntegerType()))
    .withColumn("IsPersonalBest", col("IsPersonalBest").cast(BooleanType()))
    .withColumn("TyreLife", col("TyreLife").cast(IntegerType()))
    .withColumn("FreshTyre", col("FreshTyre").cast(BooleanType()))
    .withColumn("LapStartTime", split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapStartDate", to_timestamp("LapStartDate", "yyyy-MM-dd HH:mm:ss.SSS"))
    .withColumn("TrackStatus", col("TrackStatus").cast(IntegerType()))
    .withColumn("Position", col("Position").cast(IntegerType()))
    .withColumn("Deleted", col("Deleted").cast(BooleanType()))
    .withColumn("FastF1Generated", col("FastF1Generated").cast(BooleanType()))
    .withColumn("IsAccurate", col("IsAccurate").cast(BooleanType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
    .withColumn("LapSessionTime", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
)

lap_data = lap_data.drop(col("Time"))

In [9]:
# # Show the result
# lap_data.show(1)

**Feature Engineering**

In [10]:
# Define windows
start_position_window = Window.partitionBy("Year", "EventName", "Driver")
lap_order_window = start_position_window.orderBy("LapNumber")

In [11]:
# Creating new features
lap_data = (
    lap_data
    .withColumn("rolling_avg_laptime", avg("LapTime").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)))
    .withColumn("pit_in_lap", when(col("PitInTime").isNotNull(), 1).otherwise(0))
    .withColumn("pit_exit_lap", when(col("PitOutTime").isNotNull(), 1).otherwise(0))
    .withColumn(
        "last_pit_lap",
        coalesce(
            max("pit_exit_lap").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)),
            lit(0)
        )
    )
    .withColumn("laps_since_last_pit", col("LapNumber") - col("last_pit_lap"))
    .withColumn(
        "prev_compound", 
        when(
            col("LapNumber") == 1, col("Compound")
        ).otherwise(
            lag("Compound").over(lap_order_window)
        )
    )
    .withColumn(
        "pit_stop_duration",
        when(
            col("PitOutTime").isNull(),
            lit(0)
        ).otherwise(
            col("PitOutTime") - lag("PitInTime").over(lap_order_window)
        )
    )
    .withColumn("max_pit_stop_duration", max("pit_stop_duration").over(lap_order_window))
    .withColumn("start_position", first(when(col("LapNumber") == 1, col("Position")), ignorenulls=True).over(start_position_window))
    .withColumn("position_change_since_race_start", col("start_position") - col("Position"))
    .withColumn(
        "fastest_sector", when(
            (col("Sector1Time") <= col("Sector2Time")) & (col("Sector1Time") <= col("Sector3Time")), 1
        ).when(
            (col("Sector2Time") <= col("Sector1Time")) & (col("Sector2Time") <= col("Sector3Time")), 2
        ).otherwise(3)
    )
)

lap_data = lap_data.drop("Sector1SessionTime", "Sector2SessionTime", "Sector3SessionTime", "DeletedReason", "IsAccurate", "LapStartDate")

In [12]:
lap_data.show(1)

+------+------------+-------+---------+-----+----------+---------+-----------+-----------+-----------+-------+-------+-------+-------+--------------+--------+--------+---------+--------+------------+-----------+--------+-------+---------------+----+--------------------+-------+--------------+-------------------+----------+------------+------------+-------------------+-------------+-----------------+---------------------+--------------+--------------------------------+--------------+
|Driver|DriverNumber|LapTime|LapNumber|Stint|PitOutTime|PitInTime|Sector1Time|Sector2Time|Sector3Time|SpeedI1|SpeedI2|SpeedFL|SpeedST|IsPersonalBest|Compound|TyreLife|FreshTyre|    Team|LapStartTime|TrackStatus|Position|Deleted|FastF1Generated|Year|           EventName|Session|LapSessionTime|rolling_avg_laptime|pit_in_lap|pit_exit_lap|last_pit_lap|laps_since_last_pit|prev_compound|pit_stop_duration|max_pit_stop_duration|start_position|position_change_since_race_start|fastest_sector|
+------+------------+---

### Telemetry Data

In [13]:
# Extract the file name from the file path
file_name_col = input_file_name()

# Extract the event name and session from the file name
telemetry_data = (
    telemetry_data
    .withColumn("Year", regexp_extract(file_name_col, r"/(\d{4})_[^/]+_[QR]\.csv$", 1))
    .withColumn("EventName", regexp_replace(regexp_extract(file_name_col, r"/\d{4}_(.+)_[QR]\.csv$", 1), "_", " "))
    .withColumn("Session", regexp_extract(file_name_col, r"/\d{4}_[^/]+_([QR])\.csv$", 1))
)

In [14]:
# Create sql view
telemetry_data.createOrReplaceTempView("telemetry")

# Filter for only Race events
telemetry_data = spark.sql("""
    SELECT *
    FROM telemetry
    WHERE Session = 'R'
""")

**Fixing Datatypes**

In [15]:
# # Check datatypes
# telemetry_data.printSchema()

In [16]:
telemetry_data = (
    telemetry_data
    .withColumn("Date", to_timestamp("Date", "yyyy-MM-dd HH:mm:ss.SSS"))
    .withColumn("RPM", col("RPM").cast(IntegerType()))
    .withColumn("Speed", col("Speed").cast(IntegerType()))
    .withColumn("nGear", col("nGear").cast(IntegerType()))
    .withColumn("Throttle", col("Throttle").cast(IntegerType()))
    .withColumn("Brake", col("Brake").cast(BooleanType()).cast(IntegerType()))
    .withColumn("DRS", col("DRS").cast(IntegerType()))
    .withColumn(
        "DataCollectionTime", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn(
        "SessionTime", split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Distance", col("Distance").cast(FloatType()))
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
    .withColumn(
        "IsDRSActive", when(
            col("DRS").isin(10, 12, 14), 1
        ).otherwise(0)
    )
)

telemetry_data = telemetry_data.drop(col("Time"))

In [17]:
# # Show the result
telemetry_data.show(1)

+--------------------+-----+-----+-----+--------+-----+---+------+-----------+--------+------+---------+----+----------------+-------+------------------+-----------+
|                Date|  RPM|Speed|nGear|Throttle|Brake|DRS|Source|SessionTime|Distance|Driver|LapNumber|Year|       EventName|Session|DataCollectionTime|IsDRSActive|
+--------------------+-----+-----+-----+--------+-----+---+------+-----------+--------+------+---------+----+----------------+-------+------------------+-----------+
|2023-08-27 13:03:...|10093|    0|    1|      15|    0|  1|   car|   3725.042|     0.0|   VER|        1|2023|Dutch Grand Prix|      R|             0.082|          0|
+--------------------+-----+-----+-----+--------+-----+---+------+-----------+--------+------+---------+----+----------------+-------+------------------+-----------+
only showing top 1 row



**Feature Engineering**

In [18]:
# Define window
window_spec = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").orderBy("SessionTime")
last_50_window = window_spec.rowsBetween(-49, 0)

In [19]:
# Compute per-lap aggregates
telemetry_data = (
    telemetry_data
    .withColumn("avg_speed_last_lap", avg("Speed").over(window_spec))
    .withColumn("max_speed_last_lap", max("Speed").over(window_spec))
    .withColumn("avg_throttle_last_lap", avg("Throttle").over(window_spec))
    .withColumn("avg_brake_last_lap", avg("Brake").over(window_spec))
    .withColumn("avg_rpm", avg("RPM").over(window_spec))
    .withColumn("gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("gear_change_count", sum("gear_change").over(window_spec))
    .withColumn(
        "DRS_activation_count",
        sum(
            when(
                (~lag("DRS").over(window_spec).isin(10, 12, 14)) & (col("DRS").isin(10, 12, 14)),
                1
            ).otherwise(0)
        ).over(window_spec.rowsBetween(Window.unboundedPreceding, 0))
    )
)

In [20]:
# Rolling features over last 50 telemetry rows
telemetry_data = (
    telemetry_data
    .withColumn("rolling_throttle_mean", avg("Throttle").over(last_50_window))
    .withColumn("rolling_brake_intensity", avg("Brake").over(last_50_window))
    .withColumn("rolling_gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("rolling_gear_change_rate", avg("rolling_gear_change").over(last_50_window))
    .withColumn("rolling_speed_mean", avg("Speed").over(last_50_window))
)

In [21]:
# Final sector features (define final 5% of distance per lap)
max_distance = telemetry_data.groupBy("Year", "EventName", "Driver", "LapNumber").agg(max("Distance").alias("max_dist"))
telemetry_data = telemetry_data.join(max_distance, on=["Year", "EventName", "Driver", "LapNumber"])
telemetry_data = telemetry_data.withColumn("in_final_sector", col("Distance") >= col("max_dist") * 0.95)

# Define new window
final_sector_window = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

telemetry_data = (
    telemetry_data
    .withColumn("final_sector_avg_speed", avg(when(col("in_final_sector"), col("Speed"))).over(final_sector_window))
    .withColumn("final_sector_throttle", avg(when(col("in_final_sector"), col("Throttle"))).over(final_sector_window))
    .withColumn("final_sector_brake", avg(when(col("in_final_sector"), col("Brake"))).over(final_sector_window))
)

In [22]:
# telemetry_data.show(1)

In [23]:
# Select final per-lap features
lap_feature_cols = [
    "EventName", "Driver", "LapNumber", "Year", "Session",
    "avg_speed_last_lap", "max_speed_last_lap",
    "avg_throttle_last_lap", "avg_brake_last_lap",
    "gear_change_count", "avg_rpm",
    "rolling_throttle_mean", "rolling_brake_intensity",
    "rolling_gear_change_rate", "rolling_speed_mean",
    "final_sector_avg_speed", "final_sector_throttle", 
    "final_sector_brake"
]

# For all columns, take the FIRST value per (Driver, LapNumber)
# Because window functions already populated each row with the same value within each lap
aggregated_laps = (
    telemetry_data
    .select(*lap_feature_cols)
    .groupBy("Year", "EventName", "Session", "Driver", "LapNumber")
    .agg(*[
        first(col_name).alias(col_name) 
        if col_name != "DRS_activation_count" 
        else last(col_name).alias(col_name) 
        for col_name in lap_feature_cols 
        if col_name not in ("Year", "EventName", "Session", "Driver", "LapNumber")
    ])
)

In [None]:
aggregated_laps.show(5)

### Joining the Data

In [None]:
# Join lap_data and telemetry_data
data = (
    lap_data
    .alias('lap')
    .join(
        aggregated_laps.alias('telemetry')
        ,on=["Year", "EventName", "Session", "Driver", "LapNumber"]
        ,how="outer"
    )
)

In [None]:
# Create target variable
data = (
    data
    .withColumn(
        "WillPitNextLap", when(
            lead("PitInTime", 1).over(Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber")).isNotNull(), 1
        )
    .otherwise(0)
    )
)

# data = data.drop("PitInTime")

In [None]:
#data.show(1)

In [None]:
data.printSchema()

root
 |-- Year: integer (nullable = true)
 |-- EventName: string (nullable = true)
 |-- Session: string (nullable = true)
 |-- Driver: string (nullable = true)
 |-- LapNumber: integer (nullable = true)
 |-- DriverNumber: integer (nullable = true)
 |-- LapTime: double (nullable = true)
 |-- Stint: integer (nullable = true)
 |-- PitOutTime: double (nullable = true)
 |-- PitInTime: double (nullable = true)
 |-- Sector1Time: double (nullable = true)
 |-- Sector2Time: double (nullable = true)
 |-- Sector3Time: double (nullable = true)
 |-- SpeedI1: integer (nullable = true)
 |-- SpeedI2: integer (nullable = true)
 |-- SpeedFL: integer (nullable = true)
 |-- SpeedST: integer (nullable = true)
 |-- IsPersonalBest: boolean (nullable = true)
 |-- Compound: string (nullable = true)
 |-- TyreLife: integer (nullable = true)
 |-- FreshTyre: boolean (nullable = true)
 |-- Team: string (nullable = true)
 |-- LapStartTime: double (nullable = true)
 |-- TrackStatus: integer (nullable = true)
 |-- Posit

**Handling Missing Values**

In [None]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

LapTime

In [None]:
# # Check missing values
# data.filter(col("LapTime").isNull()).count()

In [None]:
# # Check rows with missing values
# data.filter(col("LapTime").isNull()).show(5)

In [None]:
# Rows with NULL LapTime and Sector<1, 2, 3>Time are DNF so we drop these rows
data = (
    data
    .filter(~(
        col("Sector1Time").isNull() & 
        col("Sector2Time").isNull() & 
        col("Sector3Time").isNull() & 
        col("LapTime").isNull()
    ))
)

In [None]:
# Remove SectorTime columns
data = data.drop("Sector1Time", "Sector2Time", "Sector3Time")

In [None]:
# # Recheck
# data.filter(col("LapTime").isNull()).count()

In [None]:
# Fix missing values - compute by subtracting the time at the end and at the start of the lap
data = data.withColumn("LapTime", col("LapSessionTime") - col("LapStartTime"))

In [None]:
# # Recheck
# data.filter(col("LapTime").isNull()).count()

Missing values recheck, since DNF rows were removed.

In [None]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

SpeedI1, SpeedI2, SpeedFL, SpeedST

In [None]:
# # Check missing values
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [None]:
# Fill missing values - speed rolling average
driver_lap_window = Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber").rowsBetween(Window.unboundedPreceding, -1)

# List of columns to process
speed_cols = ["SpeedI1", "SpeedI2", "SpeedFL", "SpeedST"]

# Fill missing values
for col_name in speed_cols:
    cumulative_avg = avg(col(col_name)).over(driver_lap_window)
    data = (
        data
        .withColumn(
            col_name,
            when(col(col_name).isNull(), cumulative_avg).otherwise(col(col_name))
        )
    )

In [None]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [None]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).show(1)

In [None]:
# Fill missing values - teammate's speed in same lap

# Self-join on teammate info
teammate_join = data.alias("self").join(
    data.alias("tm"),
    on=[
        col("self.Year") == col("tm.Year"),
        col("self.EventName") == col("tm.EventName"),
        col("self.Session") == col("tm.Session"),
        col("self.Team") == col("tm.Team"),
        col("self.LapNumber") == col("tm.LapNumber"),
        col("self.Driver") != col("tm.Driver")
    ],
    how="left"
)

# Replace missing values from teammate values
updated_cols = [
    coalesce(col(f"self.{col_name}"), col(f"tm.{col_name}")).alias(col_name)
    if col_name in speed_cols else col(f"self.{col_name}")
    for col_name in data.columns
]

# Assign back to `data` (replacing the original one)
data = teammate_join.select(*updated_cols)

In [None]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [None]:
# Fill missing values - finish line speed with longest straight speed
data = (
    data
    .withColumn(
        "SpeedFL",
        when(col("SpeedFL").isNull(), col("SpeedST")).otherwise(col("SpeedFL"))
    )
)

In [None]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

pit_stop_duration

In [None]:
# # Check missing values
# data.filter(
#     col("pit_stop_duration").isNull()
# ).count()

In [None]:
# # Check missing values
# data.filter(
#     col("pit_stop_duration").isNull()
# ).show(3)

In [None]:
# These values seem to be mistakes, so we set PitOutTime to NULL and recompute pit_stop_duration and max_pit_stop_duration
data = (
    data
    .withColumn(
        "PitOutTime",
        when(col("pit_stop_duration").isNull(), None).otherwise(col("PitOutTime"))
    )
    .withColumn(
        "pit_stop_duration",
        when(
            col("PitOutTime").isNull(),
            lit(0)
        ).otherwise(
            col("PitOutTime") - lag("PitInTime").over(lap_order_window)
        )
    )
    .withColumn("max_pit_stop_duration", max("pit_stop_duration").over(lap_order_window))
)

In [None]:
# # Recheck
# data.filter(
#     col("pit_stop_duration").isNull()
# ).count()

max_pit_stop_duration

In [None]:
# # Check missing values
# data.filter(
#     col("max_pit_stop_duration").isNull()
# ).count()

# # No more :)

In [None]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

## Data Modelling

In [None]:
# Train test split
train_data = data.filter(~((col("EventName") == "Abu Dhabi Grand Prix") & (col("Year") == 2023)) & ~((col("EventName") == "Las Vegas Grand Prix") & (col("Year") == 2023)))
val_data = data.filter((col("EventName") == "Las Vegas Grand Prix") & (col("Year") == 2023))
test_data = data.filter((col("EventName") == "Abu Dhabi Grand Prix") & (col("Year") == 2023))

In [None]:
import optuna
from pyspark.ml.classification import RandomForestClassifier, LogisticRegression, MultilayerPerceptronClassifier
from xgboost.spark import SparkXGBClassifier
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import StandardScaler, VectorAssembler, StringIndexer, ChiSqSelector
from pyspark.sql.functions import col
import uuid

def train_model(target, train_data, val_data, test_data, model_type, optimize, num_features=20, n_trials=2):
    """
    Train a model with optional Optuna hyperparameter optimization, feature selection, and scaling for LR/MLP.
    
    Args:
        target (str): Target column name.
        train_data: PySpark DataFrame for training.
        val_data: PySpark DataFrame for validation during optimization.
        test_data: PySpark DataFrame for final evaluation.
        model_type (str): Model type ('xgb', 'rf', 'lr', 'mlp').
        optimize (bool): If True, optimize hyperparameters with Optuna; if False, use defaults.
        num_features (int): Number of features to select using Chi-Square Selector.
        n_trials (int): Number of Optuna trials if optimize=True.
    
    Returns:
        tuple: (trained model, areaUnderPR score on test data)
    """
    # Validate model_type
    valid_models = ['xgb', 'rf', 'lr', 'mlp']
    if model_type not in valid_models:
        raise ValueError(f"model_type must be one of {valid_models}")

    # Define indexers for categorical features
    indexers = [
        StringIndexer(inputCol="Team", outputCol="TeamIndex"),
        StringIndexer(inputCol="Compound", outputCol="CompoundIndex"),
        StringIndexer(inputCol="Driver", outputCol="DriverIndex"),
        StringIndexer(inputCol="EventName", outputCol="EventNameIndex")
    ]

    # Cache data
    train_data.cache()
    val_data.cache()
    test_data.cache()

    # Apply indexers to all datasets
    indexer_pipeline = Pipeline(stages=indexers)
    train_data = indexer_pipeline.fit(train_data).transform(train_data)
    val_data = indexer_pipeline.fit(val_data).transform(val_data)
    test_data = indexer_pipeline.fit(test_data).transform(test_data)

    # Get all columns for feature selection (exclude target and raw categorical columns)
    categorical_cols = ["Team", "Compound", "Driver", "EventName"]
    feature_cols = [col for col in train_data.columns if col != target and col not in categorical_cols]
    
    # Temporary assembler for feature selection
    temp_assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    temp_pipeline = Pipeline(stages=[temp_assembler])
    train_data = temp_pipeline.fit(train_data).transform(train_data)

    # Apply Chi-Square Selector
    selector = ChiSqSelector(
        numTopFeatures=num_features,
        featuresCol="features",
        outputCol="selected_features",
        labelCol=target
    )
    selector_model = selector.fit(train_data)
    selected_indices = selector_model.selectedFeatures
    selected_feature_names = [feature_cols[i] for i in selected_indices]
    print(f"Selected {len(selected_feature_names)} features: {selected_feature_names}")

    # Create new assembler with selected features
    assembler = VectorAssembler(inputCols=selected_feature_names, outputCol="features")
    feature_pipeline = Pipeline(stages=[assembler])

    # Update datasets with selected features
    train_data = feature_pipeline.fit(train_data).transform(train_data)
    val_data = feature_pipeline.fit(val_data).transform(val_data)
    test_data = feature_pipeline.fit(test_data).transform(test_data)

    def get_classifier(trial=None):
        """Define classifier based on model_type and optional trial parameters."""
        if model_type == 'rf':
            if optimize and trial:
                num_trees = trial.suggest_int("numTrees", 10, 100)
                max_depth = trial.suggest_int("maxDepth", 5, 30)
                min_instances_per_node = trial.suggest_int("minInstancesPerNode", 1, 10)
                subsampling_rate = trial.suggest_float("subsamplingRate", 0.5, 1.0)
                max_bins = trial.suggest_int("maxBins", 10, 50)
                return RandomForestClassifier(labelCol=target, featuresCol="features",
                                            numTrees=num_trees, maxDepth=max_depth,
                                            minInstancesPerNode=min_instances_per_node,
                                            subsamplingRate=subsampling_rate, maxBins=max_bins)
            return RandomForestClassifier(labelCol=target, featuresCol="features")
        
        elif model_type == 'xgb':
            if optimize and trial:
                max_depth = trial.suggest_int("maxDepth", 3, 10)
                num_round = trial.suggest_int("num_round", 10, 100)
                eta = trial.suggest_float("eta", 0.01, 0.3)
                subsample = trial.suggest_float("subsample", 0.5, 1.0)
                colsample_bytree = trial.suggest_float("colsample_bytree", 0.5, 1.0)
                return SparkXGBClassifier(
                    labelCol=target,
                    featuresCol="features",
                    max_depth=max_depth,
                    num_round=num_round,
                    eta=eta,
                    subsample=subsample,
                    colsample_bytree=colsample_bytree
                )
            return SparkXGBClassifier(labelCol=target, featuresCol="features")
        
        elif model_type == 'lr':
            if optimize and trial:
                reg_param = trial.suggest_float("regParam", 1e-5, 0.5, log=True)
                elastic_net = trial.suggest_float("elasticNetParam", 0.0, 1.0)
                tol = trial.suggest_float("tol", 1e-6, 1e-3, log=True)
                max_iter = trial.suggest_int("maxIter", 50, 300)
                fit_intercept = trial.suggest_categorical("fitIntercept", [True, False])
                return LogisticRegression(labelCol=target, featuresCol="scaled_features",
                                        regParam=reg_param, elasticNetParam=elastic_net,
                                        tol=tol, maxIter=max_iter, fitIntercept=fit_intercept)
            return LogisticRegression(labelCol=target, featuresCol="scaled_features")
        
        elif model_type == 'mlp':
            layers = [train_data.schema["features"].metadata["ml_attr"]["num_attrs"], 64, 32, 2]  
            if optimize and trial:
                max_iter = trial.suggest_int("maxIter", 50, 200)
                block_size = trial.suggest_int("blockSize", 32, 128)
                step_size = trial.suggest_float("stepSize", 0.001, 0.1, log=True)
                return MultilayerPerceptronClassifier(labelCol=target, featuresCol="scaled_features",
                                                    layers=layers, maxIter=max_iter, blockSize=block_size,
                                                    stepSize=step_size)
            return MultilayerPerceptronClassifier(labelCol=target, featuresCol="scaled_features",
                                                layers=layers)

    def objective(trial):
        """Objective function for Optuna optimization using validation data."""
        classifier = get_classifier(trial)
        # Add scaler for lr/mlp
        stages = [assembler]
        if model_type in ['lr', 'mlp']:
            scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withMean=True, withStd=True)
            stages.append(scaler)
        stages.append(classifier)
        
        pipeline = Pipeline(stages=stages)
        try:
            model = pipeline.fit(train_data)
            evaluator = BinaryClassificationEvaluator(labelCol=target, metricName="areaUnderPR")
            auc = evaluator.evaluate(model.transform(val_data))
            return auc
        except Exception as e:
            print(f"Trial failed: {e}")
            return 0.0

    # Train model
    if optimize:
        study = optuna.create_study(
            direction="maximize", 
            study_name=f"{model_type}_optimization_{uuid.uuid4()}"
        )
        
        study.optimize(objective, n_trials=n_trials)
        print(f"Best trial for {model_type}: {study.best_trial.params}")
        print(f"Best areaUnderPR: {study.best_value}")
        
        # Train final model with best parameters
        classifier = get_classifier(study.best_trial)
    else:
        classifier = get_classifier()

    # Build pipeline
    stages = [assembler]
    if model_type in ['lr', 'mlp']:
        scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withMean=True, withStd=True)
        stages.append(scaler)
    stages.append(classifier)
    pipeline = Pipeline(stages=stages)
    
    # Fit and evaluate on test data
    model = pipeline.fit(train_data)
    evaluator = BinaryClassificationEvaluator(labelCol=target, metricName="areaUnderPR")
    auc = evaluator.evaluate(model.transform(test_data))
    
    # Unpersist data
    train_data.unpersist()
    val_data.unpersist()
    test_data.unpersist()
    
    return model, auc

In [None]:
model, auc = train_model(
    target="WillPitNextLap",
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    model_type="xgb",
    optimize=True,
    num_features=30,
    n_trials=20
)
print(f"Model AUC on test data: {auc}")

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "c:\Users\Admin\anaconda\envs\ML\Lib\site-packages\py4j\java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Admin\anaconda\envs\ML\Lib\site-packages\py4j\clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Admin\anaconda\envs\ML\Lib\socket.py", line 720, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 