In [0]:
!pip install --upgrade polars
!pip install scikit-learn pandas numpy xgboost


Collecting polars
  Obtaining dependency information for polars from https://files.pythonhosted.org/packages/ec/14/ee34ebe3eb842c83ca1d2d3af6ee02b08377e056ffad156c9a2b15a6d05c/polars-1.32.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached polars-1.32.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (15 kB)
Using cached polars-1.32.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.3 MB)
Installing collected packages: polars
Successfully installed polars-1.32.2
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
Collecting xgboost
  Obtaining dependency information for xgboost from https://files.pythonhosted.org/packages/dc/76/241d22b2b503e97e222d85d5e18f9cc76a67acb552acef84a78bc9e787a5/xgboost-3.0.4-py3-none-manylinux_2_28_x86_64.whl.metadata
  Using cached xgboost-3.0.4-py3-none-manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Collecting nvidia-nccl

In [0]:
dbutils.library.restartPython()

In [0]:
import polars as pl
import numpy as np
import gc
from typing import List
import pandas as pd
from typing import List, Dict
from sklearn.model_selection import GroupKFold
import os

In [0]:
pl.Config.set_tbl_cols(1000)

polars.config.Config

In [0]:
class config:
    data_path = '/Workspace/Users/rajneesh.tiwari@tide.co/aero/'
    train = 'train.parquet'
    test = 'test.parquet'
    airport_benchmark = 'airports.csv'
    debug = False
    # Columns to be used for EDA
    legs = [0, 1]
    segments = [0, 1, 2, 3]
    trips = ['arrivalTo', 'departureFrom']
    date_cols = ['legs0_arrivalAt', 'legs0_departureAt', 'legs1_arrivalAt', 'legs1_departureAt', 'requestDate']
    N_FOLDS = 5
    EXP = 'NB0047'

In [0]:
save_path = os.path.join('/Workspace/Users/rajneesh.tiwari@tide.co/aero/runs', config.EXP)
os.makedirs(save_path, exist_ok=True)

In [0]:

def reduce_mem_usage_polars(df: pl.DataFrame, verbose: bool = True) -> pl.DataFrame:
    """
    Optimize data types in Polars DataFrame for memory efficiency.
    Note: Polars is already quite memory efficient, but this provides additional optimization.
    """
    try:
        if verbose:
            start_mem = df.estimated_size("mb")
            print(f"Memory usage before optimization: {start_mem:.2f} MB")
        
        # Get numeric columns
        numeric_cols = [col for col, dtype in zip(df.columns, df.dtypes) 
                       if dtype in [pl.Int64, pl.Int32, pl.Int16, pl.Int8, 
                                   pl.Float64, pl.Float32, pl.UInt64, pl.UInt32, pl.UInt16, pl.UInt8]]
        
        optimizations = []
        
        for col in numeric_cols:
            try:
                min_val = df.select(pl.col(col).min()).item()
                max_val = df.select(pl.col(col).max()).item()
                
                # Skip optimization if column has all null values or if min/max calculation failed
                if min_val is None or max_val is None:
                    optimizations.append(pl.col(col))
                    continue
                
                current_dtype = df[col].dtype
                
                if current_dtype in [pl.Int64, pl.Int32, pl.Int16, pl.Int8, pl.UInt64, pl.UInt32, pl.UInt16, pl.UInt8]:
                    # Integer optimization
                    if min_val >= -128 and max_val <= 127:
                        optimizations.append(pl.col(col).cast(pl.Int8))
                    elif min_val >= -32768 and max_val <= 32767:
                        optimizations.append(pl.col(col).cast(pl.Int16))
                    elif min_val >= -2147483648 and max_val <= 2147483647:
                        optimizations.append(pl.col(col).cast(pl.Int32))
                    else:
                        optimizations.append(pl.col(col))
                else:
                    # Float optimization
                    if abs(min_val) <= np.finfo(np.float32).max and abs(max_val) <= np.finfo(np.float32).max:
                        optimizations.append(pl.col(col).cast(pl.Float32))
                    else:
                        optimizations.append(pl.col(col))
            except Exception as e:
                if verbose:
                    print(f"Could not optimize column {col}: {e}")
                optimizations.append(pl.col(col))
        
        # Apply optimizations
        if optimizations:
            df = df.with_columns(optimizations)
        
        if verbose:
            end_mem = df.estimated_size("mb")
            reduction = 100 * (start_mem - end_mem) / start_mem if start_mem > 0 else 0
            print(f'Memory usage after optimization: {end_mem:.2f} MB ({reduction:.1f}% reduction)')
        
        return df
    
    except Exception as e:
        if verbose:
            print(f"Memory optimization failed: {e}. Returning original DataFrame.")
        return df


In [0]:

def read_parquet(path: str, debug: bool = config.debug) -> pl.DataFrame:
    """
    Reads a parquet file from the given path.
    """
    if debug:
        print("----------- Running in debug mode -----------")
        df = pl.read_parquet(path)
        sampled = df.head(100000)
        return sampled.select(pl.exclude("__index_level_0__"))
    else:
        return pl.read_parquet(path).select(pl.exclude("__index_level_0__"))

def read_csv(path: str) -> pl.DataFrame:
    """
    Reads a CSV file from the given path.
    """
    return pl.read_csv(path)

In [0]:
def convert_mixed_time_polars(time_str: str) -> float:
    """
    Converts a time string in either 'Days.Hours:Minutes:Seconds' or 'Hours:Minutes:Seconds' format to seconds.
    """
    if time_str is None or time_str == "" or pd.isna(time_str):
        return None
    
    try:
        time_str = str(time_str).strip()
        if '.' in time_str:
            # Days.Hours:Minutes:Seconds format
            days_part, time_part = time_str.split('.', 1)
            days = int(days_part)
            h, m, s = map(int, time_part.split(':'))
            return days * 86400 + h * 3600 + m * 60 + s
        else:
            # Regular Hours:Minutes:Seconds format
            h, m, s = map(int, time_str.split(':'))
            return h * 3600 + m * 60 + s
    except:
        return None

def extract_time_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Extracts various time-based features from columns specified in config.date_cols.
    """
    expressions = []
    
    for col in config.date_cols:
        if col in df.columns:
            # Check if column is already datetime or needs conversion
            col_dtype = df[col].dtype
            
            if col_dtype in [pl.Datetime, pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")]:
                # Already datetime, extract features directly
                expressions.extend([
                    pl.col(col).dt.hour().alias(f'{col}_hour'),
                    pl.col(col).dt.weekday().alias(f'{col}_dayofweek'),
                    pl.col(col).dt.month().alias(f'{col}_month'),
                    pl.col(col).dt.day().alias(f'{col}_day')
                ])
            elif col_dtype == pl.String:
                # String type, convert to datetime first then extract features
                expressions.extend([
                    pl.col(col).str.to_datetime().dt.hour().alias(f'{col}_hour'),
                    pl.col(col).str.to_datetime().dt.weekday().alias(f'{col}_dayofweek'),
                    pl.col(col).str.to_datetime().dt.month().alias(f'{col}_month'),
                    pl.col(col).str.to_datetime().dt.day().alias(f'{col}_day'),
                    pl.col(col).str.to_datetime().alias(col)  # Keep datetime version
                ])
            else:
                # Try to convert other types to datetime
                try:
                    expressions.extend([
                        pl.col(col).cast(pl.Datetime).dt.hour().alias(f'{col}_hour'),
                        pl.col(col).cast(pl.Datetime).dt.weekday().alias(f'{col}_dayofweek'),
                        pl.col(col).cast(pl.Datetime).dt.month().alias(f'{col}_month'),
                        pl.col(col).cast(pl.Datetime).dt.day().alias(f'{col}_day'),
                        pl.col(col).cast(pl.Datetime).alias(col)  # Keep datetime version
                    ])
                except:
                    print(f"Warning: Could not process datetime column {col} with dtype {col_dtype}")
                    continue
    
    if expressions:
        return df.with_columns(expressions)
    else:
        return df

def convert_duration_to_seconds(df: pl.DataFrame) -> pl.DataFrame:
    """
    Converts all columns containing 'duration' in their name to seconds,
    adds them as new columns with '_seconds' suffix, and drops the original columns.
    """
    duration_cols = [col for col in df.columns if 'duration' in col]
    
    expressions = []
    for col in duration_cols:
        print(f"Converting {col} to seconds")
        # Apply the conversion function and cast to Float32
        expressions.append(
            pl.col(col).map_elements(convert_mixed_time_polars, return_dtype=pl.Float64)
            .alias(f'{col}_seconds')
        )
    
    # Add new columns and drop original duration columns
    df = df.with_columns(expressions)
    df = df.drop(duration_cols)
    
    return df

In [0]:
def get_wait_time(df: pl.DataFrame, legs: List[int]) -> pl.DataFrame:
    """
    Calculate the wait time for given legs.
    """
    expressions = []
    
    for leg in legs:
        leg_duration_col = f'legs{leg}_duration_seconds'
        segment_duration_cols = [f'legs{leg}_segments{i}_duration_seconds' for i in range(4)]
        
        # Calculate sum of segment durations (filling nulls with 0)
        segment_sum = pl.sum_horizontal([
            pl.col(col).fill_null(0) for col in segment_duration_cols if col in df.columns
        ])
        
        # Calculate wait time
        expressions.append(
            (pl.col(leg_duration_col) - segment_sum).alias(f'wait_time_leg{leg}_seconds')
        )
    
    return df.with_columns(expressions)

In [0]:
def get_lead_booking_time(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the lead booking time for flights.
    """
    expressions = [
        # Lead time for departure
        (pl.col('legs0_departureAt') - pl.col('requestDate'))
        .dt.total_seconds().cast(pl.Float64).alias('lead_booking_time_seconds'),
        
        # Lead time for return
        (pl.col('legs1_departureAt') - pl.col('requestDate'))
        .dt.total_seconds().cast(pl.Float64).alias('lead_booking_time_wrt_return_seconds')
    ]
    
    return df.with_columns(expressions)

In [0]:
def get_total_length_of_trip(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the total trip duration.
    """
    expressions = [
        # Total trip length
        (pl.col('legs1_departureAt') - pl.col('legs0_departureAt'))
        .dt.total_seconds().cast(pl.Float64).alias('total_length_of_trip_seconds'),
    ]
    
    # Add percentage calculations
    df = df.with_columns(expressions)
    
    percentage_expressions = [
        (pl.col('legs0_duration_seconds') / pl.col('total_length_of_trip_seconds'))
        .cast(pl.Float64).alias('percentage_time_spent_in_flight_leg0'),
        
        (pl.col('legs1_duration_seconds') / pl.col('total_length_of_trip_seconds'))
        .cast(pl.Float64).alias('percentage_time_spent_in_flight_leg1')
    ]
    
    return df.with_columns(percentage_expressions)

In [0]:
def get_trip_type(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the trip type (round trip or single way).
    """
    expression = [
        pl.col('searchRoute').str.contains('/').fill_null(False).cast(pl.Int8).alias('trip_type')
    ]
    
    return df.with_columns(expression)

In [0]:
def get_number_of_stops(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the number of stops for each leg.
    """
    expressions = []
    
    for leg in config.legs:
        # Count non-null segments (segments 1, 2, 3 - segment 0 is always present)
        segment_cols = [f'legs{leg}_segments{i}_arrivalTo_airport_iata' for i in [1, 2, 3]]
        
        valid_segments = pl.sum_horizontal([
            pl.col(col).is_not_null().cast(pl.Int8) for col in segment_cols if col in df.columns
        ])
        
        # Total segments = 1 (segment 0) + number of additional segments
        expressions.append((1 + valid_segments).alias(f'num_segs_leg{leg}'))
    
    return df.with_columns(expressions)

In [0]:
def get_flight_changes_across_segments(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the number of aircraft changes across segments.
    """
    expressions = []
    
    for leg in config.legs:
        changes_list = []
        
        for segment in range(len(config.segments) - 1):
            current_col = f'legs{leg}_segments{segment}_aircraft_code'
            next_col = f'legs{leg}_segments{segment+1}_aircraft_code'
            
            if current_col in df.columns and next_col in df.columns:
                # Count changes where both values are not null and different
                change_expr = (
                    (pl.col(current_col) != pl.col(next_col)) & 
                    pl.col(current_col).is_not_null() & 
                    pl.col(next_col).is_not_null()
                ).cast(pl.Int8)
                
                changes_list.append(change_expr)
        
        if changes_list:
            total_changes = pl.sum_horizontal(changes_list)
            expressions.append(total_changes.alias(f'aircraft_changes_leg{leg}'))
    
    return df.with_columns(expressions)

In [0]:
def get_cabin_changes_across_segments(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the number of cabin changes across segments.
    """
    expressions = []
    
    for leg in config.legs:
        changes_list = []
        
        for segment in range(len(config.segments) - 1):
            current_col = f'legs{leg}_segments{segment}_cabinClass'
            next_col = f'legs{leg}_segments{segment+1}_cabinClass'
            
            if current_col in df.columns and next_col in df.columns:
                change_expr = (
                    (pl.col(current_col) != pl.col(next_col)) & 
                    pl.col(current_col).is_not_null() & 
                    pl.col(next_col).is_not_null()
                ).cast(pl.Int8)
                
                changes_list.append(change_expr)
        
        if changes_list:
            total_changes = pl.sum_horizontal(changes_list)
            expressions.append(total_changes.alias(f'cabin_changes_leg{leg}'))
    
    return df.with_columns(expressions)

In [0]:
def get_baggage_quantity_changes_across_segments(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the number of baggage quantity changes across segments.
    """
    expressions = []
    
    for leg in config.legs:
        changes_list = []
        
        for segment in range(len(config.segments) - 1):
            current_col = f'legs{leg}_segments{segment}_baggageAllowance_quantity'
            next_col = f'legs{leg}_segments{segment+1}_baggageAllowance_quantity'
            
            if current_col in df.columns and next_col in df.columns:
                change_expr = (
                    (pl.col(current_col) != pl.col(next_col)) & 
                    pl.col(current_col).is_not_null() & 
                    pl.col(next_col).is_not_null()
                ).cast(pl.Int8)
                
                changes_list.append(change_expr)
        
        if changes_list:
            total_changes = pl.sum_horizontal(changes_list)
            expressions.append(total_changes.alias(f'baggage_quantity_changes_leg{leg}'))
    
    return df.with_columns(expressions)

In [0]:
def is_frequent_flyer_airline(df: pl.DataFrame) -> pl.DataFrame:
    """
    Check if the flight is using a frequent flyer airline.
    """
    expressions = []
    
    for leg in config.legs:
        for segment in config.segments:
            carrier_col = f'legs{leg}_segments{segment}_marketingCarrier_code'
            
            if carrier_col in df.columns:
                # Check if carrier code is contained in frequentFlyer string
                expr = (
                    pl.col('frequentFlyer').str.contains(pl.col(carrier_col)) &
                    pl.col(carrier_col).is_not_null() &
                    pl.col('frequentFlyer').is_not_null()
                ).fill_null(False).cast(pl.Int8).alias(f'is_frequent_flyer_airline_leg{leg}_segment{segment}')
                
                expressions.append(expr)
    
    return df.with_columns(expressions)

In [0]:
def get_tax_as_percentage_of_price(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculate the tax as a percentage of the price.
    """
    expression = [
        (pl.col('taxes') / (1+pl.col('totalPrice'))).cast(pl.Float64).alias('tax_as_percentage_of_price')
    ]
    
    return df.with_columns(expression)

In [0]:
# --- NEW FUNCTION ---
def get_rank_features(df: pl.DataFrame, cols_to_rank: List[str], group_col: str = None) -> pl.DataFrame:
    """
    Creates rank features for the specified columns.
    If a group_col is provided, the rank is calculated within each group.
    
    Args:
        df: The input Polars DataFrame.
        cols_to_rank: A list of column names to create rank features for.
        group_col: The column to group by for ranking (e.g., 'searchId').
                   If None, ranks globally.
                   
    Returns:
        A Polars DataFrame with the new rank features.
    """
    print(f"Generating rank features for: {cols_to_rank}")
    expressions = []
    for col in cols_to_rank:
        if col in df.columns:
            if group_col and group_col in df.columns:
                # Rank within each group
                print(f"Ranking {col} over {group_col}")
                expressions.append(
                    pl.col(col).rank("ordinal").over(group_col).alias(f'{col}_rank_in_group')
                )
            else:
                # Rank globally
                print(f"Ranking {col} globally")
                expressions.append(
                    pl.col(col).rank("ordinal").alias(f'{col}_rank')
                )
    
    return df.with_columns(expressions)

In [0]:
def get_percentile_features(df: pl.DataFrame, cols_to_percentile: list, group_col: str = 'ranker_id') -> pl.DataFrame:
    """
    Safe version of percentile calculation that handles edge cases
    """
    print(f"Generating SAFE percentile features for: {cols_to_percentile}")
    
    expressions = []
    for col in cols_to_percentile:
        if col in df.columns:
            print(f"Creating safe percentile for {col} within {group_col}")
            
            # More robust percentile calculation
            expressions.append(
                pl.when(pl.col(col).is_null())
                .then(None)  # Keep nulls as nulls
                .when(pl.col(col).count().over(group_col) <= 1)
                .then(50.0)  # Single item gets median percentile
                .otherwise(
                    # Safe percentile calculation
                    pl.when(pl.col(col).count().over(group_col) > 1)
                    .then(
                        ((pl.col(col).rank("ordinal").over(group_col) - 1.0) / 
                         (pl.col(col).count().over(group_col) - 1.0) * 100.0)
                        .clip(0.0, 100.0)  # Ensure percentiles are between 0-100
                        .round(2)
                    )
                    .otherwise(50.0)
                )
                .cast(pl.Float32)
                .alias(f'{col}_percentile_in_group')
            )
    
    return df.with_columns(expressions)

In [0]:
def is_min_segments_total(df: pl.DataFrame) -> pl.DataFrame:
    print("Calculating is_min_total_segments")
    expressions = [
        # Calculate total segments across both legs
        (pl.col('num_segs_leg0') + pl.col('num_segs_leg1')).alias('total_segments'),
        
        # Fixed: recalculate the sum instead of referencing the new column
        ((pl.col('num_segs_leg0') + pl.col('num_segs_leg1')) == 
         (pl.col('num_segs_leg0') + pl.col('num_segs_leg1')).min().over('ranker_id'))
        .cast(pl.Int8).alias('is_min_segments')
    ]
    
    return df.with_columns(expressions)

def is_min_segments_per_leg(df: pl.DataFrame) -> pl.DataFrame:
    """
    Check if each leg has minimum segments separately
    """
    print("Calculating is_min_total_segments per leg")
    if 'num_segs_leg0' not in df.columns or 'num_segs_leg1' not in df.columns:
        raise ValueError("num_segs_leg0 and num_segs_leg1 must be calculated first")
    
    expressions = [
        # Check if leg 0 has minimum segments in group
        (pl.col('num_segs_leg0') == pl.col('num_segs_leg0').min().over('ranker_id'))
        .cast(pl.Int8).alias('is_min_segments_leg0'),
        
        # Check if leg 1 has minimum segments in group  
        (pl.col('num_segs_leg1') == pl.col('num_segs_leg1').min().over('ranker_id'))
        .cast(pl.Int8).alias('is_min_segments_leg1'),
        
        # Check if BOTH legs have minimum segments
        ((pl.col('num_segs_leg0') == pl.col('num_segs_leg0').min().over('ranker_id')) &
         (pl.col('num_segs_leg1') == pl.col('num_segs_leg1').min().over('ranker_id')))
        .cast(pl.Int8).alias('is_min_segments_both_legs')
    ]
    
    return df.with_columns(expressions)

In [0]:
def create_cv_aware_aggregate_features_no_leakage(
    train_df: pl.DataFrame, 
    test_df: pl.DataFrame,
    agg_configs: Dict[str, Dict]
) -> (pl.DataFrame, pl.DataFrame):
    """
    Creates aggregate features without any leakage for fold-specific test features.
    
    KEY FIX: Uses original raw training data for test feature calculation,
    not the contaminated recombined data.
    """
    print("Creating CV-aware aggregate features (LEAKAGE-FREE VERSION)...")
    
    grouping_col = 'ranker_id'
    if grouping_col not in train_df.columns:
        raise ValueError(f"`{grouping_col}` column required for GroupKFold.")

    # Store the ORIGINAL raw training data before any aggregate features
    original_train_df = train_df.clone()
    
    # --- Fold Assignment ---
    print("Assigning folds...")
    gkf = GroupKFold(n_splits=5) #, shuffle=True, random_state=2025)
    groups = train_df.get_column(grouping_col)
    
    fold_assignments = np.zeros(len(train_df))
    for i, (_, val_idx) in enumerate(gkf.split(X=np.zeros(len(train_df)), groups=groups)):
        fold_assignments[val_idx] = i
    
    train_df = train_df.with_columns(pl.Series("fold", fold_assignments, dtype=pl.Int8))
    original_train_df = original_train_df.with_columns(pl.Series("fold", fold_assignments, dtype=pl.Int8))
    
    # --- CV Loop for Training Features ---
    list_of_enriched_folds = []
    
    for i in range(5):
        print(f"--- Processing Training Fold {i} ---")
        
        # Use ORIGINAL data for calculating aggregates (no contamination)
        df_train_fold = original_train_df.filter(pl.col('fold') != i)
        df_val_fold = original_train_df.filter(pl.col('fold') == i)

        val_fold_with_features = df_val_fold

        for new_feat_name, config in agg_configs.items():
            print(f"  Calculating: {new_feat_name}")
            
            if config.get('filter_cond') is not None:
                agg_base = df_train_fold.filter(config['filter_cond'])
            else:
                agg_base = df_train_fold
            
            agg_df = agg_base.group_by(config['group_by']).agg(
                getattr(pl.col(config['agg_col']), config['agg_func'])().alias(new_feat_name)
            )

            val_fold_with_features = val_fold_with_features.join(agg_df, on=config['group_by'], how='left')

            del agg_base, agg_df
            gc.collect()
            
        list_of_enriched_folds.append(val_fold_with_features)
        del df_train_fold, df_val_fold, val_fold_with_features
        gc.collect()

    print("Recombining training folds...")
    train_df_final = pl.concat(list_of_enriched_folds)
    del list_of_enriched_folds
    gc.collect()

    # --- LEAKAGE-FREE Test Feature Generation ---
    print("\n--- Processing LEAKAGE-FREE Test Set Features ---")
    
    test_with_features = test_df
    if test_df is not None:
        # Create features for each fold using ORIGINAL raw data
        for fold_i in range(5):
            print(f"Creating test features for fold {fold_i} (using original raw data)...")
            
            # CRITICAL FIX: Use original_train_df instead of contaminated train_df_final
            df_train_for_test = original_train_df.filter(pl.col('fold') != fold_i)
            
            for new_feat_name, config in agg_configs.items():
                print(f"  Calculating: {new_feat_name}_fold{fold_i}")
                
                if config.get('filter_cond') is not None:
                    agg_base = df_train_for_test.filter(config['filter_cond'])
                else:
                    agg_base = df_train_for_test
                
                fold_agg_df = agg_base.group_by(config['group_by']).agg(
                    getattr(pl.col(config['agg_col']), config['agg_func'])()
                    .alias(f"{new_feat_name}_fold{fold_i}")
                )
                
                test_with_features = test_with_features.join(
                    fold_agg_df, on=config['group_by'], how='left'
                )
                
                del fold_agg_df, agg_base
                gc.collect()
            
            del df_train_for_test
            gc.collect()
        
        print(f"\nCompleted! Test set now has fold-specific features without leakage")
        print(f"Train shape: {train_df_final.shape}")
        print(f"Test shape: {test_with_features.shape}")
    
    return train_df_final, test_with_features

In [0]:
# ================================================================
# TIER 1 HIGH-PRIORITY FEATURE ENGINEERING FUNCTIONS
# Optimized for large Polars DataFrames with vectorized operations
# ================================================================

import polars as pl
import numpy as np

import polars as pl

def create_segment_tier_position_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates position and ranking features within segment tiers.
    This captures the key insight that options are ordered by segments,
    but there are preferences within each segment tier.

    Note: Uses existing 'total_segments' from your pipeline
    """
    print("Creating segment tier position features...")

    # Define a helper expression to find the minimum price for the minimum segment tier
    # We use a when/then/otherwise approach to isolate the price, which is more idiomatic
    # than filtering within a window function.
    # min_price_in_min_segment_tier = (
    #     pl.when(pl.col("total_segments") == pl.col("total_segments").min().over("ranker_id"))
    #     .then(pl.col("totalPrice"))
    #     .otherwise(None)
    #     .min()
    #     .over("ranker_id")
    # )

    min_price_in_min_segment_tier = (
        pl.col("totalPrice")
        .sort_by(["total_segments", "totalPrice"])
        .first()
        .over("ranker_id")
    )

    expressions = [
        # Position within same segment tier (ranked by price)
        pl.col('totalPrice').rank('ordinal').over(['ranker_id', 'total_segments']).alias('position_within_segment_tier'),

        # Count of options in this segment tier
        pl.col('Id').count().over(['ranker_id', 'total_segments']).alias('options_in_segment_tier'),

        # Position as percentage within tier (0-100)
        ((pl.col('totalPrice').rank('ordinal').over(['ranker_id', 'total_segments']) - 1) /
         (pl.col('Id').count().over(['ranker_id', 'total_segments']) - 1) * 100)
        .fill_null(50.0).alias('position_pct_within_segment_tier'),

        # Distance from minimum segment option
        (pl.col('total_segments') - pl.col('total_segments').min().over('ranker_id'))
        .alias('segment_distance_from_minimum'),

        # Price premium for extra segments vs minimum segment option - REFACTORED
        (pl.col('totalPrice') - min_price_in_min_segment_tier)
        .alias('price_premium_vs_min_segments'),

        # Segment tier (1=minimum, 2=min+1, etc.)
        (pl.col('total_segments') - pl.col('total_segments').min().over('ranker_id') + 1)
        .alias('segment_tier')
    ]

    return df.with_columns(expressions)


def create_value_gap_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates features that capture the value proposition and trade-offs
    between different options within each search.
    """
    print("Creating value gap features...")

    # --- FIX: Pre-calculate min price for each tier to look up the previous one ---
    # Create a temporary dataframe with min price for each segment tier
    min_price_per_tier = df.group_by("ranker_id", "total_segments").agg(
        pl.col("totalPrice").min().alias("min_price_of_tier")
    )

    # Join this back to the main df, matching the *next* tier's price to the current row
    df_with_gaps = df.join(
        min_price_per_tier.with_columns(
            (pl.col("total_segments") + 1).alias("next_tier") # Create join key
        ).rename({"min_price_of_tier": "price_of_better_segment_tier"}),
        left_on=["ranker_id", "total_segments"],
        right_on=["ranker_id", "next_tier"],
        how="left"
    )
    # --- End of Fix ---

    expressions = [
        # Price per segment premium (how much extra per additional segment)
        (pl.col('price_premium_vs_min_segments') /
         pl.col('segment_distance_from_minimum').clip(1, None))
        .alias('price_per_extra_segment'),

        # Time savings per dollar spent (for convenience premium)
        (pl.when(pl.col('price_premium_vs_min_segments') > 0)
         .then((pl.col('wait_time_leg0_seconds').max().over('ranker_id') +
                pl.col('wait_time_leg1_seconds').max().over('ranker_id') -
                pl.col('wait_time_leg0_seconds') - pl.col('wait_time_leg1_seconds')) /
               pl.col('price_premium_vs_min_segments'))
         .otherwise(None))
        .alias('time_saved_per_dollar_premium'),

        # Relative value score (price rank inverted + segment preference)
        (100 - pl.col('totalPrice_percentile_in_group') +
         (pl.col('segment_distance_from_minimum').max().over('ranker_id') -
          pl.col('segment_distance_from_minimum')) * 20)
        .alias('convenience_value_score'),

        # Is this option in the "sweet spot" (good price within acceptable segments)
        ((pl.col('segment_tier') <= 2) &
         (pl.col('totalPrice_percentile_in_group') <= 50))
        .cast(pl.Int8).alias('is_sweet_spot_option'),

        # Price gap to next better segment tier - REWRITTEN
        # This now uses the pre-calculated column from the join
        (pl.col('price_of_better_segment_tier') - pl.col('totalPrice'))
        .alias('price_gap_to_better_segments')
    ]

    return df_with_gaps.with_columns(expressions).drop("price_of_better_segment_tier")

def create_user_convenience_profile_features(df: pl.DataFrame, group_col: str = 'ranker_id') -> pl.DataFrame:
    """
    Creates features that capture user's historical price vs convenience trade-off patterns.
    These will be used in the CV-aware aggregation system.
    """
    print("Creating user convenience profile base features...")
    
    expressions = [
        # User's typical segment choice patterns
        (pl.col('segment_distance_from_minimum') * pl.col('totalPrice'))
        .alias('segment_price_interaction'),
        
        # User's willingness to pay premium (selected options vs available cheaper)
        (pl.col('totalPrice') - pl.col('totalPrice').min().over(group_col))
        .alias('premium_paid_vs_cheapest'),
        
        # User's segment flexibility score
        pl.when(pl.col('segment_distance_from_minimum') > 0)
        .then(pl.col('price_premium_vs_min_segments') / pl.col('segment_distance_from_minimum'))
        .otherwise(0.0)
        .alias('segment_flexibility_score'),
        
        # Choose convenience over price indicator
        ((pl.col('totalPrice_percentile_in_group') > 50) & 
         (pl.col('segment_distance_from_minimum') == 0))
        .cast(pl.Int8).alias('chose_convenience_over_price'),
        
        # Choose price over convenience indicator  
        ((pl.col('totalPrice_percentile_in_group') <= 30) & 
         (pl.col('segment_distance_from_minimum') > 0))
        .cast(pl.Int8).alias('chose_price_over_convenience'),
        
        # Premium segment choice (chose more segments despite higher price)
        ((pl.col('segment_distance_from_minimum') > 0) & 
         (pl.col('totalPrice_percentile_in_group') > 40))
        .cast(pl.Int8).alias('premium_segment_choice')
    ]
    
    return df.with_columns(expressions)


def create_company_travel_policy_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates features that capture company travel policy flexibility and patterns.
    These indicate how strict/flexible the company's travel policy is.
    """
    print("Creating company travel policy base features...")
    
    expressions = [
        # Company size indicator (number of unique travelers)
        pl.col('profileId').n_unique().over('companyID').alias('company_traveler_count'),
        
        # Price discipline indicator (how often they choose cheaper options)
        (pl.col('totalPrice_percentile_in_group') <= 25).cast(pl.Int8).alias('chose_bottom_quartile_price'),
        
        # Segment discipline (how often they choose minimum segments)
        (pl.col('segment_distance_from_minimum') == 0).cast(pl.Int8).alias('chose_minimum_segments'),
        
        # Premium policy indicator (frequent high-price choices)
        (pl.col('totalPrice_percentile_in_group') >= 75).cast(pl.Int8).alias('chose_top_quartile_price'),
        
        # Policy flexibility score
        (pl.col('totalPrice_percentile_in_group') * pl.col('segment_distance_from_minimum'))
        .alias('policy_flexibility_interaction'),
        
        # Company route specialization (how much they use this specific route)
        pl.col('searchRoute').count().over(['companyID', 'searchRoute']).alias('company_route_frequency'),
        
        # Travel intensity (total searches by this company)
        pl.col('ranker_id').count().over('companyID').alias('company_travel_intensity')
    ]
    
    return df.with_columns(expressions)


def create_route_familiarity_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates features capturing route familiarity and frequency patterns.
    """
    print("Creating route familiarity features...")

    # --- FIX: Break nested window function into two steps ---
    # 1. First, calculate the user_route_frequency
    df_with_freq = df.with_columns(
        pl.col('searchRoute').count().over(['profileId', 'searchRoute']).alias('user_route_frequency')
    )
    # 2. Now, create the comparison expression using the new column
    is_frequent_route_expr = (
        pl.col('user_route_frequency') >= pl.col('user_route_frequency').quantile(0.8).over('profileId')
    ).cast(pl.Int8)
    # --- End of Fix ---


    expressions = [
        # User's experience with this route (already calculated)
        pl.col('user_route_frequency'),

        # Route popularity overall
        pl.col('searchRoute').count().over('searchRoute').alias('overall_route_popularity'),

        # User's route specialization (what % of their travel is this route)
        (pl.col('user_route_frequency') /
         pl.col('ranker_id').count().over('profileId') * 100)
        .alias('user_route_specialization_pct'),

        # Company's experience with this route
        pl.col('searchRoute').count().over(['companyID', 'searchRoute']).alias('company_route_experience'),

        # Is this a frequent route for user? (top 20% of their routes) - REWRITTEN
        is_frequent_route_expr.alias('is_frequent_route_for_user'),

        # Route diversity score (how many different routes this user books)
        pl.col('searchRoute').n_unique().over('profileId').alias('user_route_diversity'),

        # New route for user indicator
        (pl.col('user_route_frequency') == 1).cast(pl.Int8).alias('is_new_route_for_user'),

        # Corporate route standardization (what % of company uses this route)
        (pl.col('searchRoute').count().over(['companyID', 'searchRoute']) /
         pl.col('ranker_id').count().over('companyID') * 100)
        .alias('corporate_route_standardization_pct')
    ]

    return df_with_freq.with_columns(expressions)

def create_advanced_ranking_features(df: pl.DataFrame, group_col: str = 'ranker_id') -> pl.DataFrame:
    """
    Creates advanced ranking features that capture relative positioning and competition.
    """
    print("Creating advanced ranking features...")

    # Define these expressions separately for clarity due to their dependency
    price_quintile_expr = (pl.col('totalPrice_percentile_in_group').round(0) / 20).cast(pl.Int8)
    segment_tier_expr = pl.col('segment_tier').cast(pl.String)

    expressions = [
        # Rank within each segment tier by various dimensions
        pl.col('totalPrice').rank('ordinal').over([group_col, 'segment_tier']).alias('price_rank_within_segment_tier'),
        pl.col('wait_time_leg0_seconds').rank('ordinal').over([group_col, 'segment_tier']).alias('wait_time_rank_within_segment_tier'),
        pl.col('lead_booking_time_seconds').rank('ordinal').over([group_col, 'segment_tier']).alias('lead_time_rank_within_segment_tier'),

        # Best option indicators within segment tier
        (pl.col('totalPrice') == pl.col('totalPrice').min().over([group_col, 'segment_tier']))
        .cast(pl.Int8).alias('is_cheapest_in_segment_tier'),
        (pl.col('wait_time_leg0_seconds') == pl.col('wait_time_leg0_seconds').min().over([group_col, 'segment_tier']))
        .cast(pl.Int8).alias('is_fastest_in_segment_tier'),

        # Competitive gap features - SYNTAX FIXED
        (pl.col('totalPrice') - pl.col('totalPrice').shift(1).sort_by('totalPrice').over(group_col))
        .alias('price_gap_to_next_cheaper'),
        (pl.col('totalPrice').shift(-1).sort_by('totalPrice').over(group_col) - pl.col('totalPrice'))
        .alias('price_gap_to_next_expensive'),

        # Market position indicators
        price_quintile_expr.alias('price_quintile'),
        (segment_tier_expr + "_" + price_quintile_expr.cast(pl.String)).alias('segment_price_category'),

        # Option scarcity/abundance
        pl.col('Id').count().over(group_col).alias('total_options_available'),
        (pl.col('Id').count().over(group_col) >= 10).cast(pl.Int8).alias('has_many_options'),
        (pl.col('Id').count().over(group_col) <= 3).cast(pl.Int8).alias('has_few_options')
    ]

    return df.with_columns(expressions)
  
high_priority_aggregation_configs = {
    # User convenience profile aggregations - NEW
    'avg_segment_flexibility_score_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'segment_flexibility_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'std_segment_flexibility_score_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'segment_flexibility_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'std'
    },
    'user_convenience_over_price_rate': {
        'group_by': ['profileId'],
        'agg_col': 'chose_convenience_over_price',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_price_over_convenience_rate': {
        'group_by': ['profileId'],
        'agg_col': 'chose_price_over_convenience',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_premium_paid_vs_cheapest_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'premium_paid_vs_cheapest',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # Company policy flexibility aggregations - NEW
    'company_price_discipline_rate': {
        'group_by': ['companyID'],
        'agg_col': 'chose_bottom_quartile_price',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_segment_discipline_rate': {
        'group_by': ['companyID'],
        'agg_col': 'chose_minimum_segments',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_premium_policy_rate': {
        'group_by': ['companyID'],
        'agg_col': 'chose_top_quartile_price',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_policy_flexibility_by_company': {
        'group_by': ['companyID'],
        'agg_col': 'policy_flexibility_interaction',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'std_policy_flexibility_by_company': {
        'group_by': ['companyID'],
        'agg_col': 'policy_flexibility_interaction',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'std'
    },
    
    # Route familiarity aggregations - NEW
    'user_avg_segment_tier_for_route': {
        'group_by': ['profileId', 'searchRoute'],
        'agg_col': 'segment_tier',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_price_percentile_for_route': {
        'group_by': ['profileId', 'searchRoute'],
        'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_avg_segment_tier_for_route': {
        'group_by': ['companyID', 'searchRoute'],
        'agg_col': 'segment_tier',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # Segment tier positioning aggregations - NEW (different from existing user_min_segments_selection_rate)
    'avg_position_within_segment_tier_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'position_pct_within_segment_tier',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_segment_tier_preference': {
        'group_by': ['companyID'],
        'agg_col': 'segment_tier',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # Value gap aggregations - NEW
    'user_avg_convenience_value_score': {
        'group_by': ['profileId'],
        'agg_col': 'convenience_value_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_sweet_spot_selection_rate': {
        'group_by': ['profileId'],
        'agg_col': 'is_sweet_spot_option',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_avg_price_per_extra_segment': {
        'group_by': ['profileId'],
        'agg_col': 'price_per_extra_segment',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # Route experience aggregations - NEW
    'user_route_specialization_avg': {
        'group_by': ['profileId'],
        'agg_col': 'user_route_specialization_pct',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_route_standardization_avg': {
        'group_by': ['companyID'],
        'agg_col': 'corporate_route_standardization_pct',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    }
}

In [0]:
aggregation_configurations = {
    'avg_price_rank_by_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_price_rank_by_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_price_rank_by_comp_route_baggageallowance': {
        'group_by': ['companyID', 'searchRoute','legs0_segments0_baggageAllowance_quantity'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_price_rank_by_comp_route_baggageallowance': {
        'group_by': ['companyID', 'searchRoute','legs0_segments0_baggageAllowance_quantity'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_price_percentile_by_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_price_percentile_by_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'median_price_percentile_by_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'median'
    },
    'avg_price_rank_by_comp_route_tariffcode': {
        'group_by': ['companyID', 'searchRoute','corporateTariffCode'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_price_rank_by_comp_route_tariffcode': {
        'group_by': ['companyID', 'searchRoute','corporateTariffCode'], 'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_tax_as_percentage_of_price_rank_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'tax_as_percentage_of_price_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_tax_as_percentage_of_price_rank_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'tax_as_percentage_of_price_rank_in_group',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_percentage_time_spent_in_flight_leg1_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'percentage_time_spent_in_flight_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_percentage_time_spent_in_flight_leg1_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'percentage_time_spent_in_flight_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_percentage_time_spent_in_flight_leg0_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'percentage_time_spent_in_flight_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_percentage_time_spent_in_flight_leg0_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'percentage_time_spent_in_flight_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_wait_time_leg0_seconds_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'wait_time_leg0_seconds',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_wait_time_leg0_seconds_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'wait_time_leg0_seconds',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_wait_time_leg1_seconds_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'wait_time_leg1_seconds',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_wait_time_leg1_seconds_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'wait_time_leg1_seconds',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_num_segs_leg0_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_num_segs_leg0_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
   'avg_num_segs_leg0_comp_route_baggageallowance': {
        'group_by': ['companyID', 'searchRoute','legs0_segments0_baggageAllowance_quantity'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_num_segs_leg0_comp_route_baggageallowance': {
        'group_by': ['companyID', 'searchRoute','legs0_segments0_baggageAllowance_quantity'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'avg_num_segs_leg1_comp_route':{
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_num_segs_leg1_comp_route': {
        'group_by': ['companyID', 'searchRoute'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
   'avg_num_segs_leg0_comp_route_miniRules0_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules0_statusInfos'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_num_segs_leg0_comp_route_miniRules0_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules0_statusInfos'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
   'avg_num_segs_leg1_comp_route_miniRules1_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules1_statusInfos'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'mean'
    },
    'std_num_segs_leg1_comp_route_miniRules1_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules1_statusInfos'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'std'
    },
    'max_num_segs_leg1_comp_route_miniRules1_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules1_statusInfos'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'max'
    },
    'min_num_segs_leg1_comp_route_miniRules1_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules1_statusInfos'], 'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'min'
    },
    'max_num_segs_leg0_comp_route_miniRules0_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules0_statusInfos'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'max'
    },
    'min_num_segs_leg0_comp_route_miniRules0_statusInfos': {
        'group_by': ['companyID', 'searchRoute','miniRules0_statusInfos'], 'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1), 'agg_func': 'min'
    },
    'avg_total_price_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'totalPrice',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_price_percentile_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'std_price_percentile_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'totalPrice_percentile_in_group',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'std'
    },
    'avg_price_rank_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'totalPrice_rank_in_group',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_stops_leg0_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'num_segs_leg0',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_stops_leg1_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'num_segs_leg1',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_min_segments_selection_rate': {
        'group_by': ['profileId'],
        'agg_col': 'is_min_segments_both_legs',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_wait_time_leg0_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'wait_time_leg0_seconds',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'selection_rate_by_airline_for_user': {
        'group_by': ['profileId', 'legs0_segments0_marketingCarrier_code'],
        'agg_col': 'selected',
        'filter_cond': None, # We aggregate the target directly
        'agg_func': 'mean'
    },
    'avg_baggage_allowance_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'legs0_segments0_baggageAllowance_quantity',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_lead_time_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'lead_booking_time_seconds',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'avg_passenger_count_by_user': {
        'group_by': ['profileId'],
        'agg_col': 'pricingInfo_passengerCount',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    }
}

# miniRules0_statusInfos

In [0]:
def create_passenger_seat_ratio_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Calculates the ratio of passenger count to available seats for specific flight segments.

    Args:
        df: The input Polars DataFrame.

    Returns:
        The DataFrame with four new ratio columns added.
    """
    # Define the seat availability columns to create ratios for
    seat_cols = [
        'legs0_segments0_seatsAvailable',
        'legs0_segments1_seatsAvailable',
        'legs1_segments0_seatsAvailable',
        'legs1_segments1_seatsAvailable'
    ]

    # Create a list of expressions to generate the new ratio columns
    ratio_expressions = [
        (
            pl.col('pricingInfo_passengerCount') /
            # Safely handle the denominator: replace 0 or null seats with null
            pl.when(pl.col(col) > 0).then(pl.col(col))
        ).alias(f"passenger_to_{col}_ratio")
        for col in seat_cols
    ]

    # Return the dataframe with the new columns
    return df.with_columns(ratio_expressions)

In [0]:
def create_route_segment_intelligence_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates route-specific segment intelligence features.
    NEW: Route-normalized segment analysis (not in existing code)
    """
    print("Creating route-specific segment intelligence features...")
    
    expressions = [
        # Route segment complexity baseline - NEW
        pl.col('total_segments').mean().over('searchRoute').alias('route_avg_segments_all_options'),
        pl.col('total_segments').min().over('searchRoute').alias('route_min_segments_available'),
        pl.col('total_segments').max().over('searchRoute').alias('route_max_segments_available'),
        pl.col('total_segments').std().over('searchRoute').alias('route_segment_diversity'),
        
        # How does this option compare to route norms? - NEW
        (pl.col('total_segments') - pl.col('total_segments').mean().over('searchRoute'))
        .alias('segments_deviation_from_route_avg'),
        
        # Is this the minimum segment option for this route? - NEW (different from global min)
        (pl.col('total_segments') == pl.col('total_segments').min().over('searchRoute'))
        .cast(pl.Int8).alias('is_min_segments_for_route'),
        
        # Route segment tier (different from global segment_tier) - NEW
        (pl.col('total_segments') - pl.col('total_segments').min().over('searchRoute') + 1)
        .alias('route_specific_segment_tier'),
        
        # Route segment complexity score - NEW
        pl.col('total_segments').n_unique().over('searchRoute').alias('route_segment_complexity_score'),
        
        # Position in route segment distribution - NEW
        ((pl.col('total_segments').rank('ordinal').over('searchRoute') - 1) /
         (pl.col('Id').count().over('searchRoute') - 1) * 100)
        .fill_null(50.0).alias('route_segment_percentile')
    ]
    
    return df.with_columns(expressions)


def create_company_segment_discipline_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates company segment discipline features.
    NEW: Quantifies company policy consistency (not in existing code)
    """
    print("Creating company segment discipline features...")
    
    expressions = [
        # Company segment choice consistency - NEW
        pl.col('total_segments').std().over('companyID').alias('company_segment_choice_std'),
        pl.col('total_segments').mean().over('companyID').alias('company_avg_segments_all_searches'),
        
        # Company segment discipline indicators - NEW
        (pl.col('is_min_segments') == 1).cast(pl.Int8).mean().over('companyID')
        .alias('company_min_segment_selection_rate_all'),
        
        # Company segment override patterns - NEW
        (pl.col('segment_distance_from_minimum') > 0).cast(pl.Int8).mean().over('companyID')
        .alias('company_segment_override_rate_all'),
        
        # Company segment flexibility within routes - NEW
        pl.col('segment_distance_from_minimum').std().over(['companyID', 'searchRoute'])
        .alias('company_route_segment_flexibility'),
        
        # Company segment consistency score - NEW
        (1 / (1 + pl.col('total_segments').std().over('companyID')))
        .alias('company_segment_consistency_score'),
        
        # Company direct flight preference rate - NEW
        pl.when(pl.col('segment_distance_from_minimum') == 0)
        .then(1)
        .otherwise(0)
        .cast(pl.Int8)
        .mean()
        .over('companyID')
        .alias('company_direct_flight_preference_rate')
    ]
    
    return df.with_columns(expressions)


def create_user_segment_preference_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates user-specific segment preference features.
    NEW: User segment consistency and deviation patterns (not in existing code)
    """
    print("Creating user segment preference features...")
    
    expressions = [
        # User segment choice patterns - NEW
        pl.col('total_segments').mean().over('profileId').alias('user_avg_segments_all_searches'),
        pl.col('total_segments').std().over('profileId').alias('user_segment_choice_std'),
        
        # User segment preference strength - NEW
        (pl.col('is_min_segments') == 1).cast(pl.Int8).mean().over('profileId')
        .alias('user_min_segment_preference_rate'),
        
        # User segment flexibility - NEW
        pl.col('segment_distance_from_minimum').mean().over('profileId')
        .alias('user_avg_segment_distance_from_min'),
        
        # User route-specific segment patterns - NEW (different from existing route features)
        pl.col('total_segments').mean().over(['profileId', 'searchRoute'])
        .alias('user_route_specific_avg_segments'),
        
        # User segment consistency score - NEW
        (1 / (1 + pl.col('total_segments').std().over('profileId')))
        .alias('user_segment_consistency_score'),
        
        # User deviation from company norm - NEW
        (pl.col('total_segments') - pl.col('total_segments').mean().over('companyID'))
        .alias('user_segments_vs_company_norm'),
        
        # User segment override pattern - NEW
        (pl.col('segment_distance_from_minimum') > 0).cast(pl.Int8)
        .mean().over('profileId').alias('user_segment_override_rate')
    ]
    
    return df.with_columns(expressions)


# ================================================================
# TIER 2: HIGH IMPACT, SLIGHTLY MORE COMPLEX (DUPLICATES REMOVED)
# ================================================================

def create_within_search_competitive_context(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates features that capture competitive context within each search.
    NEW: Segment-specific competition analysis (builds on existing has_few/many_options)
    """
    print("Creating within-search competitive context features...")
    
    expressions = [
        # Segment availability distribution in this search - NEW
        pl.col('total_segments').n_unique().over('ranker_id')
        .alias('segment_tiers_available_in_search'),
        
        # Count of options in each segment tier - NEW (more specific than existing)
        pl.col('Id').count().over(['ranker_id', 'total_segments'])
        .alias('options_in_this_segment_tier'),
        
        # Is this the only direct option? - NEW
        ((pl.col('total_segments') == pl.col('total_segments').min().over('ranker_id')) &
         (pl.col('Id').count().over(['ranker_id', 'total_segments']) == 1))
        .cast(pl.Int8).alias('is_only_direct_option'),
        
        # Segment tier dominance - NEW
        (pl.col('Id').count().over(['ranker_id', 'total_segments']) /
         pl.col('Id').count().over('ranker_id') * 100)
        .alias('segment_tier_dominance_pct'),
        
        # Price spread within segment tier - NEW
        (pl.col('totalPrice').max().over(['ranker_id', 'total_segments']) -
         pl.col('totalPrice').min().over(['ranker_id', 'total_segments']))
        .alias('price_spread_within_segment_tier'),
        
        # Segment tier scarcity score - NEW
        pl.when(pl.col('Id').count().over(['ranker_id', 'total_segments']) == 1)
        .then(100)
        .when(pl.col('Id').count().over(['ranker_id', 'total_segments']) <= 2)
        .then(75)
        .when(pl.col('Id').count().over(['ranker_id', 'total_segments']) <= 3)
        .then(50)
        .otherwise(25)
        .alias('segment_tier_scarcity_score'),
        
        # Search segment complexity - NEW
        (pl.col('total_segments').max().over('ranker_id') - 
         pl.col('total_segments').min().over('ranker_id'))
        .alias('search_segment_range'),
        
        # Is this a complex search? - NEW
        (pl.col('total_segments').n_unique().over('ranker_id') >= 3)
        .cast(pl.Int8).alias('is_complex_segment_search')
    ]
    
    return df.with_columns(expressions)


def create_enhanced_company_travel_profile(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates enhanced company travel profile features.
    NEW: Company size inference and department analysis (not in existing code)
    REMOVED: Duplicates of company_traveler_count and company_travel_intensity
    """
    print("Creating enhanced company travel profile features...")
    
    expressions = [
        # Company travel intensity per traveler - NEW (derived from existing)
        (pl.col('ranker_id').count().over('companyID') /
         pl.col('profileId').n_unique().over('companyID'))
        .alias('company_searches_per_traveler'),
        
        # Company route portfolio - NEW
        pl.col('searchRoute').n_unique().over('companyID')
        .alias('company_route_portfolio_size'),
        
        # Company travel diversity - NEW
        (pl.col('searchRoute').n_unique().over('companyID') /
         pl.col('ranker_id').count().over('companyID'))
        .alias('company_route_diversity_ratio'),
        
        # Company size tier inference - NEW
        pl.when(pl.col('profileId').n_unique().over('companyID') >= 50)
        .then(3)  # Large
        .when(pl.col('profileId').n_unique().over('companyID') >= 10)
        .then(2)  # Medium
        .otherwise(1)  # Small
        .alias('company_size_tier'),
        
        # Company travel budget patterns - NEW
        pl.col('totalPrice').mean().over('companyID')
        .alias('company_avg_booking_price'),
        
        pl.col('totalPrice').std().over('companyID')
        .alias('company_price_variance'),
        
        # Company policy indicators - NEW
        (pl.col('totalPrice').std().over('companyID') /
         pl.col('totalPrice').mean().over('companyID'))
        .alias('company_price_coefficient_variation'),
        
        # Department/role inference - NEW
        pl.when(pl.col('totalPrice').mean().over('profileId') >= 
                pl.col('totalPrice').quantile(0.8).over('companyID'))
        .then(3)  # Executive/Senior
        .when(pl.col('totalPrice').mean().over('profileId') >= 
              pl.col('totalPrice').quantile(0.5).over('companyID'))
        .then(2)  # Mid-level
        .otherwise(1)  # Junior/Standard
        .alias('inferred_traveler_tier'),
        
        # Company negotiation power - NEW
        (pl.col('ranker_id').count().over('companyID') >= 100)
        .cast(pl.Int8).alias('has_high_volume_negotiation_power')
    ]
    
    return df.with_columns(expressions)


def create_route_complexity_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates route complexity features that don't duplicate existing route familiarity features.
    NEW: Route complexity inference and geographic patterns
    """
    print("Creating route complexity features...")
    
    expressions = [
        # Route geographic complexity indicators - NEW
        pl.col('searchRoute').str.len_chars().alias('route_string_complexity'),
        
        # Route requires connections indicator - NEW
        (pl.col('total_segments').min().over('searchRoute') > 1)
        .cast(pl.Int8).alias('route_requires_connections'),
        
        # Route connection necessity score - NEW
        (pl.col('total_segments').min().over('searchRoute') - 1)
        .alias('route_minimum_connections_required'),
        
        # Route segment necessity vs choice - NEW
        pl.when(pl.col('total_segments').min().over('searchRoute') == 1)
        .then(pl.col('segment_distance_from_minimum'))  # Pure choice
        .otherwise(pl.col('segment_distance_from_minimum') - 
                  (pl.col('total_segments').min().over('searchRoute') - 1))  # Choice beyond necessity
        .alias('segments_beyond_route_necessity'),
        
        # Route competitive landscape - NEW
        (pl.col('Id').count().over('searchRoute') / 
         pl.col('ranker_id').n_unique().over('searchRoute'))
        .alias('avg_options_per_search_on_route'),
        
        # Route booking concentration - NEW
        (pl.col('ranker_id').n_unique().over('searchRoute') >= 10)
        .cast(pl.Int8).alias('is_popular_route_with_many_searches')
    ]
    
    return df.with_columns(expressions)


# ================================================================
# TIER 3: MEDIUM IMPACT, EASY IMPLEMENTATION (DUPLICATES REMOVED)
# ================================================================

def create_segment_price_interaction_refinements(df: pl.DataFrame) -> pl.DataFrame:
    """
    Creates refined segment-price interaction features.
    NEW: Enhanced price-segment analysis (builds on existing price_premium_vs_min_segments)
    """
    print("Creating segment-price interaction refinements...")
    
    expressions = [
        # Price premium per extra segment (refined) - NEW calculation method
        pl.when(pl.col('segment_distance_from_minimum') > 0)
        .then(pl.col('price_premium_vs_min_segments') / pl.col('segment_distance_from_minimum'))
        .otherwise(0)
        .alias('price_premium_per_extra_segment'),
        
        # Is segment premium justified by company standards? - NEW
        (pl.col('price_premium_vs_min_segments') <= 
         pl.col('price_premium_vs_min_segments').quantile(0.75).over('companyID'))
        .cast(pl.Int8).alias('is_segment_premium_within_company_norm'),
        
        # Segment efficiency score - NEW
        pl.when(pl.col('segment_distance_from_minimum') > 0)
        .then(100 / (1 + pl.col('segment_distance_from_minimum')))
        .otherwise(100)
        .alias('segment_efficiency_score'),
        
        # Premium segment but competitive price - NEW
        ((pl.col('segment_distance_from_minimum') == 0) & 
         (pl.col('totalPrice_percentile_in_group') <= 50))
        .cast(pl.Int8).alias('is_direct_flight_good_value'),
        
        # Connection flight but very cheap - NEW
        ((pl.col('segment_distance_from_minimum') > 0) & 
         (pl.col('totalPrice_percentile_in_group') <= 25))
        .cast(pl.Int8).alias('is_connection_flight_bargain'),
        
        # Route-normalized price premium - NEW
        (pl.col('totalPrice') - pl.col('totalPrice').min().over('searchRoute'))
        .alias('price_premium_vs_route_minimum'),
        
        # Segment value in route context - NEW
        pl.when(pl.col('total_segments') == pl.col('total_segments').min().over('searchRoute'))
        .then(pl.col('totalPrice_percentile_in_group') * -1 + 100)  # Lower price better for route minimum
        .otherwise(pl.col('totalPrice_percentile_in_group') * -1 + 100)  # Lower price better for connections too
        .alias('route_segment_value_score')
    ]
    
    return df.with_columns(expressions)


# ================================================================
# AGGREGATION CONFIGURATIONS FOR CV-AWARE SYSTEM (DUPLICATES REMOVED)
# ================================================================

tier1_aggregation_configs = {
    # Route-specific segment intelligence aggregations - ALL NEW
    'route_min_segment_selection_rate': {
        'group_by': ['searchRoute'],
        'agg_col': 'is_min_segments_for_route',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'route_avg_segments_selected': {
        'group_by': ['searchRoute'],
        'agg_col': 'total_segments',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'route_segment_acceptance_by_tier': {
        'group_by': ['searchRoute', 'route_specific_segment_tier'],
        'agg_col': 'selected',
        'filter_cond': None,
        'agg_func': 'mean'
    },
    
    # Company segment discipline aggregations - ALL NEW
    'company_segment_discipline_score': {
        'group_by': ['companyID'],
        'agg_col': 'company_segment_consistency_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_direct_preference_by_route': {
        'group_by': ['companyID', 'searchRoute'],
        'agg_col': 'is_min_segments',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_segment_override_tolerance': {
        'group_by': ['companyID'],
        'agg_col': 'company_segment_override_rate_all',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # User segment preference aggregations - ALL NEW
    'user_segment_consistency_by_route': {
        'group_by': ['profileId', 'searchRoute'],
        'agg_col': 'user_segment_choice_std',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_segment_preference_strength': {
        'group_by': ['profileId'],
        'agg_col': 'user_min_segment_preference_rate',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_vs_company_segment_deviation': {
        'group_by': ['profileId'],
        'agg_col': 'user_segments_vs_company_norm',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    }
}

tier2_aggregation_configs = {
    # Within-search competitive context aggregations - ALL NEW
    'user_complex_search_behavior': {
        'group_by': ['profileId'],
        'agg_col': 'is_complex_segment_search',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_segment_scarcity_preference': {
        'group_by': ['companyID'],
        'agg_col': 'segment_tier_scarcity_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_only_direct_selection_rate': {
        'group_by': ['profileId'],
        'agg_col': 'is_only_direct_option',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    
    # Enhanced company profile aggregations - ALL NEW
    'company_size_tier_segment_preference': {
        'group_by': ['company_size_tier'],
        'agg_col': 'segment_distance_from_minimum',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'traveler_tier_segment_choice': {
        'group_by': ['companyID', 'inferred_traveler_tier'],
        'agg_col': 'total_segments',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_route_diversity_segment_impact': {
        'group_by': ['companyID'],
        'agg_col': 'company_route_diversity_ratio',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    }
}

tier3_aggregation_configs = {
    # Segment-price interaction refinements - ALL NEW
    'user_segment_value_preference': {
        'group_by': ['profileId'],
        'agg_col': 'route_segment_value_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_segment_efficiency_tolerance': {
        'group_by': ['companyID'],
        'agg_col': 'segment_efficiency_score',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'user_direct_flight_value_sensitivity': {
        'group_by': ['profileId'],
        'agg_col': 'is_direct_flight_good_value',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'company_bargain_connection_acceptance': {
        'group_by': ['companyID'],
        'agg_col': 'is_connection_flight_bargain',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    },
    'route_complexity_user_adaptation': {
        'group_by': ['profileId', 'route_requires_connections'],
        'agg_col': 'segments_beyond_route_necessity',
        'filter_cond': (pl.col('selected') == 1),
        'agg_func': 'mean'
    }
}

In [0]:
def add_ranker_id_counts(df: pl.DataFrame) -> pl.DataFrame:
    """Adds a column with the total count of rows for each ranker_id."""
    return df.with_columns(
        total_rankerid_rows=pl.len().over('ranker_id').alias('session_total_options')
    )

In [0]:
# Main execution
def main():
    # Load data
    print("Loading data...")
    train = pl.read_parquet("./train.parquet")
    test = pl.read_parquet("./test.parquet")
    # benchmark = read_csv(f"{config.data_path}/{config.airport_benchmark}")
    
    # Apply memory optimization if needed
    print("Optimizing memory usage...")
    train = reduce_mem_usage_polars(train)
    test = reduce_mem_usage_polars(test)
    
    print('adding total count of rows...')
    train = add_ranker_id_counts(train)
    test = add_ranker_id_counts(test)
    
    # Extract time features
    print("Extracting time features...")
    train = extract_time_features(train)
    test = extract_time_features(test)
    
    # Convert duration columns to seconds
    print("Converting duration columns...")
    train = convert_duration_to_seconds(train)
    test = convert_duration_to_seconds(test)
    
    # Generate all flight-related features
    print("Generating flight features...")
    
    # Wait time calculation
    train = get_wait_time(train, config.legs)
    test = get_wait_time(test, config.legs)
    
    # Booking time features
    train = get_lead_booking_time(train)
    test = get_lead_booking_time(test)
    
    # Trip length features
    train = get_total_length_of_trip(train)
    test = get_total_length_of_trip(test)
    
    # Trip type
    train = get_trip_type(train)
    test = get_trip_type(test)
    
    # Number of stops
    train = get_number_of_stops(train)
    test = get_number_of_stops(test)
    
    # Flight changes
    train = get_flight_changes_across_segments(train)
    test = get_flight_changes_across_segments(test)
    
    # Cabin changes
    train = get_cabin_changes_across_segments(train)
    test = get_cabin_changes_across_segments(test)
    
    # Baggage changes
    train = get_baggage_quantity_changes_across_segments(train)
    test = get_baggage_quantity_changes_across_segments(test)
    
    # Frequent flyer features
    train = is_frequent_flyer_airline(train)
    test = is_frequent_flyer_airline(test)
    
    # Tax percentage
    train = get_tax_as_percentage_of_price(train)
    test = get_tax_as_percentage_of_price(test)
    
    #ranking features
    train = get_rank_features(train,['totalPrice','tax_as_percentage_of_price',
                                 'percentage_time_spent_in_flight_leg1','percentage_time_spent_in_flight_leg0',
                                 'wait_time_leg0_seconds','wait_time_leg1_seconds',
                                 'num_segs_leg0','num_segs_leg1'],'ranker_id')
    test = get_rank_features(test,['totalPrice','tax_as_percentage_of_price',
                                    'percentage_time_spent_in_flight_leg1','percentage_time_spent_in_flight_leg0',
                                    'wait_time_leg0_seconds','wait_time_leg1_seconds',
                                    'num_segs_leg0','num_segs_leg1'],'ranker_id')

    train = get_percentile_features(train, ['totalPrice'], 'ranker_id')
    test = get_percentile_features(test, ['totalPrice'], 'ranker_id')

    # is segment min feature
    train = is_min_segments_total(train)
    train = is_min_segments_per_leg(train)
    
    test = is_min_segments_total(test)
    test = is_min_segments_per_leg(test)

    train = create_passenger_seat_ratio_features(train)
    test = create_passenger_seat_ratio_features(test)

    train = create_segment_tier_position_features(train)
    train = create_value_gap_features(train)
    train = create_user_convenience_profile_features(train)
    train = create_company_travel_policy_features(train)
    train = create_route_familiarity_features(train)
    train = create_advanced_ranking_features(train)

    test = create_segment_tier_position_features(test)
    test = create_value_gap_features(test)
    test = create_user_convenience_profile_features(test)
    test = create_company_travel_policy_features(test)
    test = create_route_familiarity_features(test)
    test = create_advanced_ranking_features(test)

    train = create_route_segment_intelligence_features(train)
    train = create_company_segment_discipline_features(train)
    train = create_user_segment_preference_features(train)
    train = create_within_search_competitive_context(train)
    train = create_enhanced_company_travel_profile(train)
    train = create_route_complexity_features(train)  # Replaces route_familiarity (had duplicates)
    train = create_segment_price_interaction_refinements(train)

    test = create_route_segment_intelligence_features(test)
    test = create_company_segment_discipline_features(test)
    test = create_user_segment_preference_features(test)
    test = create_within_search_competitive_context(test)
    test = create_enhanced_company_travel_profile(test)
    test = create_route_complexity_features(test)  # Replaces route_familiarity (had duplicates)
    test = create_segment_price_interaction_refinements(test)


    # Final memory optimization
    print("Final memory optimization...")
    train = reduce_mem_usage_polars(train)
    test = reduce_mem_usage_polars(test)
    
    print("Feature engineering completed!")
    print(f"Train shape: {train.shape}")
    print(f"Test shape: {test.shape}")

    combined_configs = {**aggregation_configurations, **high_priority_aggregation_configs,
                            **tier1_aggregation_configs, **tier2_aggregation_configs, **tier3_aggregation_configs}

    train, test = create_cv_aware_aggregate_features_no_leakage(train, test, 
                                                                combined_configs) 
    # Display first few rows
    print("\nTrain data sample:")
    print(train.head())

    return train, test


In [0]:

train, test = main()

Loading data...
Optimizing memory usage...
Memory usage before optimization: 12296.82 MB
Memory usage after optimization: 8714.73 MB (29.1% reduction)
Memory usage before optimization: 4609.86 MB
Memory usage after optimization: 3504.72 MB (24.0% reduction)
adding total count of rows...
Extracting time features...
Converting duration columns...
Converting legs0_duration to seconds
Converting legs0_segments0_duration to seconds
Converting legs0_segments1_duration to seconds
Converting legs0_segments2_duration to seconds
Converting legs0_segments3_duration to seconds
Converting legs1_duration to seconds
Converting legs1_segments0_duration to seconds
Converting legs1_segments1_duration to seconds
Converting legs1_segments2_duration to seconds
Converting legs1_segments3_duration to seconds
Converting legs0_duration to seconds
Converting legs0_segments0_duration to seconds
Converting legs0_segments1_duration to seconds
Converting legs0_segments2_duration to seconds
Converting legs0_segments

In [0]:
train.head()

Id,bySelf,companyID,corporateTariffCode,frequentFlyer,nationality,isAccess3D,isVip,legs0_arrivalAt,legs0_departureAt,legs0_segments0_aircraft_code,legs0_segments0_arrivalTo_airport_city_iata,legs0_segments0_arrivalTo_airport_iata,legs0_segments0_baggageAllowance_quantity,legs0_segments0_baggageAllowance_weightMeasurementType,legs0_segments0_cabinClass,legs0_segments0_departureFrom_airport_iata,legs0_segments0_flightNumber,legs0_segments0_marketingCarrier_code,legs0_segments0_operatingCarrier_code,legs0_segments0_seatsAvailable,legs0_segments1_aircraft_code,legs0_segments1_arrivalTo_airport_city_iata,legs0_segments1_arrivalTo_airport_iata,legs0_segments1_baggageAllowance_quantity,legs0_segments1_baggageAllowance_weightMeasurementType,legs0_segments1_cabinClass,legs0_segments1_departureFrom_airport_iata,legs0_segments1_flightNumber,legs0_segments1_marketingCarrier_code,legs0_segments1_operatingCarrier_code,legs0_segments1_seatsAvailable,legs0_segments2_aircraft_code,legs0_segments2_arrivalTo_airport_city_iata,legs0_segments2_arrivalTo_airport_iata,legs0_segments2_baggageAllowance_quantity,legs0_segments2_baggageAllowance_weightMeasurementType,legs0_segments2_cabinClass,legs0_segments2_departureFrom_airport_iata,legs0_segments2_flightNumber,legs0_segments2_marketingCarrier_code,legs0_segments2_operatingCarrier_code,legs0_segments2_seatsAvailable,legs0_segments3_aircraft_code,legs0_segments3_arrivalTo_airport_city_iata,legs0_segments3_arrivalTo_airport_iata,legs0_segments3_baggageAllowance_quantity,legs0_segments3_baggageAllowance_weightMeasurementType,legs0_segments3_cabinClass,legs0_segments3_departureFrom_airport_iata,legs0_segments3_flightNumber,legs0_segments3_marketingCarrier_code,legs0_segments3_operatingCarrier_code,legs0_segments3_seatsAvailable,legs1_arrivalAt,legs1_departureAt,legs1_segments0_aircraft_code,legs1_segments0_arrivalTo_airport_city_iata,legs1_segments0_arrivalTo_airport_iata,legs1_segments0_baggageAllowance_quantity,legs1_segments0_baggageAllowance_weightMeasurementType,legs1_segments0_cabinClass,legs1_segments0_departureFrom_airport_iata,legs1_segments0_flightNumber,legs1_segments0_marketingCarrier_code,legs1_segments0_operatingCarrier_code,legs1_segments0_seatsAvailable,legs1_segments1_aircraft_code,legs1_segments1_arrivalTo_airport_city_iata,legs1_segments1_arrivalTo_airport_iata,legs1_segments1_baggageAllowance_quantity,legs1_segments1_baggageAllowance_weightMeasurementType,legs1_segments1_cabinClass,legs1_segments1_departureFrom_airport_iata,legs1_segments1_flightNumber,legs1_segments1_marketingCarrier_code,legs1_segments1_operatingCarrier_code,legs1_segments1_seatsAvailable,legs1_segments2_aircraft_code,legs1_segments2_arrivalTo_airport_city_iata,legs1_segments2_arrivalTo_airport_iata,legs1_segments2_baggageAllowance_quantity,legs1_segments2_baggageAllowance_weightMeasurementType,legs1_segments2_cabinClass,legs1_segments2_departureFrom_airport_iata,legs1_segments2_flightNumber,legs1_segments2_marketingCarrier_code,legs1_segments2_operatingCarrier_code,legs1_segments2_seatsAvailable,legs1_segments3_aircraft_code,legs1_segments3_arrivalTo_airport_city_iata,legs1_segments3_arrivalTo_airport_iata,legs1_segments3_baggageAllowance_quantity,legs1_segments3_baggageAllowance_weightMeasurementType,legs1_segments3_cabinClass,legs1_segments3_departureFrom_airport_iata,legs1_segments3_flightNumber,legs1_segments3_marketingCarrier_code,legs1_segments3_operatingCarrier_code,legs1_segments3_seatsAvailable,miniRules0_monetaryAmount,miniRules0_percentage,miniRules0_statusInfos,miniRules1_monetaryAmount,miniRules1_percentage,miniRules1_statusInfos,pricingInfo_isAccessTP,pricingInfo_passengerCount,profileId,ranker_id,requestDate,searchRoute,sex,taxes,totalPrice,selected,__index_level_0__,total_rankerid_rows,legs0_arrivalAt_hour,legs0_arrivalAt_dayofweek,legs0_arrivalAt_month,legs0_arrivalAt_day,legs0_departureAt_hour,legs0_departureAt_dayofweek,legs0_departureAt_month,legs0_departureAt_day,legs1_arrivalAt_hour,legs1_arrivalAt_dayofweek,legs1_arrivalAt_month,legs1_arrivalAt_day,legs1_departureAt_hour,legs1_departureAt_dayofweek,legs1_departureAt_month,legs1_departureAt_day,requestDate_hour,requestDate_dayofweek,requestDate_month,requestDate_day,legs0_duration_seconds,legs0_segments0_duration_seconds,legs0_segments1_duration_seconds,legs0_segments2_duration_seconds,legs0_segments3_duration_seconds,legs1_duration_seconds,legs1_segments0_duration_seconds,legs1_segments1_duration_seconds,legs1_segments2_duration_seconds,legs1_segments3_duration_seconds,wait_time_leg0_seconds,wait_time_leg1_seconds,lead_booking_time_seconds,lead_booking_time_wrt_return_seconds,total_length_of_trip_seconds,percentage_time_spent_in_flight_leg0,percentage_time_spent_in_flight_leg1,trip_type,num_segs_leg0,num_segs_leg1,aircraft_changes_leg0,aircraft_changes_leg1,cabin_changes_leg0,cabin_changes_leg1,baggage_quantity_changes_leg0,baggage_quantity_changes_leg1,is_frequent_flyer_airline_leg0_segment0,is_frequent_flyer_airline_leg0_segment1,is_frequent_flyer_airline_leg0_segment2,is_frequent_flyer_airline_leg0_segment3,is_frequent_flyer_airline_leg1_segment0,is_frequent_flyer_airline_leg1_segment1,is_frequent_flyer_airline_leg1_segment2,is_frequent_flyer_airline_leg1_segment3,tax_as_percentage_of_price,totalPrice_rank_in_group,tax_as_percentage_of_price_rank_in_group,percentage_time_spent_in_flight_leg1_rank_in_group,percentage_time_spent_in_flight_leg0_rank_in_group,wait_time_leg0_seconds_rank_in_group,wait_time_leg1_seconds_rank_in_group,num_segs_leg0_rank_in_group,num_segs_leg1_rank_in_group,totalPrice_percentile_in_group,total_segments,is_min_segments,is_min_segments_leg0,is_min_segments_leg1,is_min_segments_both_legs,passenger_to_legs0_segments0_seatsAvailable_ratio,passenger_to_legs0_segments1_seatsAvailable_ratio,passenger_to_legs1_segments0_seatsAvailable_ratio,passenger_to_legs1_segments1_seatsAvailable_ratio,position_within_segment_tier,options_in_segment_tier,position_pct_within_segment_tier,segment_distance_from_minimum,price_premium_vs_min_segments,segment_tier,total_segments_right,price_per_extra_segment,time_saved_per_dollar_premium,convenience_value_score,is_sweet_spot_option,price_gap_to_better_segments,segment_price_interaction,premium_paid_vs_cheapest,segment_flexibility_score,chose_convenience_over_price,chose_price_over_convenience,premium_segment_choice,company_traveler_count,chose_bottom_quartile_price,chose_minimum_segments,chose_top_quartile_price,policy_flexibility_interaction,company_route_frequency,company_travel_intensity,user_route_frequency,overall_route_popularity,user_route_specialization_pct,company_route_experience,is_frequent_route_for_user,user_route_diversity,is_new_route_for_user,corporate_route_standardization_pct,price_rank_within_segment_tier,wait_time_rank_within_segment_tier,lead_time_rank_within_segment_tier,is_cheapest_in_segment_tier,is_fastest_in_segment_tier,price_gap_to_next_cheaper,price_gap_to_next_expensive,price_quintile,segment_price_category,total_options_available,has_many_options,has_few_options,route_avg_segments_all_options,route_min_segments_available,route_max_segments_available,route_segment_diversity,segments_deviation_from_route_avg,is_min_segments_for_route,route_specific_segment_tier,route_segment_complexity_score,route_segment_percentile,company_segment_choice_std,company_avg_segments_all_searches,company_min_segment_selection_rate_all,company_segment_override_rate_all,company_route_segment_flexibility,company_segment_consistency_score,company_direct_flight_preference_rate,user_avg_segments_all_searches,user_segment_choice_std,user_min_segment_preference_rate,user_avg_segment_distance_from_min,user_route_specific_avg_segments,user_segment_consistency_score,user_segments_vs_company_norm,user_segment_override_rate,segment_tiers_available_in_search,options_in_this_segment_tier,is_only_direct_option,segment_tier_dominance_pct,price_spread_within_segment_tier,segment_tier_scarcity_score,search_segment_range,is_complex_segment_search,company_searches_per_traveler,company_route_portfolio_size,company_route_diversity_ratio,company_size_tier,company_avg_booking_price,company_price_variance,company_price_coefficient_variation,inferred_traveler_tier,has_high_volume_negotiation_power,route_string_complexity,route_requires_connections,route_minimum_connections_required,segments_beyond_route_necessity,avg_options_per_search_on_route,is_popular_route_with_many_searches,price_premium_per_extra_segment,is_segment_premium_within_company_norm,segment_efficiency_score,is_direct_flight_good_value,is_connection_flight_bargain,price_premium_vs_route_minimum,route_segment_value_score,fold,avg_price_rank_by_comp_route,std_price_rank_by_comp_route,avg_price_rank_by_comp_route_baggageallowance,std_price_rank_by_comp_route_baggageallowance,avg_price_percentile_by_comp_route,std_price_percentile_by_comp_route,median_price_percentile_by_comp_route,avg_price_rank_by_comp_route_tariffcode,std_price_rank_by_comp_route_tariffcode,avg_tax_as_percentage_of_price_rank_comp_route,std_tax_as_percentage_of_price_rank_comp_route,avg_percentage_time_spent_in_flight_leg1_comp_route,std_percentage_time_spent_in_flight_leg1_comp_route,avg_percentage_time_spent_in_flight_leg0_comp_route,std_percentage_time_spent_in_flight_leg0_comp_route,avg_wait_time_leg0_seconds_comp_route,std_wait_time_leg0_seconds_comp_route,avg_wait_time_leg1_seconds_comp_route,std_wait_time_leg1_seconds_comp_route,avg_num_segs_leg0_comp_route,std_num_segs_leg0_comp_route,avg_num_segs_leg0_comp_route_baggageallowance,std_num_segs_leg0_comp_route_baggageallowance,avg_num_segs_leg1_comp_route,std_num_segs_leg1_comp_route,avg_num_segs_leg0_comp_route_miniRules0_statusInfos,std_num_segs_leg0_comp_route_miniRules0_statusInfos,avg_num_segs_leg1_comp_route_miniRules1_statusInfos,std_num_segs_leg1_comp_route_miniRules1_statusInfos,max_num_segs_leg1_comp_route_miniRules1_statusInfos,min_num_segs_leg1_comp_route_miniRules1_statusInfos,max_num_segs_leg0_comp_route_miniRules0_statusInfos,min_num_segs_leg0_comp_route_miniRules0_statusInfos,avg_total_price_by_user,avg_price_percentile_by_user,std_price_percentile_by_user,avg_price_rank_by_user,avg_stops_leg0_by_user,avg_stops_leg1_by_user,user_min_segments_selection_rate,avg_wait_time_leg0_by_user,selection_rate_by_airline_for_user,avg_baggage_allowance_by_user,avg_lead_time_by_user,avg_passenger_count_by_user,avg_segment_flexibility_score_by_user,std_segment_flexibility_score_by_user,user_convenience_over_price_rate,user_price_over_convenience_rate,avg_premium_paid_vs_cheapest_by_user,company_price_discipline_rate,company_segment_discipline_rate,company_premium_policy_rate,avg_policy_flexibility_by_company,std_policy_flexibility_by_company,user_avg_segment_tier_for_route,user_price_percentile_for_route,company_avg_segment_tier_for_route,avg_position_within_segment_tier_by_user,company_segment_tier_preference,user_avg_convenience_value_score,user_sweet_spot_selection_rate,user_avg_price_per_extra_segment,user_route_specialization_avg,company_route_standardization_avg,route_min_segment_selection_rate,route_avg_segments_selected,route_segment_acceptance_by_tier,company_segment_discipline_score,company_direct_preference_by_route,company_segment_override_tolerance,user_segment_consistency_by_route,user_segment_preference_strength,user_vs_company_segment_deviation,user_complex_search_behavior,company_segment_scarcity_preference,user_only_direct_selection_rate,company_size_tier_segment_preference,traveler_tier_segment_choice,company_route_diversity_segment_impact,user_segment_value_preference,company_segment_efficiency_tolerance,user_direct_flight_value_sensitivity,company_bargain_connection_acceptance,route_complexity_user_adaptation
i32,bool,i32,i16,str,i8,bool,bool,datetime[μs],datetime[μs],str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,datetime[μs],datetime[μs],str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,str,str,str,f32,f32,f32,str,str,str,str,f32,f32,f32,f32,f32,f32,f32,f32,i8,i32,str,datetime[ns],str,bool,f32,f32,i8,i32,i16,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f64,f64,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,i8,f32,i16,i16,i16,i16,i16,i16,i16,i16,f32,i8,i8,i8,i8,i8,f32,f32,f32,f32,i16,i16,f32,i8,f32,i8,i8,f32,f32,f32,i8,f32,f32,f32,f32,i8,i8,i8,i16,i8,i8,i8,f32,i32,i32,i16,i32,f32,i32,i8,i8,i8,f32,i16,i16,i16,i8,i8,f32,f32,i8,str,i16,i8,i8,f32,i8,i8,f32,f32,i8,i8,i8,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,i8,i16,i8,f32,f32,i8,i8,i8,f32,i16,f32,i8,f32,f32,f32,i8,i8,i8,i8,i8,i8,f32,i8,f32,i8,f32,i8,i8,f32,f32,i8,f64,f64,f64,f64,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f32,f32,f32,f32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8,i8,i8,i8,f32,f32,f32,f64,f64,f64,f64,f32,f64,f32,f32,f64,f32,f32,f64,f64,f32,f64,f64,f64,f32,f32,f64,f32,f64,f32,f64,f32,f64,f32,f32,f32,f64,f64,f64,f32,f64,f32,f32,f32,f32,f64,f64,f64,f64,f64,f32,f32,f32,f64,f64,f64
143,True,53407,,,36,False,False,2024-05-23 14:50:00,2024-05-23 11:35:00,"""AT7""","""IKT""","""IKT""",0.0,0.0,1.0,"""KJA""","""137""","""UT""","""UT""",5.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,6000.0,,1.0,0.0,,0.0,1.0,1,2503394,"""e109b50aca4a43908dd146c55733e3…",2024-05-17 04:06:51,"""KJAIKT""",True,1215.0,3515.0,0,143,33,14,4,5,23,11,4,5,23,,,,,,,,,4,5,5,17,8100.0,8100.0,,,,,,,,,0.0,,545289.0,,,,,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.345563,1,33,,,1,,1,1,0.0,2,1,1,1,1,0.2,,,,1,15,0.0,0,0.0,1,,0.0,,120.0,1,,0.0,0.0,0.0,0,0,0,112,1,1,0,0.0,1265,18962,33,7883,51.5625,1265,1,4,0,6.671237,1,1,13,1,1,,1800.0,0,"""1_0""",33,1,0,2.491818,2,3,0.499965,-0.491818,1,1,2,0.0,0.823906,2.905073,0.474739,0.525261,0.500028,0.548274,0.474739,2.703125,0.460493,0.671875,0.328125,2.545455,0.6847,-0.905073,0.328125,2,15,0,45.454544,34342.0,25,1,0,169.303574,65,0.003428,3,71795.484375,86733.75,1.208067,1,1,6,1,1,-1,26.364548,1,0.0,1,100.0,1,0,0.0,100.0,0,5.914286,3.475726,3.666667,1.154701,17.656855,12.115802,16.67,,,14.028571,7.122812,,,,,0.0,0.0,,,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1,1,1,1,35416.0,34.375,23.662119,3.75,1.5,1.0,1.0,24900.0,,5.75,399243.5,1.0,0.0,0.0,0.0,0.0,10174.0,0.555035,0.957845,0.067916,2.352365,12.714286,,,1.0,,1.056206,70.625,1.0,10174.0,21.484375,3.884973,0.991561,2.008439,0.072598,0.548275,1.0,0.525264,,0.671875,-0.405073,0.0,41.042155,0.25,0.026174,2.123223,0.003428,65.625,97.658089,1.0,0.007026,-1.0
144,True,53407,,,36,False,False,2024-05-23 14:50:00,2024-05-23 11:35:00,"""AT7""","""IKT""","""IKT""",1.0,0.0,1.0,"""KJA""","""137""","""UT""","""UT""",5.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,2000.0,,1.0,0.0,,0.0,1.0,1,2503394,"""e109b50aca4a43908dd146c55733e3…",2024-05-17 04:06:51,"""KJAIKT""",True,1215.0,5315.0,0,144,33,14,4,5,23,11,4,5,23,,,,,,,,,4,5,5,17,8100.0,8100.0,,,,,,,,,0.0,,545289.0,,,,,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.228555,4,32,,,2,,2,2,9.38,2,1,1,1,1,0.2,,,,4,15,21.428572,0,1800.0,1,,1800.0,,110.620003,1,,0.0,1800.0,0.0,0,0,0,112,1,1,0,0.0,1265,18962,33,7883,51.5625,1265,1,4,0,6.671237,4,2,14,0,1,-3200.0,592.0,0,"""1_0""",33,1,0,2.491818,2,3,0.499965,-0.491818,1,1,2,0.012687,0.823906,2.905073,0.474739,0.525261,0.500028,0.548274,0.474739,2.703125,0.460493,0.671875,0.328125,2.545455,0.6847,-0.905073,0.328125,2,15,0,45.454544,34342.0,25,1,0,169.303574,65,0.003428,3,71795.484375,86733.75,1.208067,1,1,6,1,1,-1,26.364548,1,0.0,1,100.0,1,0,1800.0,90.620003,0,5.914286,3.475726,6.125,3.553735,17.656855,12.115802,16.67,,,14.028571,7.122812,,,,,0.0,0.0,,,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1,1,1,1,35416.0,34.375,23.662119,3.75,1.5,1.0,1.0,24900.0,,5.75,399243.5,1.0,0.0,0.0,0.0,0.0,10174.0,0.555035,0.957845,0.067916,2.352365,12.714286,,,1.0,,1.056206,70.625,1.0,10174.0,21.484375,3.884973,0.991561,2.008439,0.072598,0.548275,1.0,0.525264,,0.671875,-0.405073,0.0,41.042155,0.25,0.026174,2.123223,0.003428,65.625,97.658089,1.0,0.007026,-1.0
145,True,53407,,,36,False,False,2024-05-23 14:50:00,2024-05-23 11:35:00,"""AT7""","""IKT""","""IKT""",1.0,0.0,1.0,"""KJA""","""137""","""UT""","""UT""",5.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,2000.0,,1.0,2000.0,,1.0,1.0,1,2503394,"""e109b50aca4a43908dd146c55733e3…",2024-05-17 04:06:51,"""KJAIKT""",True,1215.0,8515.0,0,145,33,14,4,5,23,11,4,5,23,,,,,,,,,4,5,5,17,8100.0,8100.0,,,,,,,,,0.0,,545289.0,,,,,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.142673,7,30,,,3,,3,3,18.75,2,1,1,1,1,0.2,,,,7,15,42.857143,0,5000.0,1,,5000.0,,101.25,1,,0.0,5000.0,0.0,0,0,0,112,1,1,0,0.0,1265,18962,33,7883,51.5625,1265,1,4,0,6.671237,7,3,15,0,1,-29342.0,-2608.0,0,"""1_0""",33,1,0,2.491818,2,3,0.499965,-0.491818,1,1,2,0.025374,0.823906,2.905073,0.474739,0.525261,0.500028,0.548274,0.474739,2.703125,0.460493,0.671875,0.328125,2.545455,0.6847,-0.905073,0.328125,2,15,0,45.454544,34342.0,25,1,0,169.303574,65,0.003428,3,71795.484375,86733.75,1.208067,1,1,6,1,1,-1,26.364548,1,0.0,1,100.0,1,0,5000.0,81.25,0,5.914286,3.475726,6.125,3.553735,17.656855,12.115802,16.67,,,14.028571,7.122812,,,,,0.0,0.0,,,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1,1,1,1,35416.0,34.375,23.662119,3.75,1.5,1.0,1.0,24900.0,,5.75,399243.5,1.0,0.0,0.0,0.0,0.0,10174.0,0.555035,0.957845,0.067916,2.352365,12.714286,,,1.0,,1.056206,70.625,1.0,10174.0,21.484375,3.884973,0.991561,2.008439,0.072598,0.548275,1.0,0.525264,,0.671875,-0.405073,0.0,41.042155,0.25,0.026174,2.123223,0.003428,65.625,97.658089,1.0,0.007026,-1.0
146,True,53407,101.0,,36,True,False,2024-05-23 11:10:00,2024-05-23 08:25:00,"""SU9""","""IKT""","""IKT""",0.0,0.0,1.0,"""KJA""","""6881""","""SU""","""FV""",9.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,4600.0,,1.0,0.0,,0.0,1.0,1,2503394,"""e109b50aca4a43908dd146c55733e3…",2024-05-17 04:06:51,"""KJAIKT""",True,417.0,4417.0,0,146,33,11,4,5,23,8,4,5,23,,,,,,,,,4,5,5,17,6300.0,6300.0,,,,,,,,,0.0,,533889.0,,,,,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.094387,2,26,,,4,,4,4,3.12,2,1,1,1,1,0.111111,,,,2,15,7.142857,0,902.0,1,,902.0,,116.879997,1,,0.0,902.0,0.0,0,0,0,112,1,1,0,0.0,1265,18962,33,7883,51.5625,1265,1,4,0,6.671237,2,4,1,0,1,902.0,4098.0,0,"""1_0""",33,1,0,2.491818,2,3,0.499965,-0.491818,1,1,2,0.038061,0.823906,2.905073,0.474739,0.525261,0.500028,0.548274,0.474739,2.703125,0.460493,0.671875,0.328125,2.545455,0.6847,-0.905073,0.328125,2,15,0,45.454544,34342.0,25,1,0,169.303574,65,0.003428,3,71795.484375,86733.75,1.208067,1,1,6,1,1,-1,26.364548,1,0.0,1,100.0,1,0,902.0,96.879997,0,5.914286,3.475726,3.666667,1.154701,17.656855,12.115802,16.67,5.708333,3.209756,14.028571,7.122812,,,,,0.0,0.0,,,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1,1,1,1,35416.0,34.375,23.662119,3.75,1.5,1.0,1.0,24900.0,1.0,5.75,399243.5,1.0,0.0,0.0,0.0,0.0,10174.0,0.555035,0.957845,0.067916,2.352365,12.714286,,,1.0,,1.056206,70.625,1.0,10174.0,21.484375,3.884973,0.991561,2.008439,0.072598,0.548275,1.0,0.525264,,0.671875,-0.405073,0.0,41.042155,0.25,0.026174,2.123223,0.003428,65.625,97.658089,1.0,0.007026,-1.0
147,True,53407,101.0,,36,True,False,2024-05-23 11:10:00,2024-05-23 08:25:00,"""SU9""","""IKT""","""IKT""",1.0,0.0,1.0,"""KJA""","""6881""","""SU""","""FV""",9.0,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,4600.0,,1.0,4600.0,,1.0,1.0,1,2503394,"""e109b50aca4a43908dd146c55733e3…",2024-05-17 04:06:51,"""KJAIKT""",True,417.0,5907.0,0,147,33,11,4,5,23,8,4,5,23,,,,,,,,,4,5,5,17,6300.0,6300.0,,,,,,,,,0.0,,533889.0,,,,,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0.070582,5,22,,,5,,5,5,12.5,2,1,1,1,1,0.111111,,,,5,15,28.571428,0,2392.0,1,,2392.0,,107.5,1,,0.0,2392.0,0.0,0,0,0,112,1,1,0,0.0,1265,18962,33,7883,51.5625,1265,1,4,0,6.671237,5,5,2,0,1,1490.0,3780.0,0,"""1_0""",33,1,0,2.491818,2,3,0.499965,-0.491818,1,1,2,0.050749,0.823906,2.905073,0.474739,0.525261,0.500028,0.548274,0.474739,2.703125,0.460493,0.671875,0.328125,2.545455,0.6847,-0.905073,0.328125,2,15,0,45.454544,34342.0,25,1,0,169.303574,65,0.003428,3,71795.484375,86733.75,1.208067,1,1,6,1,1,-1,26.364548,1,0.0,1,100.0,1,0,2392.0,87.5,0,5.914286,3.475726,6.125,3.553735,17.656855,12.115802,16.67,5.708333,3.209756,14.028571,7.122812,,,,,0.0,0.0,,,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1.0,0.0,1,1,1,1,35416.0,34.375,23.662119,3.75,1.5,1.0,1.0,24900.0,1.0,5.75,399243.5,1.0,0.0,0.0,0.0,0.0,10174.0,0.555035,0.957845,0.067916,2.352365,12.714286,,,1.0,,1.056206,70.625,1.0,10174.0,21.484375,3.884973,0.991561,2.008439,0.072598,0.548275,1.0,0.525264,,0.671875,-0.405073,0.0,41.042155,0.25,0.026174,2.123223,0.003428,65.625,97.658089,1.0,0.007026,-1.0


In [0]:
print(train.columns)

['Id', 'bySelf', 'companyID', 'corporateTariffCode', 'frequentFlyer', 'nationality', 'isAccess3D', 'isVip', 'legs0_arrivalAt', 'legs0_departureAt', 'legs0_segments0_aircraft_code', 'legs0_segments0_arrivalTo_airport_city_iata', 'legs0_segments0_arrivalTo_airport_iata', 'legs0_segments0_baggageAllowance_quantity', 'legs0_segments0_baggageAllowance_weightMeasurementType', 'legs0_segments0_cabinClass', 'legs0_segments0_departureFrom_airport_iata', 'legs0_segments0_flightNumber', 'legs0_segments0_marketingCarrier_code', 'legs0_segments0_operatingCarrier_code', 'legs0_segments0_seatsAvailable', 'legs0_segments1_aircraft_code', 'legs0_segments1_arrivalTo_airport_city_iata', 'legs0_segments1_arrivalTo_airport_iata', 'legs0_segments1_baggageAllowance_quantity', 'legs0_segments1_baggageAllowance_weightMeasurementType', 'legs0_segments1_cabinClass', 'legs0_segments1_departureFrom_airport_iata', 'legs0_segments1_flightNumber', 'legs0_segments1_marketingCarrier_code', 'legs0_segments1_operatingCar

In [0]:
def prepare_for_xgb(df, missing_threshold=0.5, categorical_mappings=None, train = False):
    # Step 1: Remove columns with >50% missing
    missing_stats = df.null_count() / df.height
    cols_to_keep = [col for col in df.columns 
                    if missing_stats[col][0] <= missing_threshold]
    if train:
        df_filtered = df.select(cols_to_keep)
    else:
        df_filtered = df

    # Step 2: Exclude ID/datetime columns  
    exclude = {'Id', 'searchId', 'ranker_id', 'companyID', 'profileId', 'selected', 
               'fold', 'requestDate', 'legs0_arrivalAt', 'legs0_departureAt', 
               'legs1_arrivalAt', 'legs1_departureAt'}
    
    # Step 3: Optimal dtype conversion
    expressions = []
    for col in df_filtered.columns:
        if col in exclude:
            expressions.append(pl.col(col))
        elif df_filtered[col].dtype == pl.String:
            if categorical_mappings and col in categorical_mappings:
                # VECTORIZED approach - much faster than map_elements
                valid_categories = categorical_mappings[col]
                enum_type = pl.Enum(valid_categories)
                expressions.append(
                    pl.when(pl.col(col).is_null())
                    .then(pl.lit('missing'))
                    .when(pl.col(col).is_in(valid_categories))
                    .then(pl.col(col))
                    .otherwise(pl.lit('missing'))  # Unseen values → missing
                    .cast(enum_type)
                    .to_physical()
                    .alias(col)
                )
            else:
                # Training phase - infer categories
                expressions.append(pl.col(col).fill_null('missing').cast(pl.Categorical).to_physical().alias(col))
        elif df_filtered[col].dtype == pl.Boolean:
            expressions.append(pl.col(col).cast(pl.Int32).alias(col))
        elif df_filtered[col].dtype in [pl.Int8, pl.Int16]:
            expressions.append(pl.col(col).cast(pl.Int32).alias(col))
        # elif df_filtered[col].dtype == pl.Float32:
        #     expressions.append(pl.col(col).cast(pl.Float32).alias(col))
        # elif df_filtered[col].dtype == pl.Float64:
        #     expressions.append(pl.col(col).cast(pl.Float64).alias(col))         
        else:
            expressions.append(pl.col(col))
    return df_filtered.with_columns(expressions)

def extract_categorical_mappings(train_df, missing_threshold=0.5):
    """Extract categorical mappings from training data to ensure consistency."""
    # Get columns that will be kept
    missing_stats = train_df.null_count() / train_df.height
    cols_to_keep = [col for col in train_df.columns 
                    if missing_stats[col][0] <= missing_threshold]
    
    exclude = {'Id', 'searchId', 'ranker_id', 'companyID', 'profileId', 'selected', 
               'fold', 'requestDate', 'legs0_arrivalAt', 'legs0_departureAt', 
               'legs1_arrivalAt', 'legs1_departureAt'}
    
    categorical_mappings = {}
    for col in cols_to_keep:
        if col not in exclude and train_df[col].dtype == pl.String:
            # Get unique categories from training data
            unique_categories = (
                train_df.select(col)
                .fill_null('missing')  # This ensures missing is in the data
                .unique()
                .sort(col)
                .to_series()
                .to_list()
            )
            
            # Always ensure "missing" is in the categories (in case no nulls in training)
            if 'missing' not in unique_categories:
                unique_categories.append('missing')
                unique_categories.sort()  # Keep sorted for consistency
            
            categorical_mappings[col] = unique_categories
    
    return categorical_mappings

# Usage:
# Extract mappings from training data first
categorical_mappings = extract_categorical_mappings(train)

# Apply consistent mappings to both datasets
train = prepare_for_xgb(train, categorical_mappings=categorical_mappings, train=True)
test = prepare_for_xgb(test, categorical_mappings=categorical_mappings, train=False)

print(f"✅ Ready! Shape: train{train.shape}, test{test.shape}")

✅ Ready! Shape: train(18145372, 291), test(6897776, 714)


In [0]:
def remove_zero_variance_features(train_df, test_df):
    """
    Identifies zero-variance columns from the training set and drops them
    from both the training and test sets.
    """
    print("🔎 Identifying zero-variance columns...")
    
    # Define columns to ignore during the check
    exclude_cols = {
        'Id', 'searchId', 'ranker_id', 'companyID', 'profileId', 'selected', 
        'fold', 'requestDate', 'legs0_arrivalAt', 'legs0_departureAt', 
        'legs1_arrivalAt', 'legs1_departureAt', '__index_level_0__'
    }
    
    cols_to_check = [col for col in train_df.columns if col not in exclude_cols]
    
    # Identify columns with 1 or fewer unique values in the training data
    zero_variance_cols = [
        col for col in cols_to_check 
        if train_df[col].n_unique() <= 1
    ]
    
    if zero_variance_cols:
        print(f"🗑️ Dropping {len(zero_variance_cols)} zero-variance columns: {zero_variance_cols}")
        train_df = train_df.drop(zero_variance_cols)
        test_df = test_df.drop(zero_variance_cols)
    else:
        print("✅ No zero-variance columns found.")
        
    return train_df, test_df
  
train, test = remove_zero_variance_features(train, test)

🔎 Identifying zero-variance columns...
🗑️ Dropping 5 zero-variance columns: ['bySelf', 'pricingInfo_passengerCount', 'is_frequent_flyer_airline_leg0_segment3', 'is_frequent_flyer_airline_leg1_segment3', 'route_requires_connections']


In [0]:
drop = ['__index_level_0__','Id','companyID','legs0_arrivalAt', 'legs0_departureAt','legs1_arrivalAt',
'legs1_departureAt','profileId','ranker_id','requestDate','selected','fold']

In [0]:
try:
    del _
except:
    pass

In [0]:
for i in range(10):
    gc.collect()

In [0]:
def hitrate_at_3(y_true, y_pred, groups):
    """
    Your hit rate @ 3 metric function
    """
    df = pl.DataFrame({
        'group': groups,
        'pred': y_pred,
        'true': y_true
    })
    
    return (
        df.filter(pl.col("group").count().over("group") > 10)
        .sort(["group", "pred"], descending=[False, True])
        .group_by("group", maintain_order=True)
        .head(3)
        .group_by("group")
        .agg(pl.col("true").max())
        .select(pl.col("true").mean())
        .item()
    )

In [0]:
top100 = ['segment_distance_from_minimum',
 'is_min_segments',
 'segment_tier',
 'policy_flexibility_interaction',
 'segment_price_category',
 'chose_minimum_segments',
 'is_only_direct_option',
 'segment_flexibility_score',
 'company_avg_segment_tier_for_route',
 'price_premium_per_extra_segment',
 'legs0_segments0_cabinClass',
 'wait_time_leg0_seconds',
 'company_direct_preference_by_route',
 'legs0_segments0_baggageAllowance_quantity',
 'segment_price_interaction',
 'segment_tier_dominance_pct',
 'is_min_segments_both_legs',
 'user_avg_segment_tier_for_route',
 'pricingInfo_isAccessTP',
 'segment_tier_scarcity_score',
 'wait_time_rank_within_segment_tier',
 'wait_time_leg1_seconds',
 'segments_beyond_route_necessity',
 'isAccess3D',
 'total_segments',
 'miniRules0_monetaryAmount',
 'price_quintile',
 'miniRules1_monetaryAmount',
 'max_num_segs_leg1_comp_route_miniRules1_statusInfos',
 'avg_baggage_allowance_by_user',
 'min_num_segs_leg1_comp_route_miniRules1_statusInfos',
 'avg_num_segs_leg1_comp_route',
 'avg_num_segs_leg0_comp_route_baggageallowance',
 'miniRules1_statusInfos',
 'legs1_segments0_marketingCarrier_code',
 'is_min_segments_leg0',
 'is_popular_route_with_many_searches',
 'num_segs_leg0_rank_in_group',
 'avg_wait_time_leg0_seconds_comp_route',
 'legs1_segments0_cabinClass',
 'std_num_segs_leg1_comp_route',
 'is_frequent_flyer_airline_leg0_segment0',
 'is_frequent_flyer_airline_leg0_segment1',
 'min_num_segs_leg0_comp_route_miniRules0_statusInfos',
 'user_min_segments_selection_rate',
 'avg_num_segs_leg0_comp_route',
 'wait_time_leg0_seconds_rank_in_group',
 'isVip',
 'avg_segment_flexibility_score_by_user',
 'avg_wait_time_leg1_seconds_comp_route',
 'user_route_specific_avg_segments',
 'price_premium_vs_min_segments',
 'route_max_segments_available',
 'price_per_extra_segment',
 'selection_rate_by_airline_for_user',
 'legs0_segments0_marketingCarrier_code',
 'avg_num_segs_leg1_comp_route_miniRules1_statusInfos',
 'max_num_segs_leg0_comp_route_miniRules0_statusInfos',
 'median_price_percentile_by_comp_route',
 'miniRules0_statusInfos',
 'is_direct_flight_good_value',
 'legs1_departureAt_hour',
 'avg_price_rank_by_comp_route_baggageallowance',
 'is_frequent_flyer_airline_leg1_segment1',
 'options_in_segment_tier',
 'legs1_segments0_operatingCarrier_code',
 'legs0_segments0_baggageAllowance_weightMeasurementType',
 'legs1_segments0_baggageAllowance_quantity',
 'legs0_segments0_operatingCarrier_code',
 'cabin_changes_leg0',
 'legs0_duration_seconds',
 'chose_bottom_quartile_price',
 'nationality',
 'avg_premium_paid_vs_cheapest_by_user',
 'legs1_duration_seconds',
 'total_length_of_trip_seconds',
 'route_avg_segments_all_options',
 'premium_segment_choice',
 'avg_price_percentile_by_comp_route',
 'std_num_segs_leg0_comp_route',
 'price_rank_within_segment_tier',
 'user_price_percentile_for_route',
 'legs0_departureAt_hour',
 'user_segment_value_preference',
 'route_segment_acceptance_by_tier',
 'overall_route_popularity',
 'legs1_arrivalAt_hour',
 'std_wait_time_leg0_seconds_comp_route',
 'legs1_segments0_arrivalTo_airport_city_iata',
 'percentage_time_spent_in_flight_leg1_rank_in_group',
 'is_frequent_flyer_airline_leg1_segment0',
 'legs0_arrivalAt_hour',
 'num_segs_leg0',
 'legs1_segments0_baggageAllowance_weightMeasurementType',
 'position_within_segment_tier',
 'is_segment_premium_within_company_norm',
 'num_segs_leg1_rank_in_group',
 'is_min_segments_for_route',
 'std_num_segs_leg1_comp_route_miniRules1_statusInfos',
 'route_min_segments_available']

In [0]:
!pip install catboost

Collecting catboost
  Obtaining dependency information for catboost from https://files.pythonhosted.org/packages/e2/47/abee19aae4b2a2a21e40e3c09db784099d189b3a0745e59c1d152700d90a/catboost-1.2.8-cp311-cp311-manylinux2014_x86_64.whl.metadata
  Using cached catboost-1.2.8-cp311-cp311-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting graphviz (from catboost)
  Obtaining dependency information for graphviz from https://files.pythonhosted.org/packages/91/4c/e0ce1ef95d4000ebc1c11801f9b944fa5910ecc15b5e351865763d8657f8/graphviz-0.21-py3-none-any.whl.metadata
  Using cached graphviz-0.21-py3-none-any.whl.metadata (12 kB)
Using cached catboost-1.2.8-cp311-cp311-manylinux2014_x86_64.whl (99.2 MB)
Using cached graphviz-0.21-py3-none-any.whl (47 kB)
Installing collected packages: graphviz, catboost
Successfully installed catboost-1.2.8 graphviz-0.21
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import catboost

In [0]:
try:
    del _
except:
    pass

for i in range(10):
  gc.collect()

In [0]:
# --- Main Training Loop ---
all_fold_importances = []
all_test_predictions = []
all_oof_preds = []

for fold in range(5):
    print(f"\n=== FOLD {fold} ===")
    
    # --- Training data preparation ---
    train_fold = train.filter(pl.col('fold') != fold)
    val_fold = train.filter(pl.col('fold') == fold)
    
    # 1. Calculate the size of each group.
    # group_sizes = train_fold.group_by('ranker_id').count().rename({'count': 'group_size'})
    # bins = np.arange(0, 51, 5).tolist()  # Bins of size 5 for groups up to 50
    # bins.extend([100,150, 200,250,350, 500,750,1000,2000, 1e8]) # Add bins for the long tail
    # bins = sorted(list(set(bins))) # Ensure bins are unique and sorted

    # group_sizes = group_sizes.with_columns(
    #         pl.col('group_size').cut(breaks=bins).alias('size_stratum')
    #     )
    
    # # 3. Define the fraction of groups to sample from each stratum.
    # sample_fraction = 0.8
    # sampled_ranker_ids_by_stratum = group_sizes.group_by('size_stratum', maintain_order=False).agg(
    #     pl.col('ranker_id').sample(fraction=sample_fraction, shuffle=True, seed=42)
    # )
    # sampled_groups = sampled_ranker_ids_by_stratum.explode('ranker_id')['ranker_id'].to_list()
    
    # # 5. Filter the training fold to keep only the rows from the sampled groups.
    # train_fold = train_fold.filter(pl.col('ranker_id').is_in(sampled_groups))

    groups_tr = train_fold.select('ranker_id').to_numpy().flatten()
    groups_va = val_fold.select('ranker_id').to_numpy().flatten()

    X_tr = train_fold.drop(drop)
    features = X_tr.columns # Define features based on the training data for this fold
    X_tr = X_tr.select(features)
    X_tr = X_tr.to_numpy()
    y_tr = train_fold['selected'].to_numpy().astype(np.int32)
    
    del train_fold
    gc.collect()
    print(f"Fold {fold} training data shape: {X_tr.shape}")
    
    X_val = val_fold.drop(drop)
    # Ensure validation set has the same columns as training
    X_val = X_val.select(features)
    X_val = X_val.to_numpy()
    y_val = val_fold['selected'].to_numpy().astype(np.int32)
    print(f"Fold {fold} validation data shape: {X_val.shape}")
    
    # --- CatBoost Pool Creation ---
    # Pool is CatBoost's internal data structure for efficiency
    train_pool = catboost.Pool(
        data=X_tr, 
        label=y_tr, 
        group_id=groups_tr, 
        feature_names=features
    )
    val_pool = catboost.Pool(
        data=X_val, 
        label=y_val, 
        group_id=groups_va, 
        feature_names=features
    )
    
    # --- CatBoost Model Training ---
    # Define model parameters
    catboost_params = {
        'loss_function': 'PairLogit:max_pairs=32',#'QueryRMSE',#'QuerySoftMax',#'PairLogit:max_pairs=32',#'PairLogit',#'YetiRank',
        'eval_metric': 'NDCG:top=3',
        'iterations': 3000,        
        'learning_rate': 0.1,     
        'depth': 12,               
        'subsample': 0.8,        
        'colsample_bylevel': 0.8, 
        'l2_leaf_reg': 10.0,
        'random_seed': 42,
        'thread_count': -1,
        'verbose': 100,
        # 'early_stopping_rounds': 50,  # ENABLED early stopping
        # 'border_count': 8,       # REDUCED from default 254 to 32
        # 'task_type': 'CPU',       # Explicit CPU optimization
        # 'bootstrap_type': 'Bernoulli',  # Faster than default,
        # 'leaf_estimation_iterations': 1,  # Speed up leaf estimation
        # 'score_function': 'Cosine',       # Faster than default
    }
    
    print(f"Training CatBoost model for fold {fold}...")
    cat_model = catboost.CatBoostRanker(**catboost_params)
    
    cat_model.fit(
        train_pool,
        eval_set=[train_pool, val_pool]
    )
    
    del train_pool, X_tr, y_tr, groups_tr
    gc.collect()

    # --- Validation Predictions & Metric ---
    print(f"Calculating validation hit rate @ 3 for fold {fold}...")
    
    # Get validation predictions
    val_preds = cat_model.predict(val_pool)

    # Store Out-of-Fold (OOF) predictions
    oof_df = val_fold.select(['Id', 'ranker_id', 'selected', 'fold']).with_columns(
        pl.Series("oof_prediction", val_preds)
    )
    all_oof_preds.append(oof_df)
    print(f"✅ Stored OOF predictions for fold {fold}")

    del val_pool
    gc.collect()

    # Calculate hit rate @ 3
    fold_hitrate = hitrate_at_3(y_val, val_preds, groups_va)
    print(f"Fold {fold} Hit Rate @ 3: {fold_hitrate:.4f}")
    
    del X_val, y_val, groups_va
    gc.collect()
    
    # --- Test Predictions ---
    print(f"Making test predictions for fold {fold}...")

    test_fold = test.clone()
    groups_te = test_fold['ranker_id'].to_numpy().flatten()

    # Rename fold-specific columns to match the general feature names used in training
    fold_specific_cols = [col for col in test_fold.columns if col.endswith(f'_fold{fold}')]
    rename_mapping = {col: col.replace(f'_fold{fold}', '') for col in fold_specific_cols}
    
    if rename_mapping:
        print(f"Renaming {len(rename_mapping)} fold-specific columns for fold {fold}")
        test_fold = test_fold.rename(rename_mapping)
    
    # Select the features used in training, ensuring order is the same
    try:
        X_test = test_fold.select(features)
    except Exception as e:
        print(f"Error selecting test features: {e}")
        missing = set(features) - set(test_fold.columns)
        print(f"Missing features in test set: {missing}")
        raise

    X_test = X_test.to_numpy()
    
    # Create a Pool for test data (no labels)
    test_pool = catboost.Pool(data=X_test, group_id=groups_te, feature_names=features)
    print(f"Test data shape for fold {fold}: {X_test.shape}")
    
    # Make predictions
    test_preds = cat_model.predict(test_pool)
    
    # Store predictions with metadata for ensembling
    fold_predictions = {
        'fold': fold,
        'predictions': test_preds,
        'test_ids': test_fold.select(['Id', 'ranker_id']).to_pandas()
    }
    all_test_predictions.append(fold_predictions)
    
    print(f"✅ Test predictions completed for fold {fold}")
    
    # Save the trained model
    model_path = f"{save_path}/catboost_ranker_fold_{fold}.cbm"
    cat_model.save_model(model_path)
    print(f"✅ Saved model to {model_path}")
    
    del cat_model, test_fold, X_test, test_preds, test_pool
    gc.collect()

print("\n" + "="*20)
print("=== TRAINING COMPLETE ===")
print("="*20)
print(f"Trained {len(all_fold_importances)} models")
print(f"Generated predictions from {len(all_test_predictions)} folds")

# --- Save Combined Out-of-Fold Predictions ---
print("\n=== SAVING FINAL OOF PREDICTIONS ===")
if all_oof_preds:
    final_oof_df = pl.concat(all_oof_preds).sort("Id")
    oof_path = f'{save_path}/ensemble_oof_predictions.parquet'
    final_oof_df.write_parquet(oof_path)
    print(f"✅ Saved final OOF predictions to {oof_path}")
    print(f"Final OOF DataFrame shape: {final_oof_df.shape}")
else:
    print("⚠️ No OOF predictions were generated.")

# --- Ensemble Test Predictions ---
print("\n=== CREATING ENSEMBLE PREDICTIONS ===")
if all_test_predictions:
    # Average predictions across all folds
    ensemble_preds = np.mean([fold_pred['predictions'] for fold_pred in all_test_predictions], axis=0)

    # Use IDs from the first fold's predictions as a base
    final_predictions = all_test_predictions[0]['test_ids'].copy()
    final_predictions['prediction'] = ensemble_preds

    print(f"Final ensemble predictions shape: {final_predictions.shape}")

    # --- Rank Within Each Group ---
    print("\n=== RANKING WITHIN RANKER_ID GROUPS ===")
    final_predictions_pl = pl.from_pandas(final_predictions)
    final_predictions_ranked = final_predictions_pl.with_columns(
        pl.col('prediction').rank(method='ordinal', descending=True).over('ranker_id').alias('rank')
    )

    # Save ranked predictions
    ranked_path = f'{save_path}/ensemble_test_predictions_ranked.csv'
    final_predictions_ranked.write_csv(ranked_path)
    print(f"✅ Saved ranked predictions to {ranked_path}")

    print("\n=== FINAL SUMMARY ===")
    print(f"Total predictions: {len(final_predictions_ranked)}")
    print(f"Unique ranker_ids: {final_predictions_ranked['ranker_id'].n_unique()}")

else:
    print("⚠️ No test predictions were generated to ensemble.")


=== FOLD 0 ===
Fold 0 training data shape: (14516297, 274)
Fold 0 validation data shape: (3629075, 274)
Training CatBoost model for fold 0...
0:	test: 0.2722000	test1: 0.2723617	best: 0.2723617 (0)	total: 5.05s	remaining: 4h 12m 39s
100:	test: 0.5543184	test1: 0.4883936	best: 0.4883936 (100)	total: 8m 29s	remaining: 4h 3m 37s
200:	test: 0.6469125	test1: 0.5105529	best: 0.5105529 (200)	total: 16m 48s	remaining: 3h 54m 8s
300:	test: 0.7205995	test1: 0.5224000	best: 0.5224000 (300)	total: 24m 59s	remaining: 3h 44m 1s
400:	test: 0.7799833	test1: 0.5316445	best: 0.5316784 (399)	total: 33m	remaining: 3h 33m 55s
500:	test: 0.8228590	test1: 0.5369195	best: 0.5373538 (497)	total: 41m 1s	remaining: 3h 24m 39s
600:	test: 0.8578240	test1: 0.5422736	best: 0.5424365 (597)	total: 49m 3s	remaining: 3h 15m 49s
700:	test: 0.8821945	test1: 0.5467574	best: 0.5468871 (699)	total: 57m 15s	remaining: 3h 7m 47s
800:	test: 0.9010508	test1: 0.5502067	best: 0.5502067 (795)	total: 1h 5m 29s	remaining: 2h 59m 49s