In [0]:
# spark
from pyspark.sql.functions import col, count, countDistinct
from pyspark.sql import functions as F

# time
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

# plot
import matplotlib.pyplot as plt

# pandas
import pandas as pd

# sys
import os

In [0]:
# Create dir for group
SECTION = "4"
NUMBER = "2"
GROUP_FOLDER_PATH = f"dbfs:/student-groups/Group_{SECTION}_{NUMBER}"
DATA_BASE_DIR = "dbfs:/mnt/mids-w261"

if not dbutils.fs.ls(GROUP_FOLDER_PATH):
    created = dbutils.fs.mkdirs(GROUP_FOLDER_PATH)
    print(f"Directory created:")
else:
    print("Directory already exists:")

print(FOLDER_PATH)

In [0]:
# Utils for checkpointing data
def checkpoint_data(df, name, folder_path=GROUP_FOLDER_PATH):
    df.write.parquet(f"{folder_path}/{name}.parquet")

def load_checkpointed_data(name, folder_path=GROUP_FOLDER_PATH):
    return spark.read.parquet(f"{folder_path}/{name}.parquet")

In [0]:
# Load or create data
try:
    train_otpw = load_checkpointed_data("OTPW_3M_TRAIN")
    test_otpw = load_checkpointed_data("OTPW_3M_TEST")
except:
    # Load OTPW data
    df_otpw = spark.read.format("csv").option("header","true").load(f"{DATA_BASE_DIR}/OTPW_3M_2015.csv")

    # Filter 
    df_otpw = df_otpw.filter(col("DEP_DELAY").isNotNull())

    # Get non null column names
    def get_non_null_cols(df, threshold=10):
        """
        Calculate and return column names with null percentages at or below the threshold.

        Args:
            df (DataFrame): Spark DataFrame to analyze.
            threshold (float, optional): Maximum percentage of nulls to include a column in the output. Defaults to 10%.

        Returns:
            list[str]: List of column names with null percentage <= threshold, sorted by percentage descending.
        """
        total_count = df.count()
        null_counts = df.select([
            (F.sum(F.col(c).isNull().cast("int")) / total_count * 100).alias(c) 
            for c in df.columns
        ])
        null_dict = null_counts.collect()[0].asDict()
        low_null_cols = {col: pct for col, pct in null_dict.items() if pct <= threshold}
        sorted_cols = sorted(low_null_cols.items(), key=lambda x: x[1], reverse=True)
        return [col for col, _ in sorted_cols]

    non_null_cols = get_non_null_cols(train_otpw, threshold=10)

    # only non-null columns
    train_otpw = train_otpw.select(non_null_cols)
    test_otpw = train_otpw.select(non_null_cols)

    # checkpoint data
    checkpoint_data(train_otpw, "OTPW_3M_TRAIN")
    checkpoint_data(test_otpw, "OTPW_3M_TEST")

In [0]:
display(train_otpw.limit(5))

In [0]:
def create_expanding_window_folds(
    df, 
    start_date: str | None = None, 
    end_date: str | None = None, 
    date_col: str = "FL_DATE",
    n_folds: int = 4
):
    """
    Create expanding window time-series cross-validation folds
    Automatically divides the time range into equal periods
    
    Args:
        df: PySpark DataFrame
        start_date: Starting date as string 'YYYY-MM-DD' or None for auto-detection
        end_date: Ending date as string 'YYYY-MM-DD' or None for auto-detection
        date_col: Name of the date column
        n_folds: Number of folds to create
    
    Returns: List of tuples (train_df, val_df) for each fold
    """
    # Auto-detect dates if not provided
    if start_date is None or end_date is None:
        date_range = df.select(
            F.min(F.col(date_col)).alias("start_date"),
            F.max(F.col(date_col)).alias("end_date")
        ).first()
        
        start_date = start_date or str(date_range["start_date"])
        end_date = end_date or str(date_range["end_date"])
    
    # Convert to datetime
    start_dt = datetime.strptime(start_date, '%Y-%m-%d')
    end_dt = datetime.strptime(end_date, '%Y-%m-%d')
    
    # Calculate total days and period size
    total_days = (end_dt - start_dt).days
    period_days = total_days // (n_folds + 1)  # +1 because first period is just training
    
    if period_days == 0:
        raise ValueError(f"Dataset too small for {n_folds} folds. Total days: {total_days}")
    
    # Print header
    print(f"Detected date range: {start_date} to {end_date}")
    print(f"Total days: {total_days}, Period size: {period_days} days (~{period_days/7:.1f} weeks)")
    print()
    
    # Create visual timeline
    _print_timeline(start_dt, end_dt, period_days, n_folds)
    print()
    
    # Create folds
    folds = []
    
    for fold_num in range(n_folds):
        # Calculate validation period
        val_start_days = period_days * (fold_num + 1)
        val_end_days = period_days * (fold_num + 2)
        
        val_start = start_dt + timedelta(days=val_start_days)
        val_end = start_dt + timedelta(days=val_end_days)
        
        # Make sure we don't exceed end date
        if val_end > end_dt:
            val_end = end_dt
        
        # Training: all data from start to val_start
        train_df = df.filter(
            (F.col(date_col) >= F.lit(start_date)) & 
            (F.col(date_col) < F.lit(val_start.strftime('%Y-%m-%d')))
        )
        
        # Validation: data in the validation window
        val_df = df.filter(
            (F.col(date_col) >= F.lit(val_start.strftime('%Y-%m-%d'))) & 
            (F.col(date_col) < F.lit(val_end.strftime('%Y-%m-%d')))
        )
        
        folds.append((train_df, val_df))
        
        train_days = (val_start - start_dt).days
        val_days = (val_end - val_start).days
        
        print(f"Fold {fold_num + 1}: Train [{start_date} to {val_start.strftime('%Y-%m-%d')}), "
              f"Val [{val_start.strftime('%Y-%m-%d')} to {val_end.strftime('%Y-%m-%d')})")
        print(f"  Train: {train_days} days, Val: {val_days} days")
        print(f"  Train size: {train_df.count()}, Val size: {val_df.count()}")
        print()
    
    return folds


def _print_timeline(start_dt, end_dt, period_days, n_folds):
    """Print visual timeline of the cross-validation strategy"""
    
    # Create period labels
    periods = []
    current = start_dt
    for i in range(n_folds + 1):
        period_end = current + timedelta(days=period_days)
        if period_end > end_dt:
            period_end = end_dt
        periods.append({
            'start': current,
            'end': period_end,
            'label': current.strftime('%Y-%m-%d')
        })
        current = period_end
    
    # Fixed width for each period segment
    segment_width = 17
    
    # Print timeline header
    print("Dataset Timeline: ", end="")
    for i, period in enumerate(periods):
        if i == 0:
            print(f"{period['label']}", end="")
        else:
            # Calculate separator to maintain fixed width
            separator_len = segment_width - len(period['label'])
            print(f" {'─' * separator_len} {period['label']}", end="")
    print()
    
    # Print timeline bars
    print(" " * 18, end="")
    for i in range(len(periods)):
        print(f"|{'═' * (segment_width - 1)}", end="")
    print("|")
    print()
    
    # Print each fold
    for fold_num in range(n_folds):
        train_periods = fold_num + 1
        val_period = fold_num + 1
        
        # Calculate positions with fixed width
        fold_label = f"Fold {fold_num + 1}:"
        fold_indent = 18  # Fixed indent for all fold labels
        
        # Build train label
        if train_periods == 1:
            train_label = f"[Train: {periods[0]['label']}]"
        else:
            train_label = f"[Train: {periods[0]['label']}─────{periods[train_periods-1]['label']}]"
        
        # Build val label
        val_label = f"[Val: {periods[val_period]['label']}]"
        
        # Calculate spacing to align val label with its period
        # Position of val period start
        val_position = fold_indent + (segment_width * val_period)
        
        # Current position after fold label and train label
        current_pos = len(fold_label) + len(train_label) + fold_indent
        
        # Calculate spacing needed
        spacing = val_position - current_pos
        
        # Print the fold description line
        print(f"{fold_label:<18}{train_label}", end="")
        print(" " * max(0, spacing), end="")
        print(val_label)
        
        # Print visual representation
        print(" " * fold_indent, end="")
        
        # Train arrow
        train_width = segment_width * train_periods
        print("|", end="")
        print("═" * (train_width - 1), end="")
        print(">", end="")
        
        # Val bar
        print("|", end="")
        val_width = segment_width - 2
        print("─" * val_width, end="")
        print("|")
        print()
    
    print("Legend:  [═══] = Training Data    [───] = Validation Data")

In [0]:
folds = create_expanding_window_folds(train_otpw, n_folds=4)

In [0]:
folds

In [0]:
def sanity_check_folds(folds, date_col="FL_DATE"):
    """
    Perform sanity checks on cross-validation folds
    
    Args:
        folds: List of tuples (train_df, val_df)
        date_col: Name of the date column
    """
    print("="*70)
    print("SANITY CHECKS FOR FOLDS")
    print("="*70)
    print()
    
    all_checks_passed = True
    
    for fold_num, (train_df, val_df) in enumerate(folds, 1):
        print(f"Fold {fold_num}:")
        print("-" * 50)
        
        # Get date ranges
        train_stats = train_df.select(
            F.min(F.col(date_col)).alias("train_min"),
            F.max(F.col(date_col)).alias("train_max"),
            F.count("*").alias("train_count")
        ).first()
        
        val_stats = val_df.select(
            F.min(F.col(date_col)).alias("val_min"),
            F.max(F.col(date_col)).alias("val_max"),
            F.count("*").alias("val_count")
        ).first()
        
        train_min = str(train_stats["train_min"])
        train_max = str(train_stats["train_max"])
        val_min = str(val_stats["val_min"])
        val_max = str(val_stats["val_max"])
        train_count = train_stats["train_count"]
        val_count = val_stats["val_count"]
        
        print(f"  Train date range: {train_min} to {train_max} ({train_count:,} rows)")
        print(f"  Val date range:   {val_min} to {val_max} ({val_count:,} rows)")
        
        # Check 1: Train max < Val min (no temporal overlap)
        if train_max >= val_min:
            print(f"  ❌ FAIL: Train max ({train_max}) >= Val min ({val_min}) - TEMPORAL OVERLAP!")
            all_checks_passed = False
        else:
            print(f"  ✓ PASS: No temporal overlap (train ends before val starts)")
        
        # Check 2: Val comes immediately after train (no gap)
        train_max_dt = datetime.strptime(train_max, '%Y-%m-%d')
        val_min_dt = datetime.strptime(val_min, '%Y-%m-%d')
        gap_days = (val_min_dt - train_max_dt).days
        
        if gap_days == 1:
            print(f"  ✓ PASS: Val starts immediately after train (1 day gap as expected)")
        elif gap_days > 1:
            print(f"  ⚠️  WARNING: {gap_days} day gap between train and val")
        else:
            print(f"  ❌ FAIL: Val starts before train ends (gap = {gap_days} days)")
            all_checks_passed = False
        
        # Check 3: No empty datasets
        if train_count == 0:
            print(f"  ❌ FAIL: Train set is empty!")
            all_checks_passed = False
        else:
            print(f"  ✓ PASS: Train set has {train_count:,} rows")
        
        if val_count == 0:
            print(f"  ❌ FAIL: Validation set is empty!")
            all_checks_passed = False
        else:
            print(f"  ✓ PASS: Validation set has {val_count:,} rows")
        
        # Check 4: Check for data leakage (any val dates in train)
        leakage_count = train_df.join(
            val_df.select(F.col(date_col).alias("val_date")),
            F.col(date_col) == F.col("val_date"),
            "inner"
        ).count()
        
        if leakage_count > 0:
            print(f"  ❌ FAIL: {leakage_count} rows with same dates in both train and val - DATA LEAKAGE!")
            all_checks_passed = False
        else:
            print(f"  ✓ PASS: No data leakage (no overlapping dates)")
        
        print()
    
    # Cross-fold checks
    print("="*70)
    print("CROSS-FOLD CHECKS")
    print("="*70)
    print()
    
    # Check that training sets are expanding
    for i in range(1, len(folds)):
        train_prev = folds[i-1][0]
        train_curr = folds[i][0]
        
        prev_count = train_prev.count()
        curr_count = train_curr.count()
        
        print(f"Fold {i} → Fold {i+1}: Train size {prev_count:,} → {curr_count:,}", end="")
        
        if curr_count > prev_count:
            print(" ✓ (expanding)")
        elif curr_count == prev_count:
            print(" ⚠️  WARNING: Train set not growing")
        else:
            print(" ❌ FAIL: Train set shrinking!")
            all_checks_passed = False
    
    print()
    
    # Check validation set sizes are similar
    val_sizes = [val_df.count() for _, val_df in folds]
    avg_val_size = sum(val_sizes) / len(val_sizes)
    max_deviation = max(abs(size - avg_val_size) / avg_val_size for size in val_sizes)
    
    print(f"Validation set sizes: {val_sizes}")
    print(f"Average: {avg_val_size:,.0f}, Max deviation: {max_deviation*100:.1f}%", end="")
    
    if max_deviation < 0.1:
        print(" ✓ (consistent sizes)")
    else:
        print(" ⚠️  WARNING: Validation sets have inconsistent sizes")
    
    print()
    print("="*70)
    if all_checks_passed:
        print("✓ ALL CHECKS PASSED!")
    else:
        print("❌ SOME CHECKS FAILED - REVIEW FOLD CREATION!")
    print("="*70)


# Run sanity checks
sanity_check_folds(folds, date_col="FL_DATE")