# **F1 Pit Stop Prediction**

This project was developed using the **FastF1** library data and aims to predict if the driver will be going to the pits in the next lap.<br>
To collect the data, the team developed a `data_collection.py` script. This script allows the user to collect lap, telemetry and weather data.

The metadata of the datasets used in this project can be found at the **FastF1** documentation page, at https://docs.fastf1.dev/core.html. (Visited on May 10th, 2025)

## 0. Imports, Constants and Helper Functions

**Libraries**

In [1]:
# Import libraries
import os
from pathlib import Path
import math

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D
import seaborn as sns

from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name, regexp_extract, regexp_replace, col, when, to_timestamp, lead, avg, lag, max, first, last, split, coalesce, lit, row_number, sum
from pyspark.sql.types import IntegerType, BooleanType, FloatType
from pyspark.sql.window import Window

from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.evaluation import BinaryClassificationEvaluator

**Constants**

In [2]:
# Define constants

# Directories
ROOT = Path.cwd().parent
DATA_DIR = ROOT / "data"

# Colors
BLUE = '#003773'
RED = '#ED1E36'
YELLOW = '#FABB23'
TEAM_COLORS = {
    'Red Bull Racing': BLUE
    ,'Mercedes': '#00A19B'
    ,'Ferrari': '#A6051A'
    ,'McLaren': '#FF8700'
    ,'Aston Martin': '#00665E'
    ,'Alpine': '#FD4BC7'
    ,'AlphaTauri': '#00293F'
    ,'Alfa Romeo': '#972738'
    ,'Haas F1 Team': '#AEAEAE'
    ,'Williams': '#00A3E0'
}

# Colormap
CMAP = (
    mcolors
    .LinearSegmentedColormap
    .from_list(
        "my_cmap"
        ,[BLUE, YELLOW, RED]
    )
)

COMPOUND_COLORS = {
    'SOFT': '#F20704'
    ,'MEDIUM': '#FACA08'
    ,'HARD': '#000000'
    ,'INTERMEDIATE': '#029405'
    ,'WET': '#078CD1'
}

# Plot constants
EVENT_PLT = 'British Grand Prix'
YEAR_PLT = 2022
DRIVER_PLT = 'VER'

# Features
LAP_METRIC = [
    'LapTime'
    ,'LapNumber'
    ,'Stint'
    ,'PitOutTime'
    ,'PitInTime'
    ,'Sector1Time'
    ,'Sector2Time'
    ,'Sector3Time'
    ,'SpeedI1'
    ,'SpeedI2'
    ,'SpeedFL'
    ,'SpeedST'
    ,'TyreLife'
    ,'LapStartTime'
    ,'LapSessionTime'
]
LAP_CATEGORICAL = [
    'Driver'
    ,'DriverNumber'
    ,'IsPersonalBest'
    ,'Compound'
    ,'FreshTyre'
    ,'Team'
    ,'TrackStatus'
    ,'Position'
    ,'Deleted'
    ,'FastF1Generated'
    ,'IsAccurate'
    ,'Year'
    ,'EventName'
    ,'Session'
]
TELEMETRY_METRIC = [
    'RPM'
    ,'Speed'
    ,'Throttle'
    ,'SessionTime'
    ,'Distance'
    ,'LapNumber'
]
TELEMETRY_CATEGORICAL = [
    'nGear'
    ,'Brake'
    ,'DRS'
    ,'Driver'
    ,'Year'
    ,'EventName'
    ,'Session'
    ,'IsDRSActive'
]
WEATHER_METRIC = [
    'Time'
    ,'AirTemp'
    ,'Humidity'
    ,'Pressure'
    ,'TrackTemp'
    ,'WindSpeed'
]
WEATHER_CATEGORICAL = [
    'Rainfall'
    ,'WindDirection'
    ,'Year'
    ,'EventName'
    ,'Session'
]

**Helper Functions**

In [3]:
# Define helper function for number of bins in histogram
def sturges_bins(data, column_name):
    """
    Calculates the number of bins for a histogram using Sturges' Rule.
    
    Parameters:
    data (pd.DataFrame): The dataset containing the column.
    column_name (str): The name of the column to calculate the number of bins for.
    
    Returns:
    int: The number of bins calculated using Sturges' Rule.
    """
    # Number of data points
    n = len(data[column_name])
    
    # Calculate the number of bins using Sturges' Rule
    k = int(np.ceil(np.log2(n) + 1))
    
    return k

## 1. Data Loading

In this section, the spark session is initiliazed and the data is ingested.

In [4]:
# Initialize Spark Session
spark = SparkSession.builder \
    .appName("Lap Data Aggregation") \
    .master("local[*]") \
    .config("spark.driver.memory", "12g") \
    .config("spark.executor.memory", "12g") \
    .getOrCreate()

`TODO`: Implement data to be loaded from databricks

Upload the file through the UI
- Go to the Data tab on the left sidebar in your Databricks workspace.
- Click "Add Data" → then choose "Upload File".
- Upload your file (e.g., CSV, JSON, Parquet).
- Databricks will store it in something like:
    /FileStore/tables/your_filename.csv

You can then access it like this:
df = spark.read.option("header", "true").csv("/FileStore/tables/your_filename.csv")
df.show()


### 1.1. Lap Data

In [5]:
# Load the data
lap_data = (
    spark.read
    .option("header", True)
    .option("inferSchema", "true")
    .csv(f"{DATA_DIR}/laps.csv")
)

#### 1.1.1. Feature Analysis

After analysing the available metadata, some features are removed from the dataframe, as they do not bring necessary value to the problem solution.

In [6]:
# Dropping irrelevant columns
lap_data = lap_data.drop(
    "Sector1SessionTime"
    ,"Sector2SessionTime"
    ,"Sector3SessionTime"
    ,"LapStartDate"
    ,"DeletedReason"
)

#### 1.1.2. Fixing Datatypes

In [7]:
# Check datatypes
lap_data.printSchema()

root
 |-- Time: string (nullable = true)
 |-- Driver: string (nullable = true)
 |-- DriverNumber: integer (nullable = true)
 |-- LapTime: string (nullable = true)
 |-- LapNumber: double (nullable = true)
 |-- Stint: double (nullable = true)
 |-- PitOutTime: string (nullable = true)
 |-- PitInTime: string (nullable = true)
 |-- Sector1Time: string (nullable = true)
 |-- Sector2Time: string (nullable = true)
 |-- Sector3Time: string (nullable = true)
 |-- SpeedI1: double (nullable = true)
 |-- SpeedI2: double (nullable = true)
 |-- SpeedFL: double (nullable = true)
 |-- SpeedST: double (nullable = true)
 |-- IsPersonalBest: boolean (nullable = true)
 |-- Compound: string (nullable = true)
 |-- TyreLife: double (nullable = true)
 |-- FreshTyre: boolean (nullable = true)
 |-- Team: string (nullable = true)
 |-- LapStartTime: string (nullable = true)
 |-- TrackStatus: integer (nullable = true)
 |-- Position: double (nullable = true)
 |-- Deleted: boolean (nullable = true)
 |-- FastF1Generat

In [8]:
# Fix datatypes
lap_data = (
    lap_data
    .withColumn("LapSessionTime", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapTime", split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Stint", col("Stint").cast(IntegerType()))
    .withColumn("PitOutTime", split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitOutTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("PitInTime", split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("PitInTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector1Time", split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector1Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector2Time", split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector2Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Sector3Time", split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Sector3Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("SpeedI1", col("SpeedI1").cast(IntegerType()))
    .withColumn("SpeedI2", col("SpeedI2").cast(IntegerType()))
    .withColumn("SpeedFL", col("SpeedFL").cast(IntegerType()))
    .withColumn("SpeedST", col("SpeedST").cast(IntegerType()))
    .withColumn("TyreLife", col("TyreLife").cast(IntegerType()))
    .withColumn("LapStartTime", split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("LapStartTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Position", col("Position").cast(IntegerType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
)

lap_data = lap_data.drop(col("Time"))

In [9]:
# Show the result
lap_data.show(1)

+------+------------+------------------+---------+-----+----------+---------+-----------+-----------+-----------+-------+-------+-------+-------+--------------+--------+--------+---------+---------------+------------+-----------+--------+-------+---------------+----------+----+------------------+-------+--------------+
|Driver|DriverNumber|           LapTime|LapNumber|Stint|PitOutTime|PitInTime|Sector1Time|Sector2Time|Sector3Time|SpeedI1|SpeedI2|SpeedFL|SpeedST|IsPersonalBest|Compound|TyreLife|FreshTyre|           Team|LapStartTime|TrackStatus|Position|Deleted|FastF1Generated|IsAccurate|Year|         EventName|Session|LapSessionTime|
+------+------------+------------------+---------+-----+----------+---------+-----------+-----------+-----------+-------+-------+-------+-------+--------------+--------+--------+---------+---------------+------------+-----------+--------+-------+---------------+----------+----+------------------+-------+--------------+
|   VER|           1|100.235999999999

### 1.2. Telemetry Data

In [10]:
telemetry_data = (
    spark.read
    .option("header", True)
    .option("inferSchema", "true")
    .csv(f"{DATA_DIR}/telemetry.csv")
)

#### 1.2.1. Feature Analysis

After analysing the available metadata, some features are removed from the dataframe, as they do not bring necessary value to the problem solution.

In [11]:
telemetry_data = telemetry_data.drop(
    "Date"
    ,"DataCollectionTime"
    ,"Time"
    ,"Source"
)

#### 1.2.2. Fixing Datatypes

In [12]:
# Check datatypes
telemetry_data.printSchema()

root
 |-- RPM: double (nullable = true)
 |-- Speed: double (nullable = true)
 |-- nGear: integer (nullable = true)
 |-- Throttle: double (nullable = true)
 |-- Brake: boolean (nullable = true)
 |-- DRS: integer (nullable = true)
 |-- SessionTime: string (nullable = true)
 |-- Distance: double (nullable = true)
 |-- Driver: string (nullable = true)
 |-- LapNumber: double (nullable = true)
 |-- Year: integer (nullable = true)
 |-- EventName: string (nullable = true)
 |-- Session: string (nullable = true)



In [13]:
telemetry_data = (
    telemetry_data
    .withColumn("RPM", col("RPM").cast(IntegerType()))
    .withColumn("Speed", col("Speed").cast(IntegerType()))
    .withColumn("Throttle", col("Throttle").cast(IntegerType()))
    .withColumn("Brake", col("Brake").cast(BooleanType()).cast(IntegerType()))
    .withColumn(
        "SessionTime", split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("SessionTime"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("LapNumber", col("LapNumber").cast(IntegerType()))
    .withColumn("Year", col("Year").cast(IntegerType()))
    .withColumn(
        "IsDRSActive", when(
            col("DRS").isin(10, 12, 14), 1
        ).otherwise(0)
    )
)

In [14]:
# Show the result
telemetry_data.show(1)

+----+-----+-----+--------+-----+---+-----------+--------+------+---------+----+------------------+-------+-----------+
| RPM|Speed|nGear|Throttle|Brake|DRS|SessionTime|Distance|Driver|LapNumber|Year|         EventName|Session|IsDRSActive|
+----+-----+-----+--------+-----+---+-----------+--------+------+---------+----+------------------+-------+-----------+
|9802|    0|    1|      10|    1|  1|   3754.971|     0.0|   VER|        1|2022|Bahrain Grand Prix|      R|          0|
+----+-----+-----+--------+-----+---+-----------+--------+------+---------+----+------------------+-------+-----------+
only showing top 1 row



### 1.3. Weather Data

In [15]:
weather_data = (
    spark.read
    .option("header", True)
    .option("inferSchema", "true")
    .csv(f"{DATA_DIR}/weather.csv")
)

#### 1.3.1. Fixing Datatypes

In [16]:
# Check datatypes
weather_data.printSchema()

root
 |-- Time: string (nullable = true)
 |-- AirTemp: double (nullable = true)
 |-- Humidity: double (nullable = true)
 |-- Pressure: double (nullable = true)
 |-- Rainfall: boolean (nullable = true)
 |-- TrackTemp: double (nullable = true)
 |-- WindDirection: integer (nullable = true)
 |-- WindSpeed: double (nullable = true)
 |-- Year: integer (nullable = true)
 |-- EventName: string (nullable = true)
 |-- Session: string (nullable = true)



In [17]:
weather_data = (
    weather_data
    .withColumn(
        "Time", split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(0).cast("int") * 3600 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(1).cast("int") * 60 +
        split(regexp_replace(col("Time"), r"^0 days ", ""), ":").getItem(2).cast("double")
    )
    .withColumn("Rainfall", col("Rainfall").cast(IntegerType()))
)

In [18]:
# Show the result
weather_data.show(1)

+------+-------+--------+--------+--------+---------+-------------+---------+----+------------------+-------+
|  Time|AirTemp|Humidity|Pressure|Rainfall|TrackTemp|WindDirection|WindSpeed|Year|         EventName|Session|
+------+-------+--------+--------+--------+---------+-------------+---------+----+------------------+-------+
|63.204|   25.6|    17.0|  1010.2|       0|     32.3|          346|      0.5|2022|Bahrain Grand Prix|      R|
+------+-------+--------+--------+--------+---------+-------------+---------+----+------------------+-------+
only showing top 1 row



## 2. Data Visualization

### 2.1. Lap Data

In [19]:
# Convert the dataframe to pandas for visualization purposes
lap_pd = lap_data.toPandas()

#### 2.1.1. Histograms

#### 2.1.2. Boxplots

#### 2.1.3. Other Plots

This plot shows the majority of drivers pit when there is a track status change. All drivers racing during the 3rd moment of track status change but two decided to pit.

By analysing this plot, it is possible to conclude the first couple of laps on a new set of tyres are usually very slow compared to the remaining.

### 2.2. Telemetry Data

In [20]:
# Filter the data, since the dataset is too large and convert to pandas for visualization purposes
telemetry_pd = (
    telemetry_data
    .filter(
        (col("EventName") == EVENT_PLT) &
        (col("Year") == YEAR_PLT)
    )
    .toPandas()
)

#### 2.2.1. Histograms

#### 2.2.2. Boxplots

#### 2.2.3. Other Plots

### 2.3. Weather Data

In [21]:
# Filter the data, since the dataset is too large and convert to pandas for visualization purposes
weather_pd = (
    weather_data
    .filter(
        (col("EventName") == EVENT_PLT) &
        (col("Year") == YEAR_PLT)
    )
    .toPandas()
)

#### 2.3.1. Histograms

#### 2.2.2. Boxplots

#### 2.2.3. Other Plots

## 3. Data Processing

In this section, inconsistencies and missing values are handled.

### 3.1. Inconsistencies

In [22]:
# Checking laps with no time information
lap_data.filter((
    col("Sector1Time").isNull() & 
    col("Sector2Time").isNull() & 
    col("Sector3Time").isNull()
)).show(5)

+------+------------+-------+---------+-----+----------+---------+-----------+-----------+-----------+-------+-------+-------+-------+--------------+--------+--------+---------+---------------+------------+-----------+--------+-------+---------------+----------+----+--------------------+-------+--------------+
|Driver|DriverNumber|LapTime|LapNumber|Stint|PitOutTime|PitInTime|Sector1Time|Sector2Time|Sector3Time|SpeedI1|SpeedI2|SpeedFL|SpeedST|IsPersonalBest|Compound|TyreLife|FreshTyre|           Team|LapStartTime|TrackStatus|Position|Deleted|FastF1Generated|IsAccurate|Year|           EventName|Session|LapSessionTime|
+------+------------+-------+---------+-----+----------+---------+-----------+-----------+-----------+-------+-------+-------+-------+--------------+--------+--------+---------+---------------+------------+-----------+--------+-------+---------------+----------+----+--------------------+-------+--------------+
|   GAS|          10|   NULL|       45|    3|      NULL|     NUL

In [23]:
# Flagging drivers who DNF on the following lap
window_spec = Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber")

lap_data = lap_data.withColumn(
    'DNF',
    when(
        lag(
            when(
                (col("Sector1Time").isNull() & 
                 col("Sector2Time").isNull() & 
                 col("Sector3Time").isNull()),
                1
            ).otherwise(0),
            -1
        ).over(window_spec) == 1,
        1
    ).otherwise(0)
)

In [24]:
# Removing them rows
lap_data = (
    lap_data
    .filter(~(
        col("Sector1Time").isNull() & 
        col("Sector2Time").isNull() & 
        col("Sector3Time").isNull()
    ))
)

### 3.2. Missing Values

#### 3.2.1. Lap Data

In [25]:
# Compute null counts
null_counts = lap_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in lap_data.columns])

# Convert to a Row to filter in Python
null_counts_dict = null_counts.first().asDict()

# Filter and print only columns with nulls
for col_name, count in null_counts_dict.items():
    if count > 0:
        print(f"{col_name}: {count}")

LapTime: 1074
PitOutTime: 71889
PitInTime: 71860
Sector1Time: 1465
Sector2Time: 5
Sector3Time: 115
SpeedI1: 11226
SpeedI2: 26
SpeedFL: 2629
SpeedST: 5998


**Lap Time**

In [26]:
# Check missing values
lap_data.filter(col("LapTime").isNull()).count()

1074

In [27]:
# Fix missing values - compute by subtracting the time at the end and at the start of the lap
lap_data = lap_data.withColumn("LapTime", col("LapSessionTime") - col("LapStartTime"))

In [28]:
# Recheck
lap_data.filter(col("LapTime").isNull()).count()

0

**PitOutTime**, **PitInTime** - Does not make sense to fill in these missing values; these features will be used for feature engineering.

**Sector1Time**, **Sector2Time**, **Sector3Time** - These features will be used for feature engineering, no need to fill them.

**SpeedI1**, **SpeedI2**, **SpeedFL**, **SpeedST**

In [29]:
# Check missing values
lap_data.filter(
    col("SpeedI1").isNull() |
    col("SpeedI2").isNull() |
    col("SpeedFL").isNull() |
    col("SpeedST").isNull()
).count()

18057

In [30]:
# Fill missing values - speed rolling average
driver_lap_window = (
    Window
    .partitionBy("Year", "EventName", "Session", "Driver")
    .orderBy("LapNumber")
    .rowsBetween(Window.unboundedPreceding, -1)
)

# List of columns to process
speed_cols = ["SpeedI1", "SpeedI2", "SpeedFL", "SpeedST"]

# Fill missing values
for col_name in speed_cols:
    cumulative_avg = avg(col(col_name)).over(driver_lap_window)
    lap_data = (
        lap_data
        .withColumn(
            col_name,
            when(col(col_name).isNull(), cumulative_avg).otherwise(col(col_name))
        )
    )

In [31]:
# Recheck
lap_data.filter(
    col("SpeedI1").isNull() |
    col("SpeedI2").isNull() |
    col("SpeedFL").isNull() |
    col("SpeedST").isNull()
).count()

143

In [32]:
# Fill missing values - teammate's speed in same lap

# Self-join on teammate info
teammate_join = lap_data.alias("self").join(
    lap_data.alias("tm"),
    on=[
        col("self.Year") == col("tm.Year"),
        col("self.EventName") == col("tm.EventName"),
        col("self.Session") == col("tm.Session"),
        col("self.Team") == col("tm.Team"),
        col("self.LapNumber") == col("tm.LapNumber"),
        col("self.Driver") != col("tm.Driver")
    ],
    how="left"
)

# Replace missing values from teammate values
updated_cols = [
    coalesce(col(f"self.{col_name}"), col(f"tm.{col_name}")).alias(col_name)
    if col_name in speed_cols else col(f"self.{col_name}")
    for col_name in lap_data.columns
]

lap_data = teammate_join.select(*updated_cols)

In [33]:
# Recheck
lap_data.filter(
    col("SpeedI1").isNull() |
    col("SpeedI2").isNull() |
    col("SpeedFL").isNull() |
    col("SpeedST").isNull()
).count()

92

In [34]:
# Fill missing values - finish line speed with longest straight speed
lap_data = (
    lap_data
    .withColumn(
        "SpeedFL",
        when(col("SpeedFL").isNull(), col("SpeedST")).otherwise(col("SpeedFL"))
    )
)

In [35]:
# Recheck
lap_data.filter(
    col("SpeedI1").isNull() |
    col("SpeedI2").isNull() |
    col("SpeedFL").isNull() |
    col("SpeedST").isNull()
).count()

0

#### 3.2.2. Telemetry Data

In [36]:
# Compute null counts
null_counts = telemetry_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in telemetry_data.columns])

# Convert to a Row to filter in Python
null_counts_dict = null_counts.first().asDict()

# Filter and print only columns with nulls
for col_name, count in null_counts_dict.items():
    if count > 0:
        print(f"{col_name}: {count}")

No missing values :)

#### 3.2.3. Weather Data

In [37]:
# Compute null counts
null_counts = weather_data.select([sum(col(c).isNull().cast("int")).alias(c) for c in weather_data.columns])

# Convert to a Row to filter in Python
null_counts_dict = null_counts.first().asDict()

# Filter and print only columns with nulls
for col_name, count in null_counts_dict.items():
    if count > 0:
        print(f"{col_name}: {count}")

No missing values :)

## 4. Data Engineering

### 4.1. Lap Data

In [38]:
# Define windows
start_position_window = Window.partitionBy("Year", "EventName", "Session", "Driver")
lap_order_window = start_position_window.orderBy("LapNumber")

In [39]:
# Creating new features
lap_data = (
    lap_data
    .withColumn("rolling_avg_laptime", avg("LapTime").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)))
    .withColumn("pit_in_lap", when(col("PitInTime").isNotNull(), 1).otherwise(0))
    .withColumn("pit_exit_lap", when(col("PitOutTime").isNotNull(), 1).otherwise(0))
    .withColumn(
        "last_pit_lap",
        coalesce(
            max("pit_exit_lap").over(lap_order_window.rowsBetween(Window.unboundedPreceding, 0)),
            lit(0)
        )
    )
    .withColumn("laps_since_last_pit", col("LapNumber") - col("last_pit_lap"))
    .withColumn(
        "prev_compound", 
        when(
            col("LapNumber") == 1, col("Compound")
        ).otherwise(
            lag("Compound").over(lap_order_window)
        )
    )
    .withColumn(
        "pit_stop_duration",
        when(
            (col("PitOutTime").isNull()) | ((col("PitOutTime").isNotNull()) & (col("LapNumber") == 1)),
            lit(0)
        ).otherwise(
            col("PitOutTime") - lag("PitInTime").over(lap_order_window)
        )
    )
    .withColumn("max_pit_stop_duration", max("pit_stop_duration").over(lap_order_window))
    .withColumn("start_position", first(when(col("LapNumber") == 1, col("Position")), ignorenulls=True).over(start_position_window))
    .withColumn("position_change_since_race_start", col("start_position") - col("Position"))
    .withColumn(
        "fastest_sector", when(
            (col("Sector1Time") <= col("Sector2Time")) & (col("Sector1Time") <= col("Sector3Time")), 1
        ).when(
            (col("Sector2Time") <= col("Sector1Time")) & (col("Sector2Time") <= col("Sector3Time")), 2
        ).otherwise(3)
    )
)

In [40]:
lap_data = lap_data.drop(
    "Sector1Time", "Sector2Time", "Sector3Time", "PitOutTime"
)

In [41]:
lap_data.show(1)

+------+------------+-----------------+---------+-----+---------+-------+-------+-------+-------+--------------+--------+--------+---------+------+------------+-----------+--------+-------+---------------+----------+----+--------------------+-------+--------------+---+-------------------+----------+------------+------------+-------------------+-------------+-----------------+---------------------+--------------+--------------------------------+--------------+
|Driver|DriverNumber|          LapTime|LapNumber|Stint|PitInTime|SpeedI1|SpeedI2|SpeedFL|SpeedST|IsPersonalBest|Compound|TyreLife|FreshTyre|  Team|LapStartTime|TrackStatus|Position|Deleted|FastF1Generated|IsAccurate|Year|           EventName|Session|LapSessionTime|DNF|rolling_avg_laptime|pit_in_lap|pit_exit_lap|last_pit_lap|laps_since_last_pit|prev_compound|pit_stop_duration|max_pit_stop_duration|start_position|position_change_since_race_start|fastest_sector|
+------+------------+-----------------+---------+-----+---------+-------

### 4.2. Telemetry Data

In [42]:
# Define window
window_spec = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").orderBy("SessionTime")
last_50_window = window_spec.rowsBetween(-49, 0)

In [43]:
# Compute per-lap aggregates
telemetry_data = (
    telemetry_data
    .withColumn("avg_speed_last_lap", avg("Speed").over(window_spec))
    .withColumn("max_speed_last_lap", max("Speed").over(window_spec))
    .withColumn("avg_throttle_last_lap", avg("Throttle").over(window_spec))
    .withColumn("avg_brake_last_lap", avg("Brake").over(window_spec))
    .withColumn("avg_rpm", avg("RPM").over(window_spec))
    .withColumn("gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("gear_change_count", sum("gear_change").over(window_spec))
    .withColumn(
        "DRS_activation_count",
        sum(
            when(
                (~lag("DRS").over(window_spec).isin(10, 12, 14)) & (col("DRS").isin(10, 12, 14)),
                1
            ).otherwise(0)
        ).over(window_spec.rowsBetween(Window.unboundedPreceding, 0))
    )
    .withColumn("rolling_throttle_mean", avg("Throttle").over(last_50_window))
    .withColumn("rolling_brake_intensity", avg("Brake").over(last_50_window))
    .withColumn("rolling_gear_change", when(col("nGear") != lag("nGear").over(window_spec), 1).otherwise(0))
    .withColumn("rolling_gear_change_rate", avg("rolling_gear_change").over(last_50_window))
    .withColumn("rolling_speed_mean", avg("Speed").over(last_50_window))
)

In [44]:
# Final sector features (define final 5% of distance per lap)
max_distance = telemetry_data.groupBy("Year", "EventName", "Driver", "LapNumber").agg(max("Distance").alias("max_dist"))
telemetry_data = telemetry_data.join(max_distance, on=["Year", "EventName", "Driver", "LapNumber"])
telemetry_data = telemetry_data.withColumn("in_final_sector", col("Distance") >= col("max_dist") * 0.95)

# Define new window
final_sector_window = Window.partitionBy("Year", "EventName", "Driver", "LapNumber").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

telemetry_data = (
    telemetry_data
    .withColumn("final_sector_avg_speed", avg(when(col("in_final_sector"), col("Speed"))).over(final_sector_window))
    .withColumn("final_sector_throttle", avg(when(col("in_final_sector"), col("Throttle"))).over(final_sector_window))
    .withColumn("final_sector_brake", avg(when(col("in_final_sector"), col("Brake"))).over(final_sector_window))
)

In [45]:
# telemetry_data.show(1)

In [46]:
# Select final per-lap features
lap_feature_cols = [
    "EventName", "Driver", "LapNumber", "Year", "Session",
    "avg_speed_last_lap", "max_speed_last_lap",
    "avg_throttle_last_lap", "avg_brake_last_lap",
    "gear_change_count", "avg_rpm", "DRS_activation_count",
    "rolling_throttle_mean", "rolling_brake_intensity",
    "rolling_gear_change_rate", "rolling_speed_mean",
    "final_sector_avg_speed", "final_sector_throttle", 
    "final_sector_brake"
]

# For all columns, take the FIRST value per (Driver, LapNumber)
# Because window functions already populated each row with the same value within each lap
aggregated_laps = (
    telemetry_data
    .select(*lap_feature_cols)
    .groupBy("Year", "EventName", "Session", "Driver", "LapNumber")
    .agg(*[
        first(col_name).alias(col_name) 
        if col_name != "DRS_activation_count" 
        else last(col_name).alias(col_name) 
        for col_name in lap_feature_cols 
        if col_name not in ("Year", "EventName", "Session", "Driver", "LapNumber")
    ])
)

### 4.3. Weather Data

In [47]:

# Join weather and lap data on session keys
joined = weather_data.alias("w").join(
    lap_data.select(
        "LapNumber", "LapStartTime", "LapSessionTime", "Year", "EventName", "Session"
    ).alias("l"),
    on=[
        col("w.Year") == col("l.Year"),
        col("w.EventName") == col("l.EventName"),
        col("w.Session") == col("l.Session")
    ],
    how="left"
)

# Filter to keep only rows where weather Time is inside lap range
filtered = joined.filter(
    (col("w.Time") >= col("l.LapStartTime")) &
    (col("w.Time") <= col("l.LapSessionTime"))
)

# Step 3: Select columns and reattach to full weather_data
weather_data = filtered.select(
    col("w.*"), 
    col("l.LapNumber")
)

In [48]:
# Define window
# window_spec = Window.partitionBy("Year", "EventName", "LapNumber").orderBy("Time")

In [49]:
# Compute per-lap aggregates
weather_data = (
    weather_data
    .groupBy("Year", "EventName", "Session", "LapNumber")
    .agg(
        avg("AirTemp").alias("avg_air_temp"),
        avg("Humidity").alias("avg_humidity"),
        avg("Pressure").alias("avg_pressure"),
        max("Rainfall").alias("max_rainfall"),
        avg("TrackTemp").alias("avg_track_temp"),
        avg("WindDirection").alias("avg_wind_direction"),
        avg("WindSpeed").alias("avg_wind_speed")
    )
)

### Joining the Data

In [50]:
# Join lap_data and telemetry_data
data = (
    lap_data.alias('lap')
    .join(
        aggregated_laps.alias('telemetry'),
        on=["Year", "EventName", "Session", "Driver", "LapNumber"],
        how="inner"
    )
    .join(
        weather_data.alias('weather'),  # <- use the aggregated weather data
        on=["Year", "EventName", "Session", "LapNumber"],
        how="inner"
    )
)

In [51]:
# Create target variable
data = (
    data
    .withColumn(
        "WillPitNextLap", when(
            lead("PitInTime", 1).over(Window.partitionBy("Year", "EventName", "Session", "Driver").orderBy("LapNumber")).isNotNull(), 1
        )
    .otherwise(0)
    .cast(IntegerType())
    )
)

data = data.drop("PitInTime")

In [52]:
# data.printSchema()

## Data Modelling

In [53]:
data.columns

['Year',
 'EventName',
 'Session',
 'LapNumber',
 'Driver',
 'DriverNumber',
 'LapTime',
 'Stint',
 'SpeedI1',
 'SpeedI2',
 'SpeedFL',
 'SpeedST',
 'IsPersonalBest',
 'Compound',
 'TyreLife',
 'FreshTyre',
 'Team',
 'LapStartTime',
 'TrackStatus',
 'Position',
 'Deleted',
 'FastF1Generated',
 'IsAccurate',
 'LapSessionTime',
 'DNF',
 'rolling_avg_laptime',
 'pit_in_lap',
 'pit_exit_lap',
 'last_pit_lap',
 'laps_since_last_pit',
 'prev_compound',
 'pit_stop_duration',
 'max_pit_stop_duration',
 'start_position',
 'position_change_since_race_start',
 'fastest_sector',
 'avg_speed_last_lap',
 'max_speed_last_lap',
 'avg_throttle_last_lap',
 'avg_brake_last_lap',
 'gear_change_count',
 'avg_rpm',
 'DRS_activation_count',
 'rolling_throttle_mean',
 'rolling_brake_intensity',
 'rolling_gear_change_rate',
 'rolling_speed_mean',
 'final_sector_avg_speed',
 'final_sector_throttle',
 'final_sector_brake',
 'avg_air_temp',
 'avg_humidity',
 'avg_pressure',
 'max_rainfall',
 'avg_track_temp',
 'av

In [54]:
data = data.drop("Session")

In [55]:
data.columns

['Year',
 'EventName',
 'LapNumber',
 'Driver',
 'DriverNumber',
 'LapTime',
 'Stint',
 'SpeedI1',
 'SpeedI2',
 'SpeedFL',
 'SpeedST',
 'IsPersonalBest',
 'Compound',
 'TyreLife',
 'FreshTyre',
 'Team',
 'LapStartTime',
 'TrackStatus',
 'Position',
 'Deleted',
 'FastF1Generated',
 'IsAccurate',
 'LapSessionTime',
 'DNF',
 'rolling_avg_laptime',
 'pit_in_lap',
 'pit_exit_lap',
 'last_pit_lap',
 'laps_since_last_pit',
 'prev_compound',
 'pit_stop_duration',
 'max_pit_stop_duration',
 'start_position',
 'position_change_since_race_start',
 'fastest_sector',
 'avg_speed_last_lap',
 'max_speed_last_lap',
 'avg_throttle_last_lap',
 'avg_brake_last_lap',
 'gear_change_count',
 'avg_rpm',
 'DRS_activation_count',
 'rolling_throttle_mean',
 'rolling_brake_intensity',
 'rolling_gear_change_rate',
 'rolling_speed_mean',
 'final_sector_avg_speed',
 'final_sector_throttle',
 'final_sector_brake',
 'avg_air_temp',
 'avg_humidity',
 'avg_pressure',
 'max_rainfall',
 'avg_track_temp',
 'avg_wind_direc

In [56]:
data.select("Driver").distinct().show()

+------+
|Driver|
+------+
|   OCO|
|   BOT|
|   HAM|
|   MSC|
|   VER|
|   ZHO|
|   MAG|
|   SAR|
|   NOR|
|   TSU|
|   HUL|
|   ALB|
|   PER|
|   STR|
|   LAW|
|   GAS|
|   LEC|
|   DEV|
|   RUS|
|   COL|
+------+
only showing top 20 rows



In [57]:
# Train: all 2023 races except the last 4
# Test: final race
train_data = data.filter(
    ((col("EventName").isin("Abu Dhabi Grand Prix", )) & 
     (col("Year") == 2022))
)

# Validation: 3 races before final
val_data = data.filter(
    ((col("EventName").isin("Abu Dhabi Grand Prix",)) & 
     (col("Year") == 2023))
)

# Test: final race
test_data = data.filter(
    ((col("EventName").isin("Abu Dhabi Grand Prix")) & 
     (col("Year") == 2024))
)


In [None]:
from pyspark.sql.functions import col
from pyspark.ml.functions import vector_to_array
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Filter for a single race
train_data = data.filter(
    (col("EventName") == "Abu Dhabi Grand Prix") & 
    (col("Year") == 2024)
)

# Check unique values for categorical columns
categorical_cols = ["Team", "Compound", "Driver", "EventName", "prev_compound"]
for cat_col in categorical_cols:
    unique_count = train_data.select(cat_col).distinct().count()
    print(f"Unique values in {cat_col}: {unique_count}")

# Define preprocessing
indexers_and_encoders = [
    StringIndexer(inputCol="Team", outputCol="TeamIndex", handleInvalid="keep"),
    StringIndexer(inputCol="Compound", outputCol="CompoundIndex", handleInvalid="keep"),
    StringIndexer(inputCol="Driver", outputCol="DriverIndex", handleInvalid="keep"),
    StringIndexer(inputCol="EventName", outputCol="EventNameIndex", handleInvalid="keep"),
    StringIndexer(inputCol="prev_compound", outputCol="prev_compound_index", handleInvalid="keep"),
    OneHotEncoder(inputCol="CompoundIndex", outputCol="CompoundIndex_ohe"),
    OneHotEncoder(inputCol="prev_compound_index", outputCol="prev_compound_ohe"),
    OneHotEncoder(inputCol="DriverIndex", outputCol="DriverIndex_ohe")
]

def expand_one_hot_vectors(df, ohe_columns):
    for vec_col in ohe_columns:
        df = df.withColumn(f"{vec_col}_array", vector_to_array(col(vec_col)))
        size = len(df.select(vec_col).first()[0])
        for i in range(size):
            df = df.withColumn(f"{vec_col}_{i}", col(f"{vec_col}_array")[i])
    return df.drop(*[f"{col}_array" for col in ohe_columns])

# Store AUPRC scores
auprc_by_lap = {}

# Get max lap number
max_lap = train_data.agg({"LapNumber": "max"}).collect()[0][0]

# Start from lap 5
for lap in range(5, max_lap):
    print(f"Training on laps ≤ {lap}, predicting lap {lap + 1}")

    train_subset = train_data.filter(col("LapNumber") <= lap)
    test_subset = train_data.filter(col("LapNumber") == lap + 1)

    if test_subset.count() == 0:
        print(f"No data for lap {lap+1}, skipping.")
        continue

    # Fit and transform pipeline
    pipeline = Pipeline(stages=indexers_and_encoders)
    model = pipeline.fit(train_subset)

    train_transformed = model.transform(train_subset)
    test_transformed = model.transform(test_subset)

    # Expand OHE
    ohe_columns = ["CompoundIndex_ohe", "prev_compound_ohe", "DriverIndex_ohe"]
    train_transformed = expand_one_hot_vectors(train_transformed, ohe_columns)
    test_transformed = expand_one_hot_vectors(test_transformed, ohe_columns)

    # Assemble features
    feature_cols = [c for c in train_transformed.columns if c.startswith("CompoundIndex_ohe_") or 
                    c.startswith("prev_compound_ohe_") or 
                    c.startswith("DriverIndex_ohe_") or 
                    c in ["TeamIndex"]]

    assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
    train_assembled = assembler.transform(train_transformed)
    test_assembled = assembler.transform(test_transformed)

    # Train classifier
    rf = RandomForestClassifier(labelCol="WillPitNextLap", featuresCol="features", maxBins=64)
    rf_model = rf.fit(train_assembled)

    # Predict
    predictions = rf_model.transform(test_assembled)

    # Evaluate AUPRC
    evaluator = BinaryClassificationEvaluator(
        labelCol="WillPitNextLap", 
        rawPredictionCol="rawPrediction", 
        metricName="areaUnderPR"
    )
    auprc = evaluator.evaluate(predictions)
    auprc_by_lap[lap + 1] = auprc

    print(f"Lap {lap + 1} AUPRC: {auprc:.4f}")

Unique values in Team: 10
Unique values in Compound: 3
Unique values in Driver: 19
Unique values in EventName: 1
Unique values in prev_compound: 3
Training on laps ≤ 5, predicting lap 6
Lap 6 AUPRC: 0.0000
Training on laps ≤ 6, predicting lap 7
