# **F1 Pit Stop Prediction**

## Libraries

In [1]:
import os

from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name, regexp_extract, regexp_replace, col, when, to_timestamp, lead, avg, stddev, lag, max, sum, first, last, split, coalesce, lit
from pyspark.sql.types import IntegerType, BooleanType, FloatType
from pyspark.sql.window import Window

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator

os.chdir(os.path.abspath(os.path.join(os.getcwd(), "..", "scripts")))
from constants import LAPS, TELEMETRY
# from preprocessing import add_pit_stop_label, engineer_features

## Data Loading

In [2]:
# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Lap Data Aggregation") \
    .master("local[*]") \
    .config("spark.driver.memory", "12g") \
    .config("spark.executor.memory", "12g") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/09 20:41:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Load all CSVs into a Spark DataFrame
lap_data = spark.read.option("header", True).csv(os.path.join(LAPS, "*.csv"))
telemetry_data = spark.read.option("header", True).csv(os.path.join(TELEMETRY, "*.csv"))

                                                                                

## Data Preprocessing

### Lap Data

In [4]:
# Extract the file name from the file path
file_name_col = input_file_name()

In [5]:
# Extract the event name and session from the file name
lap_data = (
    lap_data
    .withColumn("Year", regexp_extract(file_name_col, r"/(\d{4})_[^/]+_[QR]\.csv$", 1))
    .withColumn("EventName", regexp_replace(regexp_extract(file_name_col, r"/\d{4}_(.+)_[QR]\.csv$", 1), "_", " "))
    .withColumn("Session", regexp_extract(file_name_col, r"/\d{4}_[^/]+_([QR])\.csv$", 1))
)

In [6]:
# Create sql view
lap_data.createOrReplaceTempView("laps")

# Filter for only Race sessions
lap_data = spark.sql("""
    SELECT *
    FROM laps
    WHERE Session = 'R'
""")

25/05/09 20:41:43 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


**Fixing Datatypes**

In [7]:
# # Check datatypes
# lap_data.printSchema()

In [8]:
# Fix datatypes
lap_data = (
    lap_data
    .withColumn("LapSessionTime", regexp_replace(col("Time"), r"^0 days ", ""))
    .withColumn("DriverNumber", col("DriverNumber").cast(IntegerType()))
    .withColumn("LapTime", split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Stint", col("Stint").cast(IntegerType()))
    .withColumn("PitOutTime", split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("PitInTime", split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector1Time", split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector2Time", split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector3Time", split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector1SessionTime", split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector1SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector2SessionTime", split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector2SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector3SessionTime", split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector3SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("SpeedI1", col("SpeedI1").cast(IntegerType()))
    .withColumn("SpeedI2", col("SpeedI2").cast(IntegerType()))
    .withColumn("SpeedFL", col("SpeedFL").cast(IntegerType()))
    .withColumn("SpeedST", col("SpeedST").cast(IntegerType()))
    .withColumn("IsPersonalBest", col("IsPersonalBest").cast(BooleanType()))
    .withColumn("TyreLife", col("TyreLife").cast(IntegerType()))
    .withColumn("FreshTyre", col("FreshTyre").cast(BooleanType()))
    .withColumn("LapStartTime", split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapStartDate", to_timestamp("LapStartDate", "yyyy-MM-dd HH:mm:ss.SSS"))
    .withColumn("TrackStatus", col("TrackStatus").cast(IntegerType()))
    .withColumn("Position", col("Position").cast(IntegerType()))
    .withColumn("Deleted", col("Deleted").cast(BooleanType()))
    .withColumn("FastF1Generated", col("FastF1Generated").cast(BooleanType()))
    .withColumn("IsAccurate", col("IsAccurate").cast(BooleanType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
    .withColumn("LapSessionTime", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
)

lap_data = lap_data.drop(col("Time"))

In [9]:
# # Show the result
# lap_data.show(1)

**Feature Engineering**

In [10]:
# Define windows
start_position_window = Window.partitionBy("Year", "EventName", "Driver")
lap_order_window = start_position_window.orderBy("LapNumber")

In [11]:
# Creating new features
lap_data = (
    lap_data
    .withColumn("rolling_avg_laptime", avg("LapTime").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)))
    .withColumn("pit_in_lap", when(col("PitInTime").isNotNull(), 1).otherwise(0))
    .withColumn("pit_exit_lap", when(col("PitOutTime").isNotNull(), 1).otherwise(0))
    .withColumn(
        "last_pit_lap",
        coalesce(
            max("pit_exit_lap").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)),
            lit(0)
        )
    )
    .withColumn("laps_since_last_pit", col("LapNumber") - col("last_pit_lap"))
    .withColumn(
        "prev_compound", 
        when(
            col("LapNumber") == 1, col("Compound")
        ).otherwise(
            lag("Compound").over(lap_order_window)
        )
    )
    .withColumn(
        "pit_stop_duration",
        when(
            col("PitOutTime").isNull(),
            lit(0)
        ).otherwise(
            col("PitOutTime") - lag("PitInTime").over(lap_order_window)
        )
    )
    .withColumn("max_pit_stop_duration", max("pit_stop_duration").over(lap_order_window))
    .withColumn("start_position", first(when(col("LapNumber") == 1, col("Position")), ignorenulls=True).over(start_position_window))
    .withColumn("position_change_since_race_start", col("start_position") - col("Position"))
    .withColumn(
        "fastest_sector", when(
            (col("Sector1Time") <= col("Sector2Time")) & (col("Sector1Time") <= col("Sector3Time")), 1
        ).when(
            (col("Sector2Time") <= col("Sector1Time")) & (col("Sector2Time") <= col("Sector3Time")), 2
        ).otherwise(3)
    )
)

lap_data = lap_data.drop("Sector1SessionTime", "Sector2SessionTime", "Sector3SessionTime", "DeletedReason", "IsAccurate", "LapStartDate")

In [12]:
# lap_data.show(1)

### Telemetry Data

In [13]:
# Extract the file name from the file path
file_name_col = input_file_name()

# Extract the event name and session from the file name
telemetry_data = (
    telemetry_data
    .withColumn("Year", regexp_extract(file_name_col, r"/(\d{4})_[^/]+_[QR]\.csv$", 1))
    .withColumn("EventName", regexp_replace(regexp_extract(file_name_col, r"/\d{4}_(.+)_[QR]\.csv$", 1), "_", " "))
    .withColumn("Session", regexp_extract(file_name_col, r"/\d{4}_[^/]+_([QR])\.csv$", 1))
)

In [14]:
# Create sql view
telemetry_data.createOrReplaceTempView("telemetry")

# Filter for only Race events
telemetry_data = spark.sql("""
    SELECT *
    FROM telemetry
    WHERE Session = 'R'
""")

**Fixing Datatypes**

In [15]:
# # Check datatypes
# telemetry_data.printSchema()

In [16]:
telemetry_data = (
    telemetry_data
    .withColumn("Date", to_timestamp("Date", "yyyy-MM-dd HH:mm:ss.SSS"))
    .withColumn("RPM", col("RPM").cast(IntegerType()))
    .withColumn("Speed", col("Speed").cast(IntegerType()))
    .withColumn("nGear", col("nGear").cast(IntegerType()))
    .withColumn("Throttle", col("Throttle").cast(IntegerType()))
    .withColumn("Brake", col("Brake").cast(BooleanType()).cast(IntegerType()))
    .withColumn("DRS", col("DRS").cast(IntegerType()))
    .withColumn(
        "DataCollectionTime", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn(
        "SessionTime", split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Distance", col("Distance").cast(FloatType()))
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
    .withColumn(
        "IsDRSActive", when(
            col("DRS").isin(10, 12, 14), 1
        ).otherwise(0)
    )
)

telemetry_data = telemetry_data.drop(col("Time"))

In [17]:
# # Show the result
# telemetry_data.show(1)

**Feature Engineering**

In [18]:
# Define window
window_spec = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").orderBy("SessionTime")
last_50_window = window_spec.rowsBetween(-49, 0)

In [19]:
# Compute per-lap aggregates
telemetry_data = (
    telemetry_data
    .withColumn("avg_speed_last_lap", avg("Speed").over(window_spec))
    .withColumn("max_speed_last_lap", max("Speed").over(window_spec))
    .withColumn("avg_throttle_last_lap", avg("Throttle").over(window_spec))
    .withColumn("avg_brake_last_lap", avg("Brake").over(window_spec))
    .withColumn("avg_rpm", avg("RPM").over(window_spec))
    .withColumn("gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("gear_change_count", sum("gear_change").over(window_spec))
    .withColumn(
        "DRS_activation_count",
        sum(
            when(
                (~lag("DRS").over(window_spec).isin(10, 12, 14)) & (col("DRS").isin(10, 12, 14)),
                1
            ).otherwise(0)
        ).over(window_spec.rowsBetween(Window.unboundedPreceding, 0))
    )
)

In [20]:
# Rolling features over last 50 telemetry rows
telemetry_data = (
    telemetry_data
    .withColumn("rolling_throttle_mean", avg("Throttle").over(last_50_window))
    .withColumn("rolling_brake_intensity", avg("Brake").over(last_50_window))
    .withColumn("rolling_gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("rolling_gear_change_rate", avg("rolling_gear_change").over(last_50_window))
    .withColumn("rolling_speed_mean", avg("Speed").over(last_50_window))
)

In [21]:
# Final sector features (define final 5% of distance per lap)
max_distance = telemetry_data.groupBy("Year", "EventName", "Driver", "LapNumber").agg(max("Distance").alias("max_dist"))
telemetry_data = telemetry_data.join(max_distance, on=["Year", "EventName", "Driver", "LapNumber"])
telemetry_data = telemetry_data.withColumn("in_final_sector", col("Distance") >= col("max_dist") * 0.95)

# Define new window
final_sector_window = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

telemetry_data = (
    telemetry_data
    .withColumn("final_sector_avg_speed", avg(when(col("in_final_sector"), col("Speed"))).over(final_sector_window))
    .withColumn("final_sector_throttle", avg(when(col("in_final_sector"), col("Throttle"))).over(final_sector_window))
    .withColumn("final_sector_brake", avg(when(col("in_final_sector"), col("Brake"))).over(final_sector_window))
)

In [22]:
# telemetry_data.show(1)

In [23]:
# Select final per-lap features
lap_feature_cols = [
    "EventName", "Driver", "LapNumber", "Year", "Session",
    "avg_speed_last_lap", "max_speed_last_lap",
    "avg_throttle_last_lap", "avg_brake_last_lap",
    "gear_change_count", "avg_rpm",
    "rolling_throttle_mean", "rolling_brake_intensity",
    "rolling_gear_change_rate", "rolling_speed_mean",
    "final_sector_avg_speed", "final_sector_throttle", 
    "final_sector_brake"
]

# For all columns, take the FIRST value per (Driver, LapNumber)
# Because window functions already populated each row with the same value within each lap
aggregated_laps = (
    telemetry_data
    .select(*lap_feature_cols)
    .groupBy("Year", "EventName", "Session", "Driver", "LapNumber")
    .agg(*[
        first(col_name).alias(col_name) 
        if col_name != "DRS_activation_count" 
        else last(col_name).alias(col_name) 
        for col_name in lap_feature_cols 
        if col_name not in ("Year", "EventName", "Session", "Driver", "LapNumber")
    ])
)

In [24]:
# aggregated_laps.show(5)

### Joining the Data

In [25]:
# Join lap_data and telemetry_data
data = (
    lap_data
    .alias('lap')
    .join(
        aggregated_laps.alias('telemetry')
        ,on=["Year", "EventName", "Session", "Driver", "LapNumber"]
        ,how="outer"
    )
)

In [26]:
# Create target variable
data = (
    data
    .withColumn(
        "WillPitNextLap", when(
            lead("PitInTime", 1).over(Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber")).isNotNull(), 1
        )
    .otherwise(0)
    )
)

# data = data.drop("PitInTime")

In [27]:
# data.show(1)

In [28]:
# data.printSchema()

**Handling Missing Values**

In [29]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

LapTime

In [30]:
# # Check missing values
# data.filter(col("LapTime").isNull()).count()

In [31]:
# # Check rows with missing values
# data.filter(col("LapTime").isNull()).show(5)

In [32]:
# Rows with NULL LapTime and Sector<1, 2, 3>Time are DNF so we drop these rows
data = (
    data
    .filter(~(
        col("Sector1Time").isNull() & 
        col("Sector2Time").isNull() & 
        col("Sector3Time").isNull() & 
        col("LapTime").isNull()
    ))
)

In [33]:
# Remove SectorTime columns
data = data.drop("Sector1Time", "Sector2Time", "Sector3Time")

In [34]:
# # Recheck
# data.filter(col("LapTime").isNull()).count()

In [35]:
# Fix missing values - compute by subtracting the time at the end and at the start of the lap
data = data.withColumn("LapTime", col("LapSessionTime") - col("LapStartTime"))

In [36]:
# # Recheck
# data.filter(col("LapTime").isNull()).count()

Missing values recheck, since DNF rows were removed.

In [37]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

SpeedI1, SpeedI2, SpeedFL, SpeedST

In [38]:
# # Check missing values
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [39]:
# Fill missing values - speed rolling average
driver_lap_window = Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber").rowsBetween(Window.unboundedPreceding, -1)

# List of columns to process
speed_cols = ["SpeedI1", "SpeedI2", "SpeedFL", "SpeedST"]

# Fill missing values
for col_name in speed_cols:
    cumulative_avg = avg(col(col_name)).over(driver_lap_window)
    data = (
        data
        .withColumn(
            col_name,
            when(col(col_name).isNull(), cumulative_avg).otherwise(col(col_name))
        )
    )

In [40]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [41]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).show(1)

In [42]:
# Fill missing values - teammate's speed in same lap

# Self-join on teammate info
teammate_join = data.alias("self").join(
    data.alias("tm"),
    on=[
        col("self.Year") == col("tm.Year"),
        col("self.EventName") == col("tm.EventName"),
        col("self.Session") == col("tm.Session"),
        col("self.Team") == col("tm.Team"),
        col("self.LapNumber") == col("tm.LapNumber"),
        col("self.Driver") != col("tm.Driver")
    ],
    how="left"
)

# Replace missing values from teammate values
updated_cols = [
    coalesce(col(f"self.{col_name}"), col(f"tm.{col_name}")).alias(col_name)
    if col_name in speed_cols else col(f"self.{col_name}")
    for col_name in data.columns
]

# Assign back to `data` (replacing the original one)
data = teammate_join.select(*updated_cols)

In [43]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

In [44]:
# Fill missing values - finish line speed with longest straight speed
data = (
    data
    .withColumn(
        "SpeedFL",
        when(col("SpeedFL").isNull(), col("SpeedST")).otherwise(col("SpeedFL"))
    )
)

In [45]:
# # Recheck
# data.filter(
#     col("SpeedI1").isNull() |
#     col("SpeedI2").isNull() |
#     col("SpeedFL").isNull() |
#     col("SpeedST").isNull()
# ).count()

pit_stop_duration

In [46]:
# # Check missing values
# data.filter(
#     col("pit_stop_duration").isNull()
# ).count()

In [47]:
# # Check missing values
# data.filter(
#     col("pit_stop_duration").isNull()
# ).show(3)

In [48]:
# These values seem to be mistakes, so we set PitOutTime to NULL and recompute pit_stop_duration and max_pit_stop_duration
data = (
    data
    .withColumn(
        "PitOutTime",
        when(col("pit_stop_duration").isNull(), None).otherwise(col("PitOutTime"))
    )
    .withColumn(
        "pit_stop_duration",
        when(
            col("PitOutTime").isNull(),
            lit(0)
        ).otherwise(
            col("PitOutTime") - lag("PitInTime").over(lap_order_window)
        )
    )
    .withColumn("max_pit_stop_duration", max("pit_stop_duration").over(lap_order_window))
)

In [49]:
# # Recheck
# data.filter(
#     col("pit_stop_duration").isNull()
# ).count()

max_pit_stop_duration

In [50]:
# # Check missing values
# data.filter(
#     col("max_pit_stop_duration").isNull()
# ).count()

# # No more :)

In [51]:
# # Compute null counts
# null_counts = data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns])

# # Convert to a Row to filter in Python
# null_counts_dict = null_counts.first().asDict()

# # Filter and print only columns with nulls
# for col_name, count in null_counts_dict.items():
#     if count > 0:
#         print(f"{col_name}: {count}")

## Data Modelling

In [52]:
# Convert categorical features to numerical values
indexers = [
    StringIndexer(inputCol="Team", outputCol="TeamIndex"),
    StringIndexer(inputCol="Compound", outputCol="CompoundIndex"),
    StringIndexer(inputCol="Driver", outputCol="DriverIndex"),
    StringIndexer(inputCol="EventName", outputCol="EventNameIndex")
]

In [53]:
# Assemble all features into a single vector column for the model
feature_columns = [
    # Telemetry / strategy features
    'Year', 'LapNumber', 'Stint', 'TyreLife', 'FreshTyre', 'TrackStatus',
    'Position', 'rolling_avg_laptime', 'laps_since_last_pit',
    'pit_stop_duration', 'max_pit_stop_duration',
    'position_change_since_race_start',
    'avg_speed_last_lap', 'avg_throttle_last_lap', 'avg_brake_last_lap',
    'avg_rpm', 'gear_change_count',
    'rolling_throttle_mean', 'rolling_brake_intensity',
    'rolling_gear_change_rate', 'rolling_speed_mean',
    'final_sector_avg_speed', 'final_sector_throttle', 'final_sector_brake',

    # Categorical features (indexed)
    'TeamIndex', 'CompoundIndex', 'DriverIndex', 'EventNameIndex'
]
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")

In [54]:
# Train test split
train_data = data.filter(~((col("EventName") == "Abu Dhabi Grand Prix") & (col("Year") == 2023)))
test_data = data.filter((col("EventName") == "Abu Dhabi Grand Prix") & (col("Year") == 2023))

In [56]:
# Build the model
classifier = RandomForestClassifier(labelCol="WillPitNextLap", featuresCol="features")

# Define the pipeline
pipeline = Pipeline(stages=indexers + [assembler, classifier])

In [57]:
# Train the model
model = pipeline.fit(train_data)



In [58]:
# Make predictions on the test set
predictions = model.transform(test_data)

In [60]:
predictions.select("features", "WillPitNextLap", "prediction").show()



+--------------------+--------------+----------+
|            features|WillPitNextLap|prediction|
+--------------------+--------------+----------+
|[2023.0,1.0,1.0,2...|             0|       0.0|
|[2023.0,2.0,1.0,3...|             0|       0.0|
|[2023.0,3.0,1.0,4...|             0|       0.0|
|[2023.0,4.0,1.0,5...|             0|       0.0|
|[2023.0,5.0,1.0,6...|             0|       0.0|
|[2023.0,6.0,1.0,7...|             0|       0.0|
|[2023.0,7.0,1.0,8...|             0|       0.0|
|[2023.0,8.0,1.0,9...|             0|       0.0|
|[2023.0,9.0,1.0,1...|             0|       0.0|
|[2023.0,10.0,1.0,...|             0|       0.0|
|[2023.0,11.0,1.0,...|             0|       0.0|
|[2023.0,12.0,1.0,...|             0|       0.0|
|[2023.0,13.0,1.0,...|             0|       0.0|
|[2023.0,14.0,1.0,...|             0|       0.0|
|[2023.0,15.0,1.0,...|             1|       0.0|
|[2023.0,16.0,1.0,...|             0|       0.0|
|[2023.0,17.0,2.0,...|             0|       0.0|
|[2023.0,18.0,2.0,..

                                                                                

In [61]:
# Define the evaluator
evaluator = BinaryClassificationEvaluator(labelCol="WillPitNextLap", metricName="areaUnderPR")

In [62]:
# Evaluate and print the areaUnderPR score
aupr = evaluator.evaluate(predictions)
print(f"areaUnderPR: {aupr}")



areaUnderPR: 0.13370017139227544


                                                                                