Implementing a Custom Join 
==============

## Sanity Checks Added

This refined version includes comprehensive sanity checks at each major step to track row counts and identify where flights are being dropped.

**Checkpoints:**
1. After Initial Load - Baseline row count
2. After dropDuplicates - Check for duplicate removal impact
3. After Airport Join - Verify airport data join (LEFT JOIN preserves all flights)
4. After Station Join (Left) - Check station matching (LEFT JOIN preserves all flights)
5. After Station Filter - **CRITICAL**: This is where flights without stations are dropped
6. After Weather Join - Check weather data join (INNER JOIN may drop flights)
7. Final Before Save - Final row count before saving

**Summary Cell:** A comparison table showing row counts and losses at each step is included before the final save.

---


# Setup

In [0]:
data_version = "3m" # "3m", "6m", "1y", "" -> blank is full

In [0]:
# Imports

from pyspark.sql.functions import col, regexp_replace, split, trim, to_timestamp, date_format, broadcast
from pyspark.sql.window import Window
import pyspark.sql.functions as F
import pandas as pd

##### Spark settings for efficient execution

In [0]:
spark.conf.get("spark.sql.session.timeZone")
spark.conf.set("spark.sql.session.timeZone", "UTC")

In [0]:
# Enable Adaptive Query Execution (AQE) and optimizations for joins
spark.conf.set("spark.sql.adaptive.enabled", "true")                  # Adaptive Query Execution (auto optimizes joins/shuffles)
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")         # Automatically handles skewed joins
#spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")  # Merge small partitions post-shuffle
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024) # Optional: increase broadcast threshold to 50MB


Reading back to see the spark settings

In [0]:
print("Spark version:", spark.version)
print("AQE:", spark.conf.get("spark.sql.adaptive.enabled"))
print("Skew join:", spark.conf.get("spark.sql.adaptive.skewJoin.enabled"))
print("Coalesce:", spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled"))
print("AutoBroadcastJoinThreshold:", spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))
print("spark.sql.shuffle.partitions:", spark.conf.get("spark.sql.shuffle.partitions"))


In [0]:
# Paths & Defines

data_BASE_DIR = "dbfs:/mnt/mids-w261/"
display(dbutils.fs.ls(f"{data_BASE_DIR}")) 

# Section Folder
section = "4"
number = "2"
section_DIR = f"dbfs:/mnt/mids-w261/student-groups/Group_{section}_{number}"

# Subdirectories for organization
raw_DIR = f"{section_DIR}/raw"
processed_DIR = f"{section_DIR}/processed"
checkpoints_DIR = f"{section_DIR}/checkpoints"
intermediate_DIR = f"{section_DIR}/intermediate"

# Output filenames (using variables for maintainability)
FLIGHTS_WITH_STATION = f"flights_with_station_{data_version}"
FLIGHTS_WEATHER_JOINED = f"flights_weather_joined_{data_version}"
FLIGHTS_WITH_AIRPORTS = f"flights_with_airports_{data_version}"

# Full paths for outputs
flights_with_station_path = f"{intermediate_DIR}/{FLIGHTS_WITH_STATION}"
flights_weather_joined_path = f"{processed_DIR}/{FLIGHTS_WEATHER_JOINED}"
flights_with_airports_path = f"{intermediate_DIR}/{FLIGHTS_WITH_AIRPORTS}"

# Check if section_DIR exists, print contents or create it
try:
    print(f"✓ Section directory exists: {section_DIR}")
    print("\nContents:")
    contents = dbutils.fs.ls(section_DIR)
    for item in contents:
        print(f"  - {item.name} ({'DIR' if item.isDir() else 'FILE'}) - {item.size} bytes")
    print(f"\nTotal items: {len(contents)}")
except Exception as e:
    print(f"✗ Section directory does not exist: {section_DIR}")
    print("Creating directory structure...")
    dbutils.fs.mkdirs(section_DIR)
    print(f"✓ Base directory created: {section_DIR}")

# Create subdirectories
print("\nCreating/verifying subdirectories...")
for subdir_name, subdir_path in [
    ("raw", raw_DIR),
    ("processed", processed_DIR),
    ("checkpoints", checkpoints_DIR),
    ("intermediate", intermediate_DIR)
]:
    try:
        dbutils.fs.mkdirs(subdir_path)
        print(f"✓ {subdir_name}: {subdir_path}")
    except Exception as e:
        print(f"✗ Error creating {subdir_name}: {e}")

# Set checkpoint directory for Spark
spark.sparkContext.setCheckpointDir(checkpoints_DIR)

print("\n" + "="*60)
print("Directory structure ready!")
print("="*60)
print(f"\nKey paths:")
print(f"  Raw: {raw_DIR}")
print(f"  Intermediate: {intermediate_DIR}")
print(f"  Processed: {processed_DIR}")
print(f"  Checkpoints: {checkpoints_DIR}")
print(f"\nOutput files:")
print(f"  Flights+Station: {flights_with_station_path}")
print(f"  Flights+Weather: {flights_weather_joined_path}")

# Load Data

#### Provided Datasets

In [0]:
dbutils.fs.ls("dbfs:/mnt/mids-w261/datasets_final_project_2022/")

In [0]:
# Airline Data    
if data_version == "":
    df_flights = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_airlines_data/") # full dataset
else:
    df_flights = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_airlines_data_{data_version}/") 

# Stations data      
df_stations = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/stations_data/stations_with_neighbors.parquet/")

# Weather data
if data_version == "":
    df_weather = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_weather_data/") # full dataset
else:
    df_weather = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_weather_data_{data_version}/")


In [None]:
# ============================================================
# SANITY CHECK: After Initial Load
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Initial Load")
print("="*80)

# Row count
try:
    row_count = df_flights.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights.columns:
    key_columns.append('origin')
if 'dest' in df_flights.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights.columns:
    try:
        dropped_airports = df_flights.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [0]:
df_flights.printSchema()

In [0]:
df_weather.printSchema()

In [0]:
df_stations.printSchema()

#### Airport Codes - Additional 

This adds missing data in DF Flights, like Latitude & Longitude

In [0]:
# Download the CSV file to a local temp path on the driver
import requests
import os

url = "https://datahub.io/core/airport-codes/r/airport-codes.csv"
local_tmp_path = "/tmp/airport-codes.csv"   # driver-local temp file

response = requests.get(url)
response.raise_for_status()  # optional but recommended

with open(local_tmp_path, "wb") as f:
    f.write(response.content)

# Target DBFS directory and file path (stable location)
target_dir = "/mnt/mids-w261/student-groups/Group_4_2/raw"
target_path = f"{target_dir}/airport-codes.csv"

# Make sure the directory exists
dbutils.fs.mkdirs(target_dir)

# If the file already exists, you can remove it first (optional but avoids cp errors)
dbutils.fs.rm(target_path, recurse=False)

# Copy from driver-local file system to DBFS mount
dbutils.fs.cp("file://" + local_tmp_path, target_path)

# Read the CSV file from the stable DBFS path
df_airport_codes = (
    spark.read
    .option("header", "true")
    .csv('/dbfs/mnt/mids-w261/student-groups/Group_4_2/raw/airport-codes.csv')
    .cache()
)

# Force materialization of the cache so later stages don't re-read the CSV
df_airport_codes.count()

# display(df_airport_codes)

In [0]:
# df_airport_codes.printSchema()

In [0]:
# df_airport_codes.describe()

In [0]:
# display(df_airport_codes.limit(5))

#### Airport - Timezone string - Additional Dataset

Source: https://github.com/opentraveldata/opentraveldata/blob/master/opentraveldata/optd_por_public.csv 


In [0]:
url = "https://raw.githubusercontent.com/opentraveldata/opentraveldata/master/opentraveldata/optd_por_public.csv"

local_path = "/dbfs/tmp/airport-timezones.csv"

with open(local_path, "wb") as f:
    f.write(requests.get(url).content)

# Use the corresponding DBFS path for Spark
dbfs_path = "dbfs:/tmp/airport-timezones.csv"

# Copy to DBFS
dbutils.fs.cp("file:" + local_path, dbfs_path)

df_airport_timezones = (
    spark.read
        .option("header", True)
        .option("delimiter", "^")
        .option("inferSchema", True)
        .csv(dbfs_path)
)

cols_to_keep = [
    "iata_code",
    "icao_code",
    "faa_code",
    "timezone",
    "latitude",
    "longitude"
]

df_airport_timezones = df_airport_timezones.select(cols_to_keep)

# Write to Parquet and read back to avoid file re-read issues
parquet_path = "dbfs:/tmp/airport-timezones-cached.parquet"

# Check if Parquet file already exists (avoid re-reading CSV on subsequent runs)
    # Try to read existing Parquet file first
    df_airport_timezones = spark.read.parquet(parquet_path)
    df_airport_timezones.count()  # Materialize to break lineage
    print("✓ Using existing Parquet file")
except:
    # Parquet doesn't exist, create it from CSV DataFrame
    print("Creating Parquet file from CSV...")
    df_airport_timezones.write.mode("overwrite").parquet(parquet_path)
    df_airport_timezones = spark.read.parquet(parquet_path)
    df_airport_timezones.count()  # Materialize to break lineage from CSV
    print("✓ Parquet file created and materialized")



    # Try to read existing Parquet file first
except:
    # Parquet doesn't exist, create it from CSV


# display(df_airport_timezones)

In [0]:
def get_df_stats(df, sample_size_estimate=False, tmp_path="/tmp/df_profile_tmp"):
    """
    Returns number of rows, columns, and size of a Spark DataFrame.
    
    Parameters:
        df (DataFrame): Input Spark DataFrame.
        sample_size_estimate (bool): 
            If True -> estimate size using string length of rows.
            If False -> write dataframe to parquet to get accurate size.
        tmp_path (str): Temporary path for size calculation if sample_size_estimate=False.
    
    Returns:
        dict: {
            "rows": <int>,
            "columns": <int>,
            "size_bytes": <int>
        }
    """
    
    # Number of rows
    rows = df.count()
    
    # Number of columns
    cols = len(df.columns)
    
    # Size calculation
    if sample_size_estimate:
        size_bytes = df.rdd.map(lambda row: len(str(row))).sum()
    else:
        # Remove tmp directory if exists
        try:
            dbutils.fs.rm(tmp_path, recurse=True)
        except:
            pass
        
        # Write to temp parquet
        df.write.mode("overwrite").parquet(tmp_path)
        
        # Sum file sizes
        size_bytes = sum([f.size for f in dbutils.fs.ls(tmp_path)])
        
        # Clean up temp files
        dbutils.fs.rm(tmp_path, recurse=True)
    
    return {
        "rows": rows,
        "columns": cols,
        "size_bytes": size_bytes
    }

print(get_df_stats(df_airport_timezones))

In [0]:
print(get_df_stats(df_airport_codes))

In [0]:
# Join airport_codes table with airport_timezones 

# Perform a left join on both iata_code and icao_code
df_airport_joined = (
    df_airport_codes.alias("a")
    .join(
        df_airport_timezones.alias("b"),
        (
            (F.col("a.iata_code") == F.col("b.iata_code")) |
            (F.col("a.icao_code") == F.col("b.icao_code"))
        ),
        how="left"
    )
)

# Display result
# display(df_airport_joined) 

In [0]:
# print("Row count before join:", df_airport_codes.count())
# print("Row count after join:", df_airport_joined.count())

In [0]:
df_airport_joined_clean = df_airport_joined.select(
    "ident", "type", "name", "elevation_ft",
    "continent", "iso_country", "iso_region", "municipality",
    "a.icao_code", "a.iata_code", "gps_code", "local_code",
    "coordinates",
    F.col("b.latitude").alias("latitude"),
    F.col("b.longitude").alias("longitude"),
    F.col("b.timezone").alias("timezone"),
    F.col("b.faa_code").alias("faa_code_otd")
)

# display(df_airport_joined_clean)


In [0]:

df_airport_joined = df_airport_joined_clean

# Define window partitioned by your key columns
w = Window.partitionBy("iata_code", "icao_code").orderBy(
    F.when(F.col("latitude").isNotNull(), 1).otherwise(2),  # prefer non-null lat
    F.when(F.col("timezone").isNotNull(), 1).otherwise(2)   # then prefer non-null tz
)

# Add row number within each partition and keep the first
df_airport_dedup = (
    df_airport_joined
    .withColumn("row_num", F.row_number().over(w))
    .filter(F.col("row_num") == 1)
    .drop("row_num")
)

# print("Row count before de-dup:", df_airport_joined.count())
# print("Row count after de-dup:", df_airport_dedup.count()) 

In [0]:
# Add timestamp column

df_flights = df_flights.withColumn(
    "fl_date_timestamp",
    to_timestamp(col("fl_date"), "yyyy-MM-dd")
)

# Parse full datetime string (keep hours/minutes)
df_weather = df_weather.withColumn(
    "date_timestamp",
    F.to_timestamp(F.col("date"), "yyyy-MM-dd'T'HH:mm:ss")
)

In [0]:
# display(df_airport_dedup.select("timezone").groupBy("timezone").count().orderBy("count", ascending=False))

In [0]:
df_airports = df_airport_dedup

In [0]:
# De-duplicate 
df_flights = df_flights.dropDuplicates() # This has known duplicates
df_weather = df_weather.dropDuplicates()
df_stations = df_stations.dropDuplicates()
df_airports = df_airports.dropDuplicates()

In [None]:
# ============================================================
# SANITY CHECK: After dropDuplicates
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After dropDuplicates")
print("="*80)

# Row count
try:
    row_count = df_flights.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights.columns:
    key_columns.append('origin')
if 'dest' in df_flights.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights.columns:
    try:
        dropped_airports = df_flights.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [None]:
# ============================================================
# SANITY CHECK: After dropDuplicates
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After dropDuplicates")
print("="*80)

# Row count
try:
    row_count = df_flights.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights.columns:
    key_columns.append('origin')
if 'dest' in df_flights.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights.columns:
    try:
        dropped_airports = df_flights.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [0]:

# Convert all column names to lowercase
df_flights = df_flights.toDF(*[c.lower() for c in df_flights.columns])
df_weather = df_weather.toDF(*[c.lower() for c in df_weather.columns])
df_stations = df_stations.toDF(*[c.lower() for c in df_stations.columns])
df_airports = df_airports.toDF(*[c.lower() for c in df_airports.columns])


##### Split coordinates column into latitude and longitude

In [0]:
# create latitude and longitude colums in Airport Codes
df_airports = df_airports.withColumn(
    "coordinates",
    regexp_replace("coordinates", "[()]", "")
)

df_airports = df_airports.withColumn(
    "lat_lon",
    split("coordinates", ",")
)

df_airports = df_airports.withColumn(
    "latitude",
    trim(df_airports["lat_lon"].getItem(0))
).withColumn(
    "longitude",
    trim(df_airports["lat_lon"].getItem(1))
).drop("lat_lon")

df_airports.cache()
# display(df_airports)

In [0]:
# display(df_stations.limit(5))

##### create neighbor_iata column

Since Stations data's contains 4 digit ICAO codes, we need to  convert it to 3 digit IATA codes for US airports.  

Example: 
ICAO (4-letter): e.g., KJFK, KLAX, EGLL
→ This is what neighbor_call contains.

IATA (3-letter): e.g., JFK, LAX, LHR
→ This is what flight data uses.



In [0]:

df_stations = df_stations.withColumn(
    "neighbor_iata",
    F.expr("substring(neighbor_call, 2, 3)")  # remove leading 'K' for US airports
) 

In [0]:
# display(df_stations.select("neighbor_call","neighbor_iata").limit(5))

##### create scheduled departure time  (local)

In [0]:
from pyspark.sql.functions import col, lpad, concat, to_timestamp, expr

# Combine fl_date_timestamp and crs_dep_time to create sched_depart_date_time
df_flights = df_flights.withColumn(
    "sched_depart_date_time",
    to_timestamp(
        concat(
            col("fl_date"),
            lpad(col("crs_dep_time"), 4, "0")
        ),
        "yyyy-MM-ddHHmm"
    )
)


In [0]:
# display(df_airport_codes.limit(5))

# Implement Joins

#### Step 1: Join Flights data with Airports

This join is done using IATA code of the airport. We do this join to get data like latitude/longitude of the airport - both origin & destination. 

In [0]:

# Prepare airports dataframe - select only needed columns
df_airports_clean = df_airports.select(
    col("iata_code"),
    col("name"),
    col("latitude").cast("double").alias("latitude"),
    col("longitude").cast("double").alias("longitude"),
    col("iso_country"),
    col("timezone")
).filter(
    col("iata_code").isNotNull() &
    col("latitude").isNotNull() &
    col("longitude").isNotNull()
)

# print(f"Airports count: {df_airports_clean.count()}")

# Broadcast airports (it's small enough)
airports_broadcast = broadcast(df_airports_clean)

# Perform the join
df_flights_with_airports = (
    df_flights.alias("f")
    .join(
        airports_broadcast.alias("ao"),
        col("f.origin") == col("ao.iata_code"),
        "left"
    )
    .join(
        airports_broadcast.alias("ad"),
        col("f.dest") == col("ad.iata_code"),
        "left"
    )
    .select(
        col("f.*"),
        # Origin airport info
        col("ao.name").alias("origin_airport_name"),
        col("ao.latitude").alias("origin_latitude"),
        col("ao.longitude").alias("origin_longitude"),
        col("ao.iso_country").alias("origin_country"),
        col("ao.timezone").alias("origin_timezone"),
        # Destination airport info
        col("ad.name").alias("destination_airport_name"),
        col("ad.latitude").alias("destination_latitude"),
        col("ad.longitude").alias("destination_longitude"),
        col("ad.iso_country").alias("destination_country"),
        col("ad.timezone").alias("destination_timezone")
    )
)


In [None]:
# ============================================================
# SANITY CHECK: After Airport Join
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Airport Join")
print("="*80)

# Row count
try:
    row_count = df_flights_with_airports.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_airports.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_airports.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_airports.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_airports.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_airports.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_airports.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_airports.columns:
    try:
        dropped_airports = df_flights_with_airports.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [0]:
# display(df_flights_combined.describe())

Get all timestamps in UTC using timezone info

In [0]:
df_flights_with_airports = df_flights_with_airports.withColumn(
    "sched_depart_date_time_UTC",
    F.to_utc_timestamp("sched_depart_date_time", F.col("origin_timezone"))
)

# Two hours prior to scheduled departure in UTC
df_flights_with_airports = df_flights_with_airports.withColumn(
    "two_hours_prior_depart_UTC",
    expr("sched_depart_date_time_UTC - INTERVAL 2 HOURS")
)

# Four hours prior to scheduled departure in UTC
df_flights_with_airports = df_flights_with_airports.withColumn(
    "four_hours_prior_depart_UTC",
    expr("sched_depart_date_time_UTC - INTERVAL 4 HOURS")
)

# display(df_flights_with_airports.select(
#     "fl_date", "crs_dep_time", "sched_depart_date_time",
#     "sched_depart_date_time_UTC", "two_hours_prior_depart_UTC", "four_hours_prior_depart_UTC"
# ).limit(5))

In [None]:
# ============================================================
# SANITY CHECK: After Airport Join
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Airport Join")
print("="*80)

# Row count
try:
    row_count = df_flights_with_airports.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_airports.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_airports.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_airports.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_airports.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_airports.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_airports.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_airports.columns:
    try:
        dropped_airports = df_flights_with_airports.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [0]:


print("\n✓ Join 1 Complete")

##### Checkpoint 1: Store after first join 

In [0]:
# df_flights_combined 

# df_flights_with_station_notnull = spark.read.parquet(f"{section_DIR}/df_flights_with_station_{data_version}")

#### Step 2: Join Flights+Airport with Stations 

Here we try two different approaches:
 - Using neighbor_call/neighbor_iata 
 - Using latitude & longitude


In [0]:
# display(df_stations.limit(5))

So that shows all of them match! 

##### Join type 2: get stationID from stations table using lat/long



Doing a cross-join with the whole flights data & station data will create too many combinations. It's better we only get the join done with distinct airports and later combine it with original flights data.

TODO: Try Haversine approach

In [0]:

# Step 1: Get distinct origin airports
df_distinct_origins = (
    df_flights_with_airports
    .select("origin", "origin_latitude", "origin_longitude")
    .distinct()
    .filter(
        col("origin_latitude").isNotNull() & 
        col("origin_longitude").isNotNull()
    )
    .repartition(200, "origin")
    #.cache()
)

print(f"Distinct origins: {df_distinct_origins.count()}")

# Step 2: Prepare stations dataframe
df_stations_clean = df_stations.select(
    col("station_id"),
    col("lat").cast("double"),
    col("lon").cast("double")
).filter(
    col("station_id").isNotNull() &
    col("lat").isNotNull() &
    col("lon").isNotNull()
)

stations_broadcast = broadcast(df_stations_clean)

# Step 3: Perform spatial join with bounding box filter
df_candidates = (
    df_distinct_origins.alias("a")
    .join(
        stations_broadcast.alias("s"),
        (col("s.lat").between(col("a.origin_latitude") - 0.5, col("a.origin_latitude") + 0.5)) &
        (col("s.lon").between(col("a.origin_longitude") - 0.5, col("a.origin_longitude") + 0.5)),
        "inner"
    )
    .withColumn(
        "distance_km",
        F.expr("""
            6371 * acos(
                least(1.0,
                    cos(radians(a.origin_latitude)) * cos(radians(s.lat)) *
                    cos(radians(s.lon) - radians(a.origin_longitude)) +
                    sin(radians(a.origin_latitude)) * sin(radians(s.lat))
                )
            )
        """)
    )
    .filter(col("distance_km") < 50)
)

# Step 4: Get nearest station per airport
window_spec = Window.partitionBy("a.origin").orderBy("distance_km")

df_nearest_stations = (
    df_candidates
    .withColumn("rank", F.row_number().over(window_spec))
    .filter(col("rank") == 1)
    .select(
        col("a.origin").alias("origin"),
        col("s.station_id").alias("origin_station_id"),
        col("distance_km").alias("station_distance_km")
    )
)

# print(f"Airports with stations: {df_nearest_stations.count()}")


In [0]:

# Step 5: Join back to main flights dataframe
df_flights_with_station = (
    df_flights_with_airports
    .join(df_nearest_stations, on="origin", how="left")
)

# Check quality
null_count = df_flights_with_station.filter(col("origin_station_id").isNull()).count()
print(f"Flights without station: {null_count:,}")

# Filter out flights without stations
df_flights_with_station_clean = df_flights_with_station.filter(
    col("origin_station_id").isNotNull()
)

# print(f"Final flights with stations: {df_flights_with_station_clean.count():,}")
print("\n✓ Join 2 Complete")

In [None]:
# ============================================================
# SANITY CHECK: After Station Filter (DROPS ROWS)
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Station Filter (DROPS ROWS)")
print("="*80)

# Row count
try:
    row_count = df_flights_with_station_clean.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_station_clean.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_station_clean.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_station_clean.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_station_clean.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_station_clean.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_station_clean.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_station_clean.columns:
    try:
        dropped_airports = df_flights_with_station_clean.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [None]:
# ============================================================
# SANITY CHECK: After Station Join (Left)
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Station Join (Left)")
print("="*80)

# Row count
try:
    row_count = df_flights_with_station.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_station.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_station.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_station.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_station.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_station.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_station.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_station.columns:
    try:
        dropped_airports = df_flights_with_station.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [0]:


# Unpersist cached data
df_distinct_origins.unpersist()

# display(df_flights_with_station_clean.limit(5))




Inspect the columns where origin_station_id is null

In [None]:
# ============================================================
# SANITY CHECK: After Station Filter (DROPS ROWS)
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Station Filter (DROPS ROWS)")
print("="*80)

# Row count
try:
    row_count = df_flights_with_station_clean.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_station_clean.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_station_clean.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_station_clean.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_station_clean.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_station_clean.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_station_clean.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_station_clean.columns:
    try:
        dropped_airports = df_flights_with_station_clean.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [None]:
# ============================================================
# SANITY CHECK: After Station Join (Left)
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Station Join (Left)")
print("="*80)

# Row count
try:
    row_count = df_flights_with_station.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_flights_with_station.columns:
    key_columns.append('origin')
if 'dest' in df_flights_with_station.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_flights_with_station.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_flights_with_station.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_flights_with_station.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_flights_with_station.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_flights_with_station.columns:
    try:
        dropped_airports = df_flights_with_station.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


Most of the ones which are null are:

- PSE — Mercedita Airport, Ponce, Puerto Rico 🇵🇷
- GUM — Antonio B. Won Pat International Airport, Guam 🇬🇺
- PPG — Pago Pago International Airport (Tafuna Airport), American Samoa 🇦🇸
- ISN — Sloulin Field International Airport (now replaced by Williston Basin International Airport), Williston, North Dakota, USA 🇺🇸

In [0]:
df_flights_with_station_notnull = df_flights_with_station.filter(F.col("origin_station_id").isNotNull())
# display(df_flights_with_station_notnull.limit(5))


##### Checkpoint 2: Before Final Join



In [0]:

# Load intermediate result later
#df_flights_with_station_notnull = spark.read.parquet(flights_with_station_path)


#### Step 3: Join Flights+Airports+Stations with Weather 

##### Run the full join

In [0]:
spark.sparkContext.setCheckpointDir(
    checkpoints_DIR
)

In [0]:
# COMMAND ----------
# MAGIC %md
# MAGIC ## Step 3: Weather Join with Station + Bucket Co-Partitioning
# MAGIC 
# MAGIC *Goal:*
# MAGIC 1. Partition both flights and weather by (station, bucket).
# MAGIC 2. Ensure each partition sees weather from the current and previous bucket.
# MAGIC 3. Within each partition, pick the closest weather record **before** the flight's two-hours-prior time.
# MAGIC 4. Retain all flights, even if no weather is found.

# COMMAND ----------
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.functions import col

print("=" * 60)
print("JOIN 3: Flights + Weather (Station + Bucket Co-Partitioned)")
print("=" * 60)

# Disable Photon just for this complex join step
spark.conf.set("spark.databricks.photon.enabled", "false")
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

# -------------------------------------------------------------------
# Time bucketing configuration
# -------------------------------------------------------------------
BUCKET_INTERVAL_MINUTES = 30
BUCKET_INTERVAL_SECONDS = BUCKET_INTERVAL_MINUTES * 60

# -------------------------------------------------------------------
# 1. Normalize types and add time buckets
# -------------------------------------------------------------------

# Ensure station IDs have consistent types
df_weather = df_weather.withColumn("station", col("station").cast("string"))
df_flights_with_station = df_flights_with_station.withColumn(
    "origin_station_id", col("origin_station_id").cast("string")
)

# Weather: compute base bucket
df_weather_bucketed = (
    df_weather
    .withColumn("weather_ts", col("date_timestamp").cast("timestamp"))
    .withColumn(
        "bucket",
        (
            col("weather_ts").cast("long") / F.lit(BUCKET_INTERVAL_SECONDS)
        ).cast("long") * F.lit(BUCKET_INTERVAL_SECONDS)
    )
)

# Create a shifted version so that for bucket B we also have weather
# from the previous bucket (B - interval).
# We do this by shifting weather from bucket (original) to bucket + interval.
# So for any bucket B, it contains:
#   - weather originally in B
#   - weather originally in B - interval (now shifted to B)
df_weather_shifted_for_prev = df_weather_bucketed.withColumn(
    "bucket",
    col("bucket") + F.lit(BUCKET_INTERVAL_SECONDS)
)

# Union original + shifted weather
df_weather_for_join = df_weather_bucketed.unionByName(df_weather_shifted_for_prev)

# Repartition weather by (station, bucket) so each partition has
# the station and its two relevant buckets' weather.
df_weather_for_join = df_weather_for_join.repartition("station", "bucket")


# Flights: compute bucket using "two_hours_prior_depart_utc"
df_flights_with_buckets = (
    df_flights_with_station
    .withColumn("flight_ts", col("two_hours_prior_depart_utc").cast("timestamp"))
    .withColumn(
        "bucket",
        (
            col("flight_ts").cast("long") / F.lit(BUCKET_INTERVAL_SECONDS)
        ).cast("long") * F.lit(BUCKET_INTERVAL_SECONDS)
    )
)

# Repartition flights by (origin_station_id, bucket)
# so they align with weather partitions.
df_flights_with_buckets = df_flights_with_buckets.repartition(
    "origin_station_id", "bucket"
)

# -------------------------------------------------------------------
# 2. Join on station + bucket (current + previous weather bucket present)
# -------------------------------------------------------------------
print("\nExecuting station + bucket join...")

# remove duplicate column
df_weather_for_join = df_weather_for_join.withColumnRenamed("year", "weather_year")

df_joined = (
    df_flights_with_buckets.alias("f")
    .join(
        df_weather_for_join.alias("w"),
        (
            (col("f.origin_station_id") == col("w.station")) &
            (col("f.bucket") == col("w.bucket"))
        ),
        "left"  # retain all flights
    )
    .select(
        col("f.*"),
        col("w.*")
    )
)

# -------------------------------------------------------------------
# 3. On each partition: keep weather only if it is BEFORE the flight,
#    and retain the closest one.
# -------------------------------------------------------------------
print("\nFiltering weather <= flight time and ranking...")

# Filter: keep rows where weather is null (no match) OR
# weather_ts <= two_hours_prior_depart_utc.
df_joined_filtered = df_joined.filter(
    (col("w.weather_ts").isNull()) |
    (col("w.weather_ts") <= col("f.flight_ts"))
)

# Compute time difference in seconds (flight_ts - weather_ts).
# For null weather_ts we keep time_diff_sec as null.
df_with_diff = (
    df_joined_filtered
    .withColumn(
        "time_diff_sec",
        F.when(
            col("w.weather_ts").isNull(),
            F.lit(None).cast("long")
        ).otherwise(
            col("f.flight_ts").cast("long") - col("w.weather_ts").cast("long")
        )
    )
)

# Define a window per unique flight. Adjust keys if needed
# to match the flight's natural primary key in your data.
window_closest = Window.partitionBy(
    "f.origin_station_id",
    "f.fl_date",
    "f.crs_dep_time",
    "f.op_carrier_fl_num"
).orderBy(
    col("time_diff_sec").isNull().asc(),  # non-null weather first
    col("time_diff_sec").asc_nulls_last()
)

df_ranked = (
    df_with_diff
    .withColumn("has_weather", col("w.weather_ts").isNotNull())
    .withColumn("weather_rank", F.row_number().over(window_closest))
)

# Keep:
#   - all flights with no weather (has_weather = false)
#   - the single closest weather row (rank = 1) when weather exists
df_final = (
    df_ranked
    .filter(
        (~col("has_weather")) | (col("weather_rank") == 1)
    )
    .drop(
        "weather_rank",
        "time_diff_sec",
        "bucket",            # bucket (same name on both sides) no longer needed
        "station",           # weather station if you don't need it; keep if required
        "has_weather",
        "weather_ts",        # drop if date_timestamp already present & sufficient
        "flight_ts"          # drop if two_hours_prior_depart_utc is kept
    )
)

print("✓ Closest-weather selection complete (all flights retained).")

# (Optional) Persist or write df_final; no count() here to avoid full scan.
# Example:
# df_final.write.format("delta").mode("overwrite").save("/path/to/output")

# Re-enable Photon for subsequent (simpler) operations
spark.conf.set("spark.databricks.photon.enabled", "true")

print("\n" + "=" * 60)
print("✓ JOIN 3 COMPLETE (Station + Bucket Co-Partitioned)")
print("=" * 60)


In [None]:
# ============================================================
# SANITY CHECK: After Weather Join
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Weather Join")
print("="*80)

# Row count
try:
    row_count = df_joined.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_joined.columns:
    key_columns.append('origin')
if 'dest' in df_joined.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_joined.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_joined.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_joined.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_joined.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_joined.columns:
    try:
        dropped_airports = df_joined.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


## Flight Lineage Join - Final Step

This section adds flight lineage features by joining each flight to its previous flight in the lineage (same aircraft, previous flight). 

**Key Features Added:**
- Previous flight information (origin, dest, times, delays, etc.)
- Turnover time (time from arrival to departure)
- Cumulative delays
- Sequence information (lineage rank)
- Jump detection

**CRITICAL: NO ROWS ARE DROPPED** - All flights are preserved. Flights without previous flight data get NULL values which are handled via imputation.


In [None]:
# ============================================================================
# FLIGHT LINEAGE JOIN - FINAL STEP
# ============================================================================
# All flight lineage feature engineering code grouped together here

print("=" * 60)
print("FLIGHT LINEAGE JOIN")
print("=" * 60)

# Step 1: Identify tail_num column (handle variations)
tail_num_candidates = ['tail_num', 'TAIL_NUM', 'tail_number', 'TAIL_NUMBER', 'op_unique_carrier_tail_num']
tail_num_col = None

for candidate in tail_num_candidates:
    if candidate in df_final.columns:
        tail_num_col = candidate
        print(f"✓ Found tail_num column: {tail_num_col}")
        break

if tail_num_col is None:
    # Try pattern matching
    tail_cols = [c for c in df_final.columns if 'tail' in c.lower()]
    if tail_cols:
        tail_num_col = tail_cols[0]
        print(f"✓ Found tail_num column via pattern matching: {tail_num_col}")
    else:
        raise ValueError(f"Could not find tail_num column. Available columns: {df_final.columns[:20]}...")

# Step 2: Prepare arrival timestamp for ranking
# Use scheduled arrival time as fallback if actual unavailable (preserves all rows)
print("\nStep 2: Creating arrival timestamp for ranking...")
df_final = df_final.withColumn(
    'arrival_time_for_ranking',
    F.coalesce(col('arr_time'), col('crs_arr_time'))
)

# Convert to timestamp for proper temporal ordering
df_final = df_final.withColumn(
    'arrival_timestamp',
    F.when(
        col('arrival_time_for_ranking').isNotNull() & col('fl_date').isNotNull(),
        F.to_timestamp(
            F.concat(
                col('fl_date'),
                F.lpad(col('arrival_time_for_ranking').cast('string'), 4, '0')
            ),
            'yyyy-MM-ddHHmm'
        )
    ).otherwise(None)
)

print("✓ Arrival timestamp created")


In [None]:
# Step 3: Create window specification and rank flights
print("\nStep 3: Creating window specification and ranking flights...")

# Window: partition by tail_num, order by arrival timestamp (ASCENDING)
# ASCENDING order is critical: earliest flights first, so LAG gets the previous flight
window_spec = Window.partitionBy(tail_num_col).orderBy(F.col('arrival_timestamp').asc_nulls_last())

# Rank flights by arrival time (1 = earliest, higher = more recent)
# lineage_rank is highly predictive: indicates how many flights aircraft has completed
df_final = df_final.withColumn('lineage_rank', F.row_number().over(window_spec))

print("✓ Flights ranked")


In [None]:
# Step 4: Get Previous Flight Data Using LAG
print("\nStep 4: Getting previous flight data using LAG...")

# Core Previous Flight Information (Required for Feature Engineering)
df_final = df_final.withColumn('prev_flight_origin', F.lag('origin', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_dest', F.lag('dest', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_actual_dep_time', F.lag('dep_time', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_actual_arr_time', F.lag('arr_time', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_dep_delay', F.lag('dep_delay', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_arr_delay', F.lag('arr_delay', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_air_time', F.lag('air_time', 1).over(window_spec))

# Scheduled Times (for fallback when actual unavailable)
df_final = df_final.withColumn('prev_flight_crs_dep_time', F.lag('crs_dep_time', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_crs_arr_time', F.lag('crs_arr_time', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_crs_elapsed_time', F.lag('crs_elapsed_time', 1).over(window_spec))

# Time Components (for turn time and taxi time calculations)
df_final = df_final.withColumn('prev_flight_taxi_in', F.lag('taxi_in', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_taxi_out', F.lag('taxi_out', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_wheels_off', F.lag('wheels_off', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_wheels_on', F.lag('wheels_on', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_actual_elapsed_time', F.lag('actual_elapsed_time', 1).over(window_spec))

# Route and Flight Information
df_final = df_final.withColumn('prev_flight_distance', F.lag('distance', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_op_carrier', F.lag('op_carrier', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_op_carrier_fl_num', F.lag('op_carrier_fl_num', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_fl_date', F.lag('fl_date', 1).over(window_spec))

# Status Flags (for understanding previous flight context)
df_final = df_final.withColumn('prev_flight_cancelled', F.lag('cancelled', 1).over(window_spec))
df_final = df_final.withColumn('prev_flight_diverted', F.lag('diverted', 1).over(window_spec))

print("✓ Previous flight data retrieved")


In [None]:
# ============================================================
# SANITY CHECK: After Weather Join
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: After Weather Join")
print("="*80)

# Row count
try:
    row_count = df_joined.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_joined.columns:
    key_columns.append('origin')
if 'dest' in df_joined.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_joined.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_joined.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_joined.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_joined.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_joined.columns:
    try:
        dropped_airports = df_joined.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)


In [None]:
# Step 6: Compute Actual Turnover Time (with data leakage check)

# Convert crs_dep_time to minutes (needed for data leakage checks and calculations)
df_final = df_final.withColumn(
    'crs_dep_time_minutes',
    F.when(
        col('crs_dep_time').isNotNull(),
        (F.floor(col('crs_dep_time') / 100) * 60 + (col('crs_dep_time') % 100))
    ).otherwise(None)
)

print("\nStep 6: Computing actual turnover time (with data leakage check)...")

# Convert crs_dep_time to minutes (needed for data leakage checks and calculations)
df_final = df_final.withColumn(
    'crs_dep_time_minutes',
    F.when(
        col('crs_dep_time').isNotNull(),
        (F.floor(col('crs_dep_time') / 100) * 60 + (col('crs_dep_time') % 100))
    ).otherwise(None)
)


# Convert actual times to minutes
df_final = df_final.withColumn(
    'prev_flight_actual_arr_time_minutes',
    F.when(
        col('prev_flight_actual_arr_time').isNotNull(),
        (F.floor(col('prev_flight_actual_arr_time') / 100) * 60 + (col('prev_flight_actual_arr_time') % 100))
    ).otherwise(None)
)

df_final = df_final.withColumn(
    'actual_dep_time_minutes',
    F.when(
        col('dep_time').isNotNull(),
        (F.floor(col('dep_time') / 100) * 60 + (col('dep_time') % 100))
    ).otherwise(None)
)

# Check data leakage: prev_arr_time must be <= crs_dep_time - 2 hours (120 minutes)
df_final = df_final.withColumn(
    'prev_arr_time_safe_to_use',
    F.when(
        (col('prev_flight_actual_arr_time_minutes').isNotNull()) &
        (col('crs_dep_time_minutes').isNotNull()),
        col('prev_flight_actual_arr_time_minutes') <= (col('crs_dep_time_minutes') - 120)
    ).otherwise(False)
)

# Actual turnover time (only compute if safe)
df_final = df_final.withColumn(
    'lineage_actual_turnover_time_minutes',
    F.when(
        (col('prev_arr_time_safe_to_use') == True) &
        (col('actual_dep_time_minutes').isNotNull()) &
        (col('prev_flight_actual_arr_time_minutes').isNotNull()),
        F.when(
            col('actual_dep_time_minutes') >= col('prev_flight_actual_arr_time_minutes'),
            col('actual_dep_time_minutes') - col('prev_flight_actual_arr_time_minutes')
        ).otherwise(
            col('actual_dep_time_minutes') + 1440 - col('prev_flight_actual_arr_time_minutes')
        )
    ).otherwise(None)
)

# Create aliases
df_final = df_final.withColumn('lineage_actual_taxi_time_minutes', col('lineage_actual_turnover_time_minutes'))
df_final = df_final.withColumn('lineage_actual_turn_time_minutes', col('lineage_actual_turnover_time_minutes'))

print("✓ Actual turnover time computed with data leakage check")


In [None]:
# Step 7: Compute Expected Flight Time and Cumulative Features
print("\nStep 7: Computing expected flight time and cumulative features...")

# Expected flight time = scheduled arrival - scheduled departure
df_final = df_final.withColumn(
    'crs_arr_time_minutes',
    F.when(
        col('crs_arr_time').isNotNull(),
        (F.floor(col('crs_arr_time') / 100) * 60 + (col('crs_arr_time') % 100))
    ).otherwise(None)
)

df_final = df_final.withColumn(
    'lineage_expected_flight_time_minutes',
    F.when(
        (col('crs_arr_time_minutes').isNotNull()) &
        (col('crs_dep_time_minutes').isNotNull()),
        F.when(
            col('crs_arr_time_minutes') >= col('crs_dep_time_minutes'),
            col('crs_arr_time_minutes') - col('crs_dep_time_minutes')
        ).otherwise(
            col('crs_arr_time_minutes') + 1440 - col('crs_dep_time_minutes')  # Day rollover
        )
    ).otherwise(None)
)

# Cumulative delay (exclude current row)
window_spec_cumulative = Window.partitionBy(tail_num_col).orderBy(F.col('arrival_timestamp').asc_nulls_last()).rowsBetween(Window.unboundedPreceding, -1)
df_final = df_final.withColumn('lineage_cumulative_delay', F.sum('dep_delay').over(window_spec_cumulative))

# Number of previous flights
df_final = df_final.withColumn('lineage_num_previous_flights', F.count('*').over(window_spec_cumulative))

# Average and max delay from previous flights
df_final = df_final.withColumn('lineage_avg_delay_previous_flights', F.avg('dep_delay').over(window_spec_cumulative))
df_final = df_final.withColumn('lineage_max_delay_previous_flights', F.max('dep_delay').over(window_spec_cumulative))

print("✓ Cumulative features computed")


In [None]:
# Step 8: Jump Detection
print("\nStep 8: Detecting jumps (aircraft repositioning)...")

# Jump = previous flight didn't arrive at current origin
# After imputation, prev_flight_dest will never be NULL (imputed to origin for first flight)
df_final = df_final.withColumn(
    'lineage_is_jump',
    F.when(col('lineage_rank') == 1, F.lit(False))  # First flight is not a jump
    .when(col('prev_flight_dest') != col('origin'), F.lit(True))  # Route mismatch = jump
    .otherwise(F.lit(False))
)

print("✓ Jump detection complete")


In [None]:
# Step 9: Check Data Leakage for All Risky Columns and Create columns_with_data_leakage Array
print("\nStep 9: Checking data leakage for all risky columns and creating columns_with_data_leakage array...")

# Create prediction_cutoff column: scheduled departure time - 2 hours (in minutes since midnight)
# This is the cutoff time for data leakage - any actual event after this time has data leakage
df_final = df_final.withColumn(
    'prediction_cutoff_minutes',
    F.when(
        col('crs_dep_time_minutes').isNotNull(),
        col('crs_dep_time_minutes') - 120
    ).otherwise(None)
)

# Also create as timestamp for easier comparison
df_final = df_final.withColumn(
    'prediction_cutoff_timestamp',
    F.when(
        col('sched_depart_date_time').isNotNull(),
        F.expr("sched_depart_date_time - INTERVAL 2 HOURS")
    ).otherwise(None)
)

# Data leakage = timestamp is AFTER prediction_cutoff, i.e., timestamp > prediction_cutoff_minutes

# Convert all actual time columns to minutes for checking
df_final = df_final.withColumn(
    'prev_flight_actual_dep_time_minutes',
    F.when(
        col('prev_flight_actual_dep_time').isNotNull(),
        (F.floor(col('prev_flight_actual_dep_time') / 100) * 60 + (col('prev_flight_actual_dep_time') % 100))
    ).otherwise(None)
)
# prev_flight_actual_arr_time_minutes already exists from Step 6

df_final = df_final.withColumn(
    'prev_flight_wheels_off_minutes',
    F.when(
        col('prev_flight_wheels_off').isNotNull(),
        (F.floor(col('prev_flight_wheels_off') / 100) * 60 + (col('prev_flight_wheels_off') % 100))
    ).otherwise(None)
)

df_final = df_final.withColumn(
    'prev_flight_wheels_on_minutes',
    F.when(
        col('prev_flight_wheels_on').isNotNull(),
        (F.floor(col('prev_flight_wheels_on') / 100) * 60 + (col('prev_flight_wheels_on') % 100))
    ).otherwise(None)
)

# Check each risky timestamp column for data leakage
# Data leakage = timestamp is AFTER prediction_cutoff, i.e., timestamp > prediction_cutoff_minutes
df_final = df_final.withColumn(
    'prev_flight_actual_dep_time_has_leakage',
    F.when(
        (col('prev_flight_actual_dep_time_minutes').isNotNull()) &
        (col('prediction_cutoff_minutes').isNotNull()),
        col('prev_flight_actual_dep_time_minutes') > col('prediction_cutoff_minutes')
    ).otherwise(False)
)

df_final = df_final.withColumn(
    'prev_flight_actual_arr_time_has_leakage',
    F.when(
        (col('prev_flight_actual_arr_time_minutes').isNotNull()) &
        (col('prediction_cutoff_minutes').isNotNull()),
        col('prev_flight_actual_arr_time_minutes') > col('prediction_cutoff_minutes')
    ).otherwise(False)
)

df_final = df_final.withColumn(
    'prev_flight_wheels_off_has_leakage',
    F.when(
        (col('prev_flight_wheels_off_minutes').isNotNull()) &
        (col('prediction_cutoff_minutes').isNotNull()),
        col('prev_flight_wheels_off_minutes') > col('prediction_cutoff_minutes')
    ).otherwise(False)
)

df_final = df_final.withColumn(
    'prev_flight_wheels_on_has_leakage',
    F.when(
        (col('prev_flight_wheels_on_minutes').isNotNull()) &
        (col('prediction_cutoff_minutes').isNotNull()),
        col('prev_flight_wheels_on_minutes') > col('prediction_cutoff_minutes')
    ).otherwise(False)
)

# For duration fields, check if their source columns have data leakage
# dep_delay is derived from actual_dep_time
df_final = df_final.withColumn(
    'prev_flight_dep_delay_has_leakage',
    col('prev_flight_actual_dep_time_has_leakage')
)

# arr_delay is derived from actual_arr_time
df_final = df_final.withColumn(
    'prev_flight_arr_delay_has_leakage',
    col('prev_flight_actual_arr_time_has_leakage')
)

# taxi_in is derived from wheels_on (arrival) - check if wheels_on has leakage
df_final = df_final.withColumn(
    'prev_flight_taxi_in_has_leakage',
    col('prev_flight_wheels_on_has_leakage')
)

# taxi_out is derived from wheels_off (departure) - check if wheels_off has leakage
df_final = df_final.withColumn(
    'prev_flight_taxi_out_has_leakage',
    col('prev_flight_wheels_off_has_leakage')
)

# air_time and actual_elapsed_time are derived from actual times - check if any source has leakage
df_final = df_final.withColumn(
    'prev_flight_air_time_has_leakage',
    (col('prev_flight_actual_dep_time_has_leakage') | col('prev_flight_actual_arr_time_has_leakage'))
)

df_final = df_final.withColumn(
    'prev_flight_actual_elapsed_time_has_leakage',
    (col('prev_flight_actual_dep_time_has_leakage') | col('prev_flight_actual_arr_time_has_leakage'))
)

# cancelled and diverted - assume they may be known late (always risky, but check if we have actual times)
# If we have actual times that have leakage, these are also risky
df_final = df_final.withColumn(
    'prev_flight_cancelled_has_leakage',
    F.when(
        col('prev_flight_cancelled').isNotNull(),
        (col('prev_flight_actual_dep_time_has_leakage') | col('prev_flight_actual_arr_time_has_leakage'))
    ).otherwise(False)
)

df_final = df_final.withColumn(
    'prev_flight_diverted_has_leakage',
    F.when(
        col('prev_flight_diverted').isNotNull(),
        (col('prev_flight_actual_dep_time_has_leakage') | col('prev_flight_actual_arr_time_has_leakage'))
    ).otherwise(False)
)

# Engineered features using actual times
# lineage_actual_turnover_time_minutes uses actual_arr_time and actual_dep_time
df_final = df_final.withColumn(
    'lineage_actual_turnover_time_minutes_has_leakage',
    (col('prev_flight_actual_arr_time_has_leakage') | col('prev_flight_actual_dep_time_has_leakage'))
)
df_final = df_final.withColumn(
    'lineage_actual_taxi_time_minutes_has_leakage',
    col('lineage_actual_turnover_time_minutes_has_leakage')
)
df_final = df_final.withColumn(
    'lineage_actual_turn_time_minutes_has_leakage',
    col('lineage_actual_turnover_time_minutes_has_leakage')
)

# Cumulative features derived from actual delays
df_final = df_final.withColumn(
    'lineage_cumulative_delay_has_leakage',
    col('prev_flight_dep_delay_has_leakage')
)
df_final = df_final.withColumn(
    'lineage_avg_delay_previous_flights_has_leakage',
    col('prev_flight_dep_delay_has_leakage')
)
df_final = df_final.withColumn(
    'lineage_max_delay_previous_flights_has_leakage',
    col('prev_flight_dep_delay_has_leakage')
)

# Create an array column listing all columns that have data leakage for this row
# This is per-row, so each row knows which specific columns have leakage
df_final = df_final.withColumn(
    'columns_with_data_leakage',
    F.array_remove(
        F.array([
            F.when(col('prev_flight_actual_dep_time_has_leakage'), F.lit('prev_flight_actual_dep_time')).otherwise(None),
            F.when(col('prev_flight_actual_arr_time_has_leakage'), F.lit('prev_flight_actual_arr_time')).otherwise(None),
            F.when(col('prev_flight_wheels_off_has_leakage'), F.lit('prev_flight_wheels_off')).otherwise(None),
            F.when(col('prev_flight_wheels_on_has_leakage'), F.lit('prev_flight_wheels_on')).otherwise(None),
            F.when(col('prev_flight_dep_delay_has_leakage'), F.lit('prev_flight_dep_delay')).otherwise(None),
            F.when(col('prev_flight_arr_delay_has_leakage'), F.lit('prev_flight_arr_delay')).otherwise(None),
            F.when(col('prev_flight_air_time_has_leakage'), F.lit('prev_flight_air_time')).otherwise(None),
            F.when(col('prev_flight_taxi_in_has_leakage'), F.lit('prev_flight_taxi_in')).otherwise(None),
            F.when(col('prev_flight_taxi_out_has_leakage'), F.lit('prev_flight_taxi_out')).otherwise(None),
            F.when(col('prev_flight_actual_elapsed_time_has_leakage'), F.lit('prev_flight_actual_elapsed_time')).otherwise(None),
            F.when(col('prev_flight_cancelled_has_leakage'), F.lit('prev_flight_cancelled')).otherwise(None),
            F.when(col('prev_flight_diverted_has_leakage'), F.lit('prev_flight_diverted')).otherwise(None),
            F.when(col('lineage_actual_turnover_time_minutes_has_leakage'), F.lit('lineage_actual_turnover_time_minutes')).otherwise(None),
            F.when(col('lineage_actual_taxi_time_minutes_has_leakage'), F.lit('lineage_actual_taxi_time_minutes')).otherwise(None),
            F.when(col('lineage_actual_turn_time_minutes_has_leakage'), F.lit('lineage_actual_turn_time_minutes')).otherwise(None),
            F.when(col('lineage_cumulative_delay_has_leakage'), F.lit('lineage_cumulative_delay')).otherwise(None),
            F.when(col('lineage_avg_delay_previous_flights_has_leakage'), F.lit('lineage_avg_delay_previous_flights')).otherwise(None),
            F.when(col('lineage_max_delay_previous_flights_has_leakage'), F.lit('lineage_max_delay_previous_flights')).otherwise(None)
        ]),
        None
    )
)

# Also keep backward-compatible safety flags
df_final = df_final.withColumn(
    'prev_arr_time_safe_to_use',
    ~col('prev_flight_actual_arr_time_has_leakage')
)
df_final = df_final.withColumn(
    'prev_dep_time_safe_to_use',
    ~col('prev_flight_actual_dep_time_has_leakage')
)

print("✓ Data leakage checks complete")
print(f"   Created prediction_cutoff_minutes and prediction_cutoff_timestamp columns")
print(f"   Created columns_with_data_leakage array (per-row list of columns with leakage)")
print(f"   Each row knows which specific columns have data leakage")


In [None]:
# Step 10: Apply Imputation for NULL Values (First Flight Handling)
print("\nStep 10: Applying imputation for NULL values (first flight handling)...\n")

# Imputation Strategy: For first flights (lineage_rank == 1), assume aircraft was at airport
# 4 hours before scheduled departure, with no prior delays (anti-delay: -10 minutes)

# STEP 1: Calculate scheduled times backwards from current scheduled departure - 4 hours
# Use crs_dep_time_minutes if available (from Step 5), otherwise calculate it
# prev_flight_crs_dep_time = current_crs_dep_time - 4 hours (240 minutes)
df_final = df_final.withColumn(
    'prev_flight_crs_dep_time_minutes',
    F.coalesce(
        # If prev_flight_crs_dep_time exists, convert to minutes
        F.when(
            col('prev_flight_crs_dep_time').isNotNull(),
            (F.floor(col('prev_flight_crs_dep_time') / 100) * 60 + (col('prev_flight_crs_dep_time') % 100))
        ),
        # Otherwise, calculate backwards: current_crs_dep_time_minutes - 240
        F.when(
            col('crs_dep_time_minutes').isNotNull(),
            col('crs_dep_time_minutes') - 240
        ).otherwise(
            # Fallback: calculate from crs_dep_time if minutes not available
            F.when(
                col('crs_dep_time').isNotNull(),
                (F.floor(col('crs_dep_time') / 100) * 60 + (col('crs_dep_time') % 100)) - 240
            )
        )
    )
)

# Convert back to HHMM format (handle day rollover: if negative, add 1440)
df_final = df_final.withColumn(
    'prev_flight_crs_dep_time',
    F.coalesce(
        col('prev_flight_crs_dep_time'),
        F.when(
            col('prev_flight_crs_dep_time_minutes').isNotNull(),
            F.when(
                col('prev_flight_crs_dep_time_minutes') >= 0,
                # Same day: convert minutes back to HHMM
                (F.floor(col('prev_flight_crs_dep_time_minutes') / 60) * 100) + (col('prev_flight_crs_dep_time_minutes') % 60)
            ).otherwise(
                # Previous day: add 1440 minutes (24 hours)
                (F.floor((col('prev_flight_crs_dep_time_minutes') + 1440) / 60) * 100) + ((col('prev_flight_crs_dep_time_minutes') + 1440) % 60)
            )
        )
    )
)

# prev_flight_crs_arr_time = same as dep time for first flight (arrived 4 hours before current dep)
df_final = df_final.withColumn(
    'prev_flight_crs_arr_time',
    F.coalesce(col('prev_flight_crs_arr_time'), col('prev_flight_crs_dep_time'))
)

# prev_flight_crs_elapsed_time = current_crs_elapsed_time
df_final = df_final.withColumn(
    'prev_flight_crs_elapsed_time',
    F.coalesce(col('prev_flight_crs_elapsed_time'), col('crs_elapsed_time'))
)

# STEP 2: Impute delays (anti-delay: -10 minutes)
df_final = df_final.withColumn('prev_flight_dep_delay', F.coalesce(col('prev_flight_dep_delay'), F.lit(-10.0)))
df_final = df_final.withColumn('prev_flight_arr_delay', F.coalesce(col('prev_flight_arr_delay'), F.lit(-10.0)))

# STEP 3: Impute actual times using scheduled times (predetermined, safe)
df_final = df_final.withColumn('prev_flight_actual_dep_time', F.coalesce(col('prev_flight_actual_dep_time'), col('prev_flight_crs_dep_time')))
df_final = df_final.withColumn('prev_flight_actual_arr_time', F.coalesce(col('prev_flight_actual_arr_time'), col('prev_flight_crs_arr_time')))

# STEP 4: Impute route information
df_final = df_final.withColumn('prev_flight_dest', F.coalesce(col('prev_flight_dest'), col('origin')))
df_final = df_final.withColumn('prev_flight_origin', F.coalesce(col('prev_flight_origin'), col('origin')))

# STEP 5: Impute time components
df_final = df_final.withColumn('prev_flight_air_time', F.coalesce(col('prev_flight_air_time'), col('prev_flight_crs_elapsed_time')))
df_final = df_final.withColumn('prev_flight_taxi_in', F.coalesce(col('prev_flight_taxi_in'), F.lit(10.0)))
df_final = df_final.withColumn('prev_flight_taxi_out', F.coalesce(col('prev_flight_taxi_out'), F.lit(15.0)))
df_final = df_final.withColumn('prev_flight_actual_elapsed_time', F.coalesce(col('prev_flight_actual_elapsed_time'), col('prev_flight_crs_elapsed_time')))

# STEP 6: Impute wheels off/on (use scheduled times)
df_final = df_final.withColumn('prev_flight_wheels_off', F.coalesce(col('prev_flight_wheels_off'), col('prev_flight_crs_dep_time')))
df_final = df_final.withColumn('prev_flight_wheels_on', F.coalesce(col('prev_flight_wheels_on'), col('prev_flight_crs_arr_time')))

# STEP 7: Impute route distance
df_final = df_final.withColumn('prev_flight_distance', F.coalesce(col('prev_flight_distance'), F.lit(0.0)))

# STEP 8: Impute status flags
df_final = df_final.withColumn('prev_flight_cancelled', F.coalesce(col('prev_flight_cancelled'), F.lit(0)))
df_final = df_final.withColumn('prev_flight_diverted', F.coalesce(col('prev_flight_diverted'), F.lit(0)))

# STEP 9: Impute flight metadata (use current flight values)
df_final = df_final.withColumn('prev_flight_fl_date', F.coalesce(col('prev_flight_fl_date'), col('fl_date')))
df_final = df_final.withColumn('prev_flight_op_carrier', F.coalesce(col('prev_flight_op_carrier'), col('op_carrier')))
df_final = df_final.withColumn('prev_flight_op_carrier_fl_num', F.coalesce(col('prev_flight_op_carrier_fl_num'), col('op_carrier_fl_num')))

# STEP 10: Impute engineered lineage features
# Turnover Time: 240 minutes (4 hours) - represents overnight/maintenance gap
df_final = df_final.withColumn('lineage_turnover_time_minutes', F.coalesce(col('lineage_turnover_time_minutes'), F.lit(240.0)))
df_final = df_final.withColumn('lineage_taxi_time_minutes', F.coalesce(col('lineage_taxi_time_minutes'), col('lineage_turnover_time_minutes')))
df_final = df_final.withColumn('lineage_turn_time_minutes', F.coalesce(col('lineage_turn_time_minutes'), col('lineage_turnover_time_minutes')))
df_final = df_final.withColumn('lineage_actual_turnover_time_minutes', F.coalesce(col('lineage_actual_turnover_time_minutes'), F.lit(240.0)))
df_final = df_final.withColumn('lineage_actual_taxi_time_minutes', F.coalesce(col('lineage_actual_taxi_time_minutes'), col('lineage_actual_turnover_time_minutes')))
df_final = df_final.withColumn('lineage_actual_turn_time_minutes', F.coalesce(col('lineage_actual_turn_time_minutes'), col('lineage_actual_turnover_time_minutes')))

# Cumulative delays: 0 (no previous flights)
df_final = df_final.withColumn('lineage_cumulative_delay', F.coalesce(col('lineage_cumulative_delay'), F.lit(0.0)))
df_final = df_final.withColumn('lineage_avg_delay_previous_flights', F.coalesce(col('lineage_avg_delay_previous_flights'), F.lit(-10.0)))
df_final = df_final.withColumn('lineage_max_delay_previous_flights', F.coalesce(col('lineage_max_delay_previous_flights'), F.lit(-10.0)))
df_final = df_final.withColumn('lineage_num_previous_flights', F.coalesce(col('lineage_num_previous_flights'), F.lit(0)))
df_final = df_final.withColumn('lineage_expected_flight_time_minutes', F.coalesce(col('lineage_expected_flight_time_minutes'), col('crs_elapsed_time')))

# Clean up temporary column
df_final = df_final.drop('prev_flight_crs_dep_time_minutes')

print("✓ Imputation complete - all NULLs replaced with design doc values")
print("  Scheduled times calculated backwards: current_crs_dep_time - 4 hours")
print("  Delays: -10 minutes (anti-delay)")
print("  Turnover time: 240 minutes (4 hours)")



## Data Dictionary - Flight Lineage Features

This section documents all new columns added by the Flight Lineage Join.

### Prediction Cutoff
- **`prediction_cutoff_minutes`**: double (nullable) - Scheduled departure time minus 2 hours, in minutes since midnight. Used to determine data leakage cutoff.
- **`prediction_cutoff_timestamp`**: timestamp (nullable) - Scheduled departure time minus 2 hours as timestamp. Alternative format for comparison.

### Previous Flight Raw Data (prev_flight_*)
- **`lineage_rank`**: int - Rank of flight in aircraft's sequence (1 = earliest, higher = more recent). Always non-null.
- **`prev_flight_origin`**: string (nullable) - Previous flight's origin airport code
- **`prev_flight_dest`**: string (nullable) - Previous flight's destination airport code. Key for jump detection.
- **`prev_flight_fl_date`**: date (nullable) - Previous flight's date
- **`prev_flight_op_carrier`**: string (nullable) - Previous flight's operating carrier code
- **`prev_flight_op_carrier_fl_num`**: string (nullable) - Previous flight's flight number
- **`prev_flight_crs_dep_time`**: int (nullable) - Previous flight's scheduled departure time (HHMM format) [SAFE]
- **`prev_flight_crs_arr_time`**: int (nullable) - Previous flight's scheduled arrival time (HHMM format) [SAFE]
- **`prev_flight_crs_elapsed_time`**: double (nullable) - Previous flight's scheduled elapsed time (minutes) [SAFE]
- **`prev_flight_distance`**: double (nullable) - Previous flight's distance (miles) [SAFE]
- **`prev_flight_actual_dep_time`**: int (nullable) - Previous flight's actual departure time (HHMM format) [DATA LEAKAGE RISK]
- **`prev_flight_actual_arr_time`**: int (nullable) - Previous flight's actual arrival time (HHMM format) [DATA LEAKAGE RISK]
- **`prev_flight_dep_delay`**: double (nullable) - Previous flight's departure delay (minutes) [DATA LEAKAGE RISK] (derived from actual_dep_time)
- **`prev_flight_arr_delay`**: double (nullable) - Previous flight's arrival delay (minutes) [DATA LEAKAGE RISK] (derived from actual_arr_time)
- **`prev_flight_air_time`**: double (nullable) - Previous flight's air time (minutes) [DATA LEAKAGE RISK] (derived from actual times)
- **`prev_flight_taxi_in`**: double (nullable) - Previous flight's taxi-in time (minutes) [DATA LEAKAGE RISK] (derived from wheels_on)
- **`prev_flight_taxi_out`**: double (nullable) - Previous flight's taxi-out time (minutes) [DATA LEAKAGE RISK] (derived from wheels_off)
- **`prev_flight_wheels_off`**: int (nullable) - Previous flight's wheels off time (HHMM format) [DATA LEAKAGE RISK]
- **`prev_flight_wheels_on`**: int (nullable) - Previous flight's wheels on time (HHMM format) [DATA LEAKAGE RISK]
- **`prev_flight_actual_elapsed_time`**: double (nullable) - Previous flight's actual elapsed time (minutes) [DATA LEAKAGE RISK] (derived from actual times)
- **`prev_flight_cancelled`**: int (nullable) - Previous flight's cancellation status (0/1) [DATA LEAKAGE RISK] (may be known late)
- **`prev_flight_diverted`**: int (nullable) - Previous flight's diversion status (0/1) [DATA LEAKAGE RISK] (may be known late)

### Lineage Engineered Features (lineage_*)
- **`lineage_is_jump`**: boolean - Flag indicating if this is a jump (aircraft repositioning) or data gap. True if prev_flight_dest != origin.
- **`lineage_turnover_time_minutes`**: double (nullable) - Expected turnover time: time between previous flight's scheduled arrival and current flight's scheduled departure (minutes) [SAFE]
- **`lineage_taxi_time_minutes`**: double (nullable) - Alias for `lineage_turnover_time_minutes` [SAFE]
- **`lineage_turn_time_minutes`**: double (nullable) - Alias for `lineage_turnover_time_minutes` [SAFE]
- **`lineage_actual_turnover_time_minutes`**: double (nullable) - Actual turnover time: time between previous flight's actual arrival and current flight's actual departure (minutes) [DATA LEAKAGE RISK]
- **`lineage_actual_taxi_time_minutes`**: double (nullable) - Alias for `lineage_actual_turnover_time_minutes` [DATA LEAKAGE RISK]
- **`lineage_actual_turn_time_minutes`**: double (nullable) - Alias for `lineage_actual_turnover_time_minutes` [DATA LEAKAGE RISK]
- **`lineage_expected_flight_time_minutes`**: double (nullable) - Expected flight time: scheduled arrival - scheduled departure (minutes) [SAFE]
- **`lineage_cumulative_delay`**: double (nullable) - Total delay accumulated by previous flights (minutes) [DATA LEAKAGE RISK] (derived from actual delays)
- **`lineage_num_previous_flights`**: long (nullable) - Number of flights the aircraft has already completed [SAFE] (count only)
- **`lineage_avg_delay_previous_flights`**: double (nullable) - Average delay across previous flights (minutes) [DATA LEAKAGE RISK] (derived from actual delays)
- **`lineage_max_delay_previous_flights`**: double (nullable) - Maximum delay in previous flights (minutes) [DATA LEAKAGE RISK] (derived from actual delays)

### Data Leakage Flags
- **`prev_arr_time_safe_to_use`**: boolean - Flag indicating if previous arrival time is safe to use (no data leakage)
- **`prev_dep_time_safe_to_use`**: boolean - Flag indicating if previous departure time is safe to use (no data leakage)
- **`columns_with_data_leakage`**: array<string> - **Per-row array** listing column names that have data leakage for this specific flight. Empty array `[]` means no leakage. Example: `['prev_flight_actual_dep_time', 'prev_flight_dep_delay', ...]`

### Data Leakage Rules
- **Cutoff**: `prediction_cutoff_minutes` = scheduled departure time - 2 hours
- **Timestamp columns**: Have leakage if timestamp > prediction_cutoff_minutes
- **Duration columns**: Have leakage if their source timestamp columns have leakage
- **Status flags**: Have leakage if any related actual times have leakage

### Imputation Values (First Flight)
- Delays: -10 minutes (anti-delay, early departure/arrival)
- Turnover time: 240 minutes (4 hours, overnight/maintenance gap)
- Cumulative delays: 0.0
- Number of previous flights: 0

**See `FLIGHT_LINEAGE_JOIN_DESIGN.md` for complete documentation.**


In [None]:
# Summary
print("\n" + "=" * 60)
print("✓ FLIGHT LINEAGE JOIN COMPLETE")
print("=" * 60)
print(f"\nNew columns added: ~38 lineage features")
print(f"Data leakage flags: prev_arr_time_safe_to_use, prev_dep_time_safe_to_use")
print(f"Risky columns documented in: columns_with_data_leakage (18 risky columns)")
print(f"\nAll flights preserved - no rows dropped")
print("=" * 60)


In [0]:

# Write
df_final.repartition(200).write.mode("overwrite").parquet(flights_weather_joined_path)
print(f"✓ Data saved to: {flights_weather_joined_path}")

In [0]:
df_final.columns

In [0]:
def check_memory_usage():
    """Check current memory usage"""
    
    print("Memory Usage Report")
    print("="*60)
    
    # 1. Check cached data
    try:
        cached_tables = spark.sql("SHOW TABLES").filter(col("isTemporary") == True)
        print(f"\n1. Cached Tables: {cached_tables.count()}")
    except:
        pass
    
    # 2. Check RDD cache
    cached_rdds = spark.sparkContext._jsc.getPersistentRDDs()
    print(f"\n2. Cached RDDs: {len(cached_rdds)}")
    
    # 3. Check broadcast variables
    try:
        broadcast_count = len(spark.sparkContext._jsc.sc().getBroadcastVariables())
        print(f"\n3. Broadcast Variables: {broadcast_count}")
    except:
        print(f"\n3. Broadcast Variables: Unknown")
    
    # 4. Storage memory
    try:
        storage_status = spark.sparkContext._jsc.sc().getExecutorStorageStatus()
        total_memory = sum([s.maxMem() for s in storage_status])
        used_memory = sum([s.memUsed() for s in storage_status])
        
        print(f"\n4. Storage Memory:")
        print(f"   Used: {used_memory / 1024**3:.2f} GB")
        print(f"   Total: {total_memory / 1024**3:.2f} GB")
        print(f"   Usage: {used_memory / total_memory * 100:.1f}%")
    except Exception as e:
        print(f"\n4. Storage Memory: Unable to retrieve")
    
    print("="*60)

# Check before and after cleanup
check_memory_usage()

In [0]:
# Check for duplicate columns
print("Checking for duplicate columns...")
print("="*60)

columns = df_final.columns
duplicates = [col for col in set(columns) if columns.count(col) > 1]

if duplicates:
    print(f"❌ Found duplicate columns: {duplicates}")
    print(f"\nAll columns ({len(columns)}):")
    for i, col in enumerate(columns):
        print(f"  {i+1}. {col}")
else:
    print("✓ No duplicates found")

In [0]:
year_columns = [x for x in df_final.columns if 'weather_year' in x] 
print(year_columns)

In [0]:

# Quick validation (optional - comment out if too slow)
print(f"Full join path: {flights_weather_joined_path}")
df_verification = spark.read.parquet(flights_weather_joined_path)
print(f"Verification count: {df_verification.count():,}")

# OTPW Data

In [0]:
dbutils.fs.ls("dbfs:/mnt/mids-w261/OTPW_60M_Backup/")

In [0]:

df_otpw_path = "dbfs:/mnt/mids-w261/OTPW_60M_Backup/"
df_otpw = spark.read.parquet(df_otpw_path)

df_otpw = df_otpw.toDF(*[c.lower() for c in df_otpw.columns])



In [0]:
df_otpw.count()

#### Compare OTPW and Custom Join - full

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, sum as _sum

def compare_dataframes(df1: DataFrame, df2: DataFrame, name1: str = "df1", name2: str = "df2"):
    print(f"===== 📊 Comparing {name1} and {name2} =====\n")

    # -----------------------------
    # 1. Row Count
    # -----------------------------
    print("➡ Row Counts:")
    print(f"{name1}: {df1.count()}")
    print(f"{name2}: {df2.count()}\n")

    # -----------------------------
    # 2. Column Names
    # -----------------------------
    cols1 = set(df1.columns)
    cols2 = set(df2.columns)

    print("➡ Columns Present in Both:")
    print(sorted(cols1 & cols2), "\n")

    print("➡ Columns Only in", name1)
    print(sorted(cols1 - cols2), "\n")

    print("➡ Columns Only in", name2)
    print(sorted(cols2 - cols1), "\n")

    # -----------------------------
    # 3. Schema Comparison
    # -----------------------------
    print("➡ Schema Differences:")
    schema1 = {f.name: f.dataType for f in df1.schema.fields}
    schema2 = {f.name: f.dataType for f in df2.schema.fields}

    common = cols1 & cols2

    diffs = {c: (schema1[c], schema2[c]) for c in common if schema1[c] != schema2[c]}

    if diffs:
        for col_name, (t1, t2) in diffs.items():
            print(f" - Column '{col_name}': {name1}={t1}, {name2}={t2}")
    else:
        print("No schema differences.\n")

    # -----------------------------
    # 4. Null Count Comparison
    # -----------------------------
    print("\n➡ Null Counts Per Column:")

    print(f"\n{name1} Null Counts:")
    nulls1 = df1.select([_sum(col(c).isNull().cast("int")).alias(c) for c in df1.columns])
    nulls1.show(truncate=False)

    print(f"\n{name2} Null Counts:")
    nulls2 = df2.select([_sum(col(c).isNull().cast("int")).alias(c) for c in df2.columns])
    nulls2.show(truncate=False)

    # -----------------------------
    # 5. Numeric Summary Comparison (optional)
    # -----------------------------
    numeric_cols1 = [f.name for f in df1.schema.fields if "int" in str(f.dataType) or "double" in str(f.dataType)]
    numeric_cols2 = [f.name for f in df2.schema.fields if "int" in str(f.dataType) or "double" in str(f.dataType)]
    numeric_common = list(set(numeric_cols1) & set(numeric_cols2))

    if numeric_common:
        print("\n➡ Summary Statistics (common numeric columns):")
        print(f"Common numeric columns: {numeric_common}\n")

        print(f"{name1} summary:")
        df1.select(numeric_common).summary().show(truncate=False)

        print(f"{name2} summary:")
        df2.select(numeric_common).summary().show(truncate=False)
    else:
        print("\nNo common numeric columns to summarize.")


In [0]:
compare_dataframes(df_otpw, df_final, "df_otpw", "df_final")


In [0]:
df_otpw.groupBy(
    "fl_date", "dep_time", "op_carrier", "origin", "dest"
).count().count()

In [0]:
df_final.groupBy(
    "fl_date", "dep_time", "op_carrier", "origin", "dest"
).count().count()

# Make Custom Join DF match OTPW Schema



In [0]:
from pyspark.sql.functions import col

# mismatched columns extracted from schema diff
cols_to_string = [
    'carrier_delay', 'wheels_on', 'first_dep_time', 'actual_elapsed_time', 'wheels_off',
    'dest_airport_seq_id', 'dep_delay', 'sched_depart_date_time', 'day_of_week',
    'origin_city_market_id', 'taxi_out', 'arr_time', 'crs_arr_time', 'late_aircraft_delay',
    'month', 'longest_add_gtime', 'origin_airport_seq_id', 'dep_time', 'origin_state_fips',
    'arr_delay_new', 'flights', 'air_time', 'dest_city_market_id', 'arr_delay_group',
    'dest_airport_id', 'dep_del15', 'security_delay', 'crs_dep_time', 'crs_elapsed_time',
    'arr_del15', 'op_carrier_fl_num', 'total_add_gtime', 'diverted', 'day_of_month', 'taxi_in',
    'op_carrier_airline_id', 'distance_group', 'arr_delay', 'origin_wac', 'dep_delay_new',
    'quarter', 'dest_wac', 'origin_airport_id', 'weather_delay', 'nas_delay', 'distance',
    'cancelled', 'dest_state_fips', 'dep_delay_group'
]

df_final_casted = df_final
for c in cols_to_string:
    if c in df_final.columns:
        df_final_casted = df_final_casted.withColumn(c, col(c).cast("string"))

# rename column 
rename_map = {
    'destination_airport_name': 'dest_airport_name',
    'destination_latitude': 'dest_airport_lat',
    'destination_longitude': 'dest_airport_lon',
    'destination_country': 'dest_region',
    'destination_timezone': 'dest_type',
    
    'origin_latitude': 'origin_airport_lat',
    'origin_longitude': 'origin_airport_lon',
    'origin_country': 'origin_region',
    'origin_timezone': 'origin_type',
    
    'sched_depart_date_time_UTC': 'sched_depart_date_time_utc',
    'two_hours_prior_depart_UTC': 'two_hours_prior_depart_utc',
    'four_hours_prior_depart_UTC': 'four_hours_prior_depart_utc',

    'date_timestamp': 'date',     # if appropriate
    'fl_date_timestamp': 'fl_date'
}

df_final_renamed = df_final_casted
for old, new in rename_map.items():
    if old in df_final_renamed.columns:
        df_final_renamed = df_final_renamed.withColumnRenamed(old, new)

# Drop extra columns
cols_to_drop = [
    'div1_airport','div1_airport_id','div1_airport_seq_id','div1_longest_gtime',
    'div1_tail_num','div1_total_gtime','div1_wheels_off','div1_wheels_on',
    'div2_airport','div2_airport_id','div2_airport_seq_id','div2_longest_gtime',
    'div2_tail_num','div2_total_gtime','div2_wheels_off','div2_wheels_on',
    'div3_airport', 'div4_airport', 'div5_airport',
    'div_airport_landings','div_reached_dest','div_actual_elapsed_time','div_arr_delay',
    'div_distance','station_distance_km','date_timestamp'
]

df_final_clean = df_final_renamed.drop(*[c for c in cols_to_drop if c in df_final_renamed.columns])



In [0]:
df_final_clean.columns

In [None]:
# ============================================================
# SUMMARY: Row Count Comparison Across Pipeline Steps
# ============================================================
print("\n" + "="*80)
print("ROW COUNT SUMMARY - Tracking Flight Losses Through Pipeline")
print("="*80)

# Store counts (will be populated as cells run)
counts = {}

try:
    counts['Initial Load'] = df_flights.count()
except:
    counts['Initial Load'] = 'N/A'

try:
    counts['After dropDuplicates'] = df_flights.count()
except:
    counts['After dropDuplicates'] = 'N/A'

try:
    counts['After Airport Join'] = df_flights_with_airports.count()
except:
    counts['After Airport Join'] = 'N/A'

try:
    counts['After Station Join (Left)'] = df_flights_with_station.count()
except:
    counts['After Station Join (Left)'] = 'N/A'

try:
    counts['After Station Filter'] = df_flights_with_station_clean.count()
except:
    counts['After Station Filter'] = 'N/A'

try:
    counts['After Weather Join'] = df_joined.count()
except:
    counts['After Weather Join'] = 'N/A'

try:
    counts['Final Before Save'] = df_final.count()
except:
    counts['Final Before Save'] = 'N/A'

# Print summary table
print("\nStep-by-Step Row Counts:")
print("-" * 80)
print(f"{'Step':<40} {'Row Count':>20} {'Change':>15}")
print("-" * 80)

prev_count = None
for step, count in counts.items():
    if isinstance(count, int):
        change = ''
        if prev_count is not None:
            diff = count - prev_count
            pct = (diff / prev_count * 100) if prev_count > 0 else 0
            change = f'{diff:+,} ({pct:+.2f}%)'
        print(f'{step:<40} {count:>20,} {change:>15}')
        prev_count = count
    else:
        print(f'{step:<40} {count:>20}')

print("-" * 80)

# Calculate total loss
if isinstance(counts.get('Initial Load'), int) and isinstance(counts.get('Final Before Save'), int):
    total_loss = counts['Initial Load'] - counts['Final Before Save']
    loss_pct = (total_loss / counts['Initial Load'] * 100) if counts['Initial Load'] > 0 else 0
    print(f'\nTotal Flights Lost: {total_loss:,} ({loss_pct:.2f}%)')
    print(f'Final Retention Rate: {100 - loss_pct:.2f}%')

# Identify biggest drop
if isinstance(counts.get('Initial Load'), int):
    steps_with_counts = [(k, v) for k, v in counts.items() if isinstance(v, int)]
    if len(steps_with_counts) > 1:
        max_drop = 0
        max_drop_step = None
        for i in range(1, len(steps_with_counts)):
            prev_step, prev_count = steps_with_counts[i-1]
            curr_step, curr_count = steps_with_counts[i]
            drop = prev_count - curr_count
            if drop > max_drop:
                max_drop = drop
                max_drop_step = curr_step
        if max_drop > 0:
            print(f'\n⚠ Largest Drop: {max_drop:,} flights at \'{max_drop_step}\'')

print("="*80)


In [0]:
# COMMAND ----------
# MAGIC %md
# MAGIC ## Save Final Result

# COMMAND ----------
print("="*60)
print("SAVING FINAL RESULT")
print("="*60)

# Check current partitioning
num_partitions = df_final.rdd.getNumPartitions()
print(f"Current partitions: {num_partitions}")

# Only repartition if necessary
if num_partitions > 500:
    print(f"⚠ Too many partitions, coalescing to 200")
    df_final = df_final.coalesce(200)
elif num_partitions < 10:
    print(f"⚠ Too few partitions, repartitioning to 50")
    df_final = df_final.repartition(50)
else:
    print(f"✓ Partition count looks good")

# Write
df_final.write.mode("overwrite").parquet(flights_weather_joined_path)
print(f"✓ Data saved to: {flights_weather_joined_path}")

# Quick validation (optional - comment out if too slow)
# df_verification = spark.read.parquet(flights_weather_joined_path)
# print(f"Verification count: {df_verification.count():,}")

# Cleanup
df_airports_clean.unpersist()

print("="*60)
print("ALL JOINS COMPLETE!")
print("="*60)

In [None]:
# ============================================================
# SANITY CHECK: Final Before Save
# ============================================================

print("\n" + "="*80)
print(f"SANITY CHECK: Final Before Save")
print("="*80)

# Row count
try:
    row_count = df_final.count()
    print(f"\n✓ Row count: {row_count:,}")
except Exception as e:
    print(f"\n✗ Error counting rows: {e}")
    row_count = None

# Check for NULLs in key columns
print("\n--- NULL Analysis ---")
key_columns = []
if 'origin' in df_final.columns:
    key_columns.append('origin')
if 'dest' in df_final.columns:
    key_columns.append('dest')
if 'origin_station_id' in df_final.columns:
    key_columns.append('origin_station_id')
if 'origin_latitude' in df_final.columns:
    key_columns.append('origin_latitude')
if 'origin_longitude' in df_final.columns:
    key_columns.append('origin_longitude')

for col_name in key_columns:
    try:
        null_count = df_final.filter(F.col(col_name).isNull()).count()
        null_pct = (null_count / row_count * 100) if row_count else 0
        print(f"  {col_name}: {null_count:,} NULLs ({null_pct:.2f}%)")
    except Exception as e:
        print(f"  {col_name}: Error - {e}")

# Identify dropped airports (if applicable)
if 'origin_station_id' in df_final.columns:
    try:
        dropped_airports = df_final.filter(
            F.col('origin_station_id').isNull()
        ).select('origin').distinct()
        dropped_count = dropped_airports.count()
        if dropped_count > 0:
            print(f"\n--- Airports Without Stations: {dropped_count} ---")
            print("Sample airports without stations:")
            dropped_airports.show(20, truncate=False)
    except Exception as e:
        print(f"\nError analyzing dropped airports: {e}")

print("="*80)
