In [1]:
import sys
import os
from dotenv import load_dotenv

root_dir = os.path.abspath("..")
sys.path.append(root_dir)
dotenv_path = os.path.join(root_dir, ".env")
load_dotenv(dotenv_path)

True

### Country code to IATA

In [2]:
airline_country_map_iata = {
    "Afghanistan": "AF",
    "Albania": "AL",
    "Algeria": "DZ",
    "Angola": "AO",
    "Antigua and Barbuda": "AG",
    "Argentina": "AR",
    "Armenia": "AM",
    "Aruba": "AW",
    "Australia": "AU",
    "Austria": "AT",
    "Azerbaijan": "AZ",
    "Bahamas": "BS",
    "Bahrain": "BH",
    "Bangladesh": "BD",
    "Belarus": "BY",
    "Belgium": "BE",
    "Belize": "BZ",
    "Benin": "BJ",
    "Bhutan": "BT",
    "Bolivia": "BO",
    "Botswana": "BW",
    "Brazil": "BR",
    "Brunei": "BN",
    "Bulgaria": "BG",
    "Burkina Faso": "BF",
    "Cambodia": "KH",
    "Cameroon": "CM",
    "Canada": "CA",
    "Cape Verde": "CV",
    "Cayman Islands": "KY",
    "Chile": "CL",
    "China": "CN",
    "Colombia": "CO",
    "Cook Islands": "CK",
    "Costa Rica": "CR",
    "Croatia": "HR",
    "Cuba": "CU",
    "Cyprus": "CY",
    "Czech Republic": "CZ",
    "Democratic Republic of the Congo": "CD",
    "Denmark": "DK",
    "Djibouti": "DJ",
    "Dominican Republic": "DO",
    "Ecuador": "EC",
    "Egypt": "EG",
    "Eritrea": "ER",
    "Estonia": "EE",
    "Ethiopia": "ET",
    "Fiji": "FJ",
    "Finland": "FI",
    "France": "FR",
    "French Polynesia": "PF",
    "Gabon": "GA",
    "Gambia": "GM",
    "Georgia": "GE",
    "Germany": "DE",
    "Ghana": "GH",
    "Greece": "GR",
    "Greenland": "GL",
    "Guatemala": "GT",
    "Guinea": "GN",
    "Guinea-Bissau": "GW",
    "Haiti": "HT",
    "Honduras": "HN",
    "Hong Kong": "HK",
    "Hungary": "HU",
    "Iceland": "IS",
    "India": "IN",
    "Indonesia": "ID",
    "Iran": "IR",
    "Iraq": "IQ",
    "Ireland": "IE",
    "Israel": "IL",
    "Italy": "IT",
    "Japan": "JP",
    "Jordan": "JO",
    "Kazakhstan": "KZ",
    "Kenya": "KE",
    "Kuwait": "KW",
    "Kyrgyzstan": "KG",
    "Lao People's Democratic Republic": "LA",
    "Latvia": "LV",
    "Lebanon": "LB",
    "Libya": "LY",
    "Lithuania": "LT",
    "Luxembourg": "LU",
    "Macao": "MO",
    "Madagascar": "MG",
    "Malawi": "MW",
    "Malaysia": "MY",
    "Maldives": "MV",
    "Malta": "MT",
    "Marshall Islands": "MH",
    "Mauritius": "MU",
    "Mexico": "MX",
    "Moldova": "MD",
    "Mongolia": "MN",
    "Montenegro": "ME",
    "Morocco": "MA",
    "Mozambique": "MZ",
    "Myanmar": "MM",
    "Namibia": "NA",
    "Nauru": "NR",
    "Nepal": "NP",
    "Netherlands": "NL",
    "New Zealand": "NZ",
    "Nigeria": "NG",
    "North Korea": "KP",
    "Norway": "NO",
    "Oman": "OM",
    "Pakistan": "PK",
    "Palau": "PW",
    "Panama": "PA",
    "Papua New Guinea": "PG",
    "Paraguay": "PY",
    "Peru": "PE",
    "Philippines": "PH",
    "Poland": "PL",
    "Portugal": "PT",
    "Qatar": "QA",
    "Republic of Korea": "KR",
    "Republic of the Congo": "CG",
    "Romania": "RO",
    "Russia": "RU",
    "Russian Federation": "RU",
    "Rwanda": "RW",
    "Samoa": "WS",
    "Saudi Arabia": "SA",
    "Serbia": "RS",
    "Seychelles": "SC",
    "Sierra Leone": "SL",
    "Singapore": "SG",
    "Slovakia": "SK",
    "Slovenia": "SI",
    "Solomon Islands": "SB",
    "South Africa": "ZA",
    "South Korea": "KR",
    "Spain": "ES",
    "Sri Lanka": "LK",
    "Sudan": "SD",
    "Suriname": "SR",
    "Sweden": "SE",
    "Switzerland": "CH",
    "Syrian Arab Republic": "SY",
    "Taiwan": "TW",
    "Tanzania": "TZ",
    "Thailand": "TH",
    "The Gambia": "GM",
    "Trinidad and Tobago": "TT",
    "Tunisia": "TN",
    "Turkey": "TR",
    "Turkiye": "TR",
    "Turkmenistan": "TM",
    "Turks and Caicos Islands": "TC",
    "UAE": "AE",
    "Uganda": "UG",
    "Ukraine": "UA",
    "United Arab Emirates": "AE",
    "United Kingdom": "GB",
    "United States": "US",
    "Uruguay": "UY",
    "Uzbekistan": "UZ",
    "Vanuatu": "VU",
    "Venezuela": "VE",
    "Vietnam": "VN",
    "Yemen": "YE",
    "Zambia": "ZM",
    "Zimbabwe": "ZW",
}

In [3]:
import polars as pl
from pydantic import BaseModel
from typing import Dict, Tuple, Optional
from math import radians, sin, cos, asin, sqrt
from datetime import datetime

In [4]:
train_file = os.path.join(root_dir, "kaggle", "train.parquet")
test_file = os.path.join(root_dir, "kaggle", "test.parquet")

In [5]:
# Additional Engineering file
def haversine(lat1, lon1, lat2, lon2) -> Optional[float]:
    if lat1 is None or lon1 is None or lat2 is None or lon2 is None:
        return None

    if lat1 != lat1 or lon1 != lon1 or lat2 != lat2 or lon2 != lon2:
        return None
        
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = sin(dlat/2)**2 + cos(lat1)*cos(lat2)*sin(dlon/2)**2
    c = 2 * asin(sqrt(a))
    return 6371 * c  # Earth radius in km

def build_airline_name_map(airline_df: pl.DataFrame) -> Tuple[Dict[str, str], Dict[str, bool], Dict[str, str]]:
    """Build lookup dict: IATA -> enriched description."""
    mapping: Dict[str, str] = {}
    lcc_map: Dict[str, bool] = {}
    nationality_map: Dict[str, str] = {}
    for row in airline_df.iter_rows(named=True):
        # airliner metadata
        code = row.get("IATA")
        if code is None:
            continue
        name = row.get("Airline")
        country = row.get("Country/Region")
        lcc = row.get("LCC")
        service_type = "LCC" if lcc else "full service carrier"
        mapping[code] = f"{name} ({country}'s {service_type})"

        # lcc map
        lcc_map[code] = lcc

        # nationality map
        nationality_map[code] = airline_country_map_iata.get(country, None)

    return mapping, lcc_map, nationality_map

def build_airport_maps(airport_df: pl.DataFrame) -> Tuple[Dict[str, float], Dict[str, Tuple[float, float]]]:
    """
    Build two dictionaries from a Polars airport dataframe:
      - airport_offset_map: { IATA -> UTC_Offset_Hours }
      - airport_coord_map: { IATA -> (GeoPointLat, GeoPointLong) }
    """
    meta_map: Dict[str, str] = {}
    offset_map: Dict[str, float] = {}
    coord_map: Dict[str, Tuple[float, float]] = {}
    country_map: Dict[str, str] = {}

    for row in airport_df.iter_rows(named=True):
        code = row.get("IATA")
        if code is None or code != code or code == "":
            continue  # skip invalid codes

        # location metadata
        city = row.get("City_Name")
        name = row.get("AirportName")
        
        country = row.get("Country_CodeA2")
        meta_map[code] = f"{city} {name} ({country})"
        if code and code == code and country and country == country and code.strip() != "":
            country_map[code] = country

        # UTC offset
        offset = row.get("UTC_Offset_Hours")
        if offset is not None and offset == offset:  # not NaN
            offset_map[code] = float(offset)

        # coordinates
        lat = row.get("GeoPointLat")
        lon = row.get("GeoPointLong")
        if (lat == lat) and (lon == lon):  # not NaN
            if not (lat == 0.0 and lon == 0.0):  # skip placeholder coords
                coord_map[code] = (float(lat), float(lon))

    return meta_map, country_map, offset_map, coord_map

def build_aircraft_map(aircraft_df: pl.DataFrame) -> Dict[str, str]:
    """
    Build lookup dict: aircraft_code (IATA) -> model description
    """
    aircraft_map = {}
    aircraft_size_map = {}
    for row in aircraft_df.iter_rows(named=True):
        code = row.get("IATA")
        model = row.get("Model")
        wtc = row.get("WTC")
        
        if code and code == code and model and model == model:
            aircraft_map[code] = model
        
        if wtc and wtc == wtc:
            aircraft_size_map[code] = wtc

    return aircraft_map, aircraft_size_map

# Cabin class map
CABIN_MAP_COMFORT = {
    1: 0,
    2: 1,
    3: 2,
    4: 4
}

WTC_MAP = {
    "L": "light aircraft",
    "M": "medium-size jet",
    "H": "heavy widebody jet",
    "J": "super heavy jet"
}

WTC_CAPACITY = {
    "L": 20,
    "M": 150,
    "H": 300,
    "J": 500
}

# Airline
airline_file = os.path.join(root_dir, "kaggle", "support", "airlines_lccs.csv")
airline = pl.read_csv(airline_file)

# Airport
airport_file = os.path.join(root_dir, "kaggle", "support", "Global Airports Dataset.csv")
airport = pl.read_csv(airport_file)

# Aircraft
aircraft_file = os.path.join(root_dir, "kaggle", "support", "Aircraft List.csv")
aircraft = pl.read_csv(aircraft_file)

airline_map, airline_lcc_map, airline_nat_map = build_airline_name_map(airline)
airport_map, airport_country_map, airport_offset_map, airport_coord_map = build_airport_maps(airport)
aircraft_map, aircraft_size_map = build_aircraft_map(aircraft)

In [6]:
class Airport(BaseModel):
    meta: dict
    country: dict
    timezone: dict
    coord: dict

class Aircraft(BaseModel):
    meta: dict
    size: dict

class Airline(BaseModel):
    meta: dict
    lcc: dict
    country: dict

In [7]:
airport_info = Airport(meta=airport_map, country=airport_country_map, timezone=airport_offset_map, coord=airport_coord_map)
aircraft_info = Aircraft(meta=aircraft_map, size=aircraft_size_map)
airline_info = Airline(meta=airline_map, lcc=airline_lcc_map, country=airline_nat_map)

In [8]:
def duration_to_minutes(val):
    if val is None or val != val:
        return None
    # If already numeric
    if isinstance(val, (int, float)):
        return int(val)
    # If string like "02:40:00"
    if isinstance(val, str):
        parts = val.split(":")
        try:
            if len(parts) == 3:
                hours, minutes, seconds = map(int, parts)
                return hours * 60 + minutes + (seconds // 60)
            elif len(parts) == 2:
                hours, minutes = map(int, parts)
                return hours * 60 + minutes
            elif len(parts) == 1:
                return int(parts[0])
        except ValueError:
            return None
    return None

def engineer_all_features(
    row: dict,
    airline_info: Airline,
    aircraft_info: Aircraft,
    airport_info: Airport,
) -> dict:
    # LEG0 - GO
    # leg - segment wise processing
    go_seg_number = [  # segment0, segment1, segment2
        row.get("legs0_segments0_marketingCarrier_code", None), 
        row.get("legs0_segments1_marketingCarrier_code", None), 
        row.get("legs0_segments2_marketingCarrier_code", None),
    ]
    go_seg_number = len([i for i in go_seg_number if i is not None])  # Exchange or not
    go_flight_time = [  # segment0, segment1, segment2
        duration_to_minutes(row.get("legs0_segments0_duration", 0)), 
        duration_to_minutes(row.get("legs0_segments1_duration", 0)), 
        duration_to_minutes(row.get("legs0_segments2_duration", 0)),
    ]
    go_total_flight_time = sum([i for i in go_flight_time if i is not None])
    go_exchange_wait = duration_to_minutes(row.get("legs0_duration", 0)) - go_total_flight_time if duration_to_minutes(row.get("legs0_duration", 0)) is not None else None
    go_bag_allow = [  # bag
        row.get("legs0_segments0_baggageAllowance_quantity", 0), 
        row.get("legs0_segments1_baggageAllowance_quantity", 0), 
        row.get("legs0_segments2_baggageAllowance_quantity", 0),
    ]
    go_plane_size = [  # aircraft size feature
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs0_segments0_aircraft_code", None), None), None), 
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs0_segments1_aircraft_code", None), None), None), 
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs0_segments2_aircraft_code", None), None), None)
    ]
    go_travel_distance = [  # Total miles accrual
        haversine(
            airport_info.coord.get(row.get("legs0_segments0_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments0_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs0_segments0_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments0_arrivalTo_airport_iata", None), [None, None])[1]
        ),
        haversine(
            airport_info.coord.get(row.get("legs0_segments1_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments1_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs0_segments1_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments1_arrivalTo_airport_iata", None), [None, None])[1]
        ),
        haversine(
            airport_info.coord.get(row.get("legs0_segments2_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments2_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs0_segments2_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs0_segments2_arrivalTo_airport_iata", None), [None, None])[1]
        ),
    ]
    go_total_travel_distance = sum([i for i in go_travel_distance if i is not None])
    go_airline_is_lcc = [
        airline_info.lcc.get(row.get("legs0_segments0_operatingCarrier_code", None), None), 
        airline_info.lcc.get(row.get("legs0_segments1_operatingCarrier_code", None), None), 
        airline_info.lcc.get(row.get("legs0_segments2_operatingCarrier_code", None), None),
    ]
    go_airline_seat_avail = [
        row.get("legs0_segments0_seatsAvailable", 0), 
        row.get("legs0_segments1_seatsAvailable", 0), 
        row.get("legs0_segments2_seatsAvailable", 0),
    ]
    go_seat_types = [
        CABIN_MAP_COMFORT.get(row.get("legs0_segments0_cabinClass", None), None), 
        CABIN_MAP_COMFORT.get(row.get("legs0_segments1_cabinClass", None), None), 
        CABIN_MAP_COMFORT.get(row.get("legs0_segments2_cabinClass", None), None),
    ]

    go_international = any([
        (
            airport_info.country.get(row.get("legs0_segments0_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs0_segments0_arrivalTo_airport_iata", None), None)
        ),
        (
            airport_info.country.get(row.get("legs0_segments1_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs0_segments1_arrivalTo_airport_iata", None), None)
        ),
        (
            airport_info.country.get(row.get("legs0_segments2_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs0_segments2_arrivalTo_airport_iata", None), None)
        ),
    ])

    # LEG1 - RETURN (RTN)
    # leg - segment wise processing
    rtn_seg_number = [  # segment0, segment1, segment2
        row.get("legs1_segments0_marketingCarrier_code", None),
        row.get("legs1_segments1_marketingCarrier_code", None),
        row.get("legs1_segments2_marketingCarrier_code", None),
    ]
    rtn_seg_number = len([i for i in rtn_seg_number if i is not None])  # Exchange or not
    rtn_flight_time = [  # segment0, segment1, segment2
        duration_to_minutes(row.get("legs1_segments0_duration", 0)), 
        duration_to_minutes(row.get("legs1_segments1_duration", 0)), 
        duration_to_minutes(row.get("legs1_segments2_duration", 0)),
    ]
    rtn_total_flight_time = sum([i for i in rtn_flight_time if i is not None])
    rtn_exchange_wait = duration_to_minutes(row.get("legs1_duration", 0)) - rtn_total_flight_time if duration_to_minutes(row.get("legs1_duration", 0)) is not None else None
    rtn_bag_allow = [  # bag
        row.get("legs1_segments0_baggageAllowance_quantity", 0), 
        row.get("legs1_segments1_baggageAllowance_quantity", 0), 
        row.get("legs1_segments2_baggageAllowance_quantity", 0),
    ]
    rtn_plane_size = [  # aircraft size feature
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs1_segments0_aircraft_code", None), None), None), 
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs1_segments1_aircraft_code", None), None), None), 
        WTC_CAPACITY.get(aircraft_info.size.get(row.get("legs1_segments2_aircraft_code", None), None), None)
    ]
    rtn_travel_distance = [  # Total miles accrual
        haversine(
            airport_info.coord.get(row.get("legs1_segments0_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments0_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs1_segments0_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments0_arrivalTo_airport_iata", None), [None, None])[1]
        ),
        haversine(
            airport_info.coord.get(row.get("legs1_segments1_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments1_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs1_segments1_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments1_arrivalTo_airport_iata", None), [None, None])[1]
        ),
        haversine(
            airport_info.coord.get(row.get("legs1_segments2_departureFrom_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments2_departureFrom_airport_iata", None), [None, None])[1],
            airport_info.coord.get(row.get("legs1_segments2_arrivalTo_airport_iata", None), [None, None])[0],
            airport_info.coord.get(row.get("legs1_segments2_arrivalTo_airport_iata", None), [None, None])[1]
        ),
    ]
    rtn_total_travel_distance = sum([i for i in rtn_travel_distance if i is not None])
    rtn_airline_is_lcc = [
        airline_info.lcc.get(row.get("legs1_segments0_operatingCarrier_code", None), None), 
        airline_info.lcc.get(row.get("legs1_segments1_operatingCarrier_code", None), None), 
        airline_info.lcc.get(row.get("legs1_segments2_operatingCarrier_code", None), None),
    ]
    rtn_airline_seat_avail = [
        row.get("legs1_segments0_seatsAvailable", 0), 
        row.get("legs1_segments1_seatsAvailable", 0), 
        row.get("legs1_segments2_seatsAvailable", 0),
    ]
    rtn_seat_types = [
        CABIN_MAP_COMFORT.get(row.get("legs1_segments0_cabinClass", None), None), 
        CABIN_MAP_COMFORT.get(row.get("legs1_segments1_cabinClass", None), None), 
        CABIN_MAP_COMFORT.get(row.get("legs1_segments2_cabinClass", None), None),
    ]

    rtn_international = any([
        (
            airport_info.country.get(row.get("legs1_segments0_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs1_segments0_arrivalTo_airport_iata", None), None)
        ),
        (
            airport_info.country.get(row.get("legs1_segments1_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs1_segments1_arrivalTo_airport_iata", None), None)
        ),
        (
            airport_info.country.get(row.get("legs1_segments2_departureFrom_airport_iata", None), None) 
            != airport_info.country.get(row.get("legs1_segments2_arrivalTo_airport_iata", None), None)
        ),
    ])

    result = {
        "go_seg_number": go_seg_number,
        "go_flight_time_0": go_flight_time[0],
        "go_flight_time_1": go_flight_time[1],
        "go_flight_time_2": go_flight_time[2],
        "go_total_flight_time": go_total_flight_time,
        "go_exchange_wait": go_exchange_wait,
        "go_bag_allow_0": go_bag_allow[0],
        "go_bag_allow_1": go_bag_allow[1],
        "go_bag_allow_2": go_bag_allow[2],
        "go_plane_size_0": go_plane_size[0],
        "go_plane_size_1": go_plane_size[1],
        "go_plane_size_2": go_plane_size[2],
        "go_travel_distance_0": go_travel_distance[0],
        "go_travel_distance_1": go_travel_distance[1],
        "go_travel_distance_2": go_travel_distance[2],
        "go_travel_distance": go_total_travel_distance,
        "go_airline_is_lcc_0": go_airline_is_lcc[0],
        "go_airline_is_lcc_1": go_airline_is_lcc[1],
        "go_airline_is_lcc_2": go_airline_is_lcc[2],
        "go_airline_seat_avail_0": go_airline_seat_avail[0],
        "go_airline_seat_avail_1": go_airline_seat_avail[1],
        "go_airline_seat_avail_2": go_airline_seat_avail[2],
        "go_seat_types_0": go_seat_types[0],
        "go_seat_types_1": go_seat_types[1],
        "go_seat_types_2": go_seat_types[2],
        "go_international": go_international,

        "rtn_seg_number": rtn_seg_number,
        "rtn_flight_time_0": rtn_flight_time[0],
        "rtn_flight_time_1": rtn_flight_time[1],
        "rtn_flight_time_2": rtn_flight_time[2],
        "rtn_total_flight_time": rtn_total_flight_time,
        "rtn_exchange_wait": rtn_exchange_wait,
        "rtn_bag_allow_0": rtn_bag_allow[0],
        "rtn_bag_allow_1": rtn_bag_allow[1],
        "rtn_bag_allow_2": rtn_bag_allow[2],
        "rtn_plane_size_0": rtn_plane_size[0],
        "rtn_plane_size_1": rtn_plane_size[1],
        "rtn_plane_size_2": rtn_plane_size[2],
        "rtn_travel_distance_0": rtn_travel_distance[0],
        "rtn_travel_distance_1": rtn_travel_distance[1],
        "rtn_travel_distance_2": rtn_travel_distance[2],
        "rtn_travel_distance": rtn_total_travel_distance,
        "rtn_airline_is_lcc_0": rtn_airline_is_lcc[0],
        "rtn_airline_is_lcc_1": rtn_airline_is_lcc[1],
        "rtn_airline_is_lcc_2": rtn_airline_is_lcc[2],
        "rtn_airline_seat_avail_0": rtn_airline_seat_avail[0],
        "rtn_airline_seat_avail_1": rtn_airline_seat_avail[1],
        "rtn_airline_seat_avail_2": rtn_airline_seat_avail[2],
        "rtn_seat_types_0": rtn_seat_types[0],
        "rtn_seat_types_1": rtn_seat_types[1],
        "rtn_seat_types_2": rtn_seat_types[2],
        "rtn_international": rtn_international,
    }

    return result

In [9]:
SCHEMA = {
    "go_seg_number": pl.Int64,
    "go_flight_time_0": pl.Int64,
    "go_flight_time_1": pl.Int64,
    "go_flight_time_2": pl.Int64,
    "go_total_flight_time": pl.Int64,
    "go_exchange_wait": pl.Int64,
    "go_bag_allow_0": pl.Float64,
    "go_bag_allow_1": pl.Float64,
    "go_bag_allow_2": pl.Float64,
    "go_plane_size_0": pl.Int64,
    "go_plane_size_1": pl.Int64,
    "go_plane_size_2": pl.Int64,
    "go_travel_distance_0": pl.Int64,
    "go_travel_distance_1": pl.Int64,
    "go_travel_distance_2": pl.Int64,
    "go_travel_distance": pl.Int64,
    "go_airline_is_lcc_0": pl.Int8,
    "go_airline_is_lcc_1": pl.Int8,
    "go_airline_is_lcc_2": pl.Int8,
    "go_airline_seat_avail_0": pl.Int64,
    "go_airline_seat_avail_1": pl.Int64,
    "go_airline_seat_avail_2": pl.Int64,
    "go_seat_types_0": pl.Int64,
    "go_seat_types_1": pl.Int64,
    "go_seat_types_2": pl.Int64,
    "go_international": pl.Boolean,

    "rtn_seg_number": pl.Int64,
    "rtn_flight_time_0": pl.Int64,
    "rtn_flight_time_1": pl.Int64,
    "rtn_flight_time_2": pl.Int64,
    "rtn_total_flight_time": pl.Int64,
    "rtn_exchange_wait": pl.Int64,
    "rtn_bag_allow_0": pl.Float64,
    "rtn_bag_allow_1": pl.Float64,
    "rtn_bag_allow_2": pl.Float64,
    "rtn_plane_size_0": pl.Int64,
    "rtn_plane_size_1": pl.Int64,
    "rtn_plane_size_2": pl.Int64,
    "rtn_travel_distance_0": pl.Int64,
    "rtn_travel_distance_1": pl.Int64,
    "rtn_travel_distance_2": pl.Int64,
    "rtn_travel_distance": pl.Int64,
    "rtn_airline_is_lcc_0": pl.Int8,
    "rtn_airline_is_lcc_1": pl.Int8,
    "rtn_airline_is_lcc_2": pl.Int8,
    "rtn_airline_seat_avail_0": pl.Int64,
    "rtn_airline_seat_avail_1": pl.Int64,
    "rtn_airline_seat_avail_2": pl.Int64,
    "rtn_seat_types_0": pl.Int64,
    "rtn_seat_types_1": pl.Int64,
    "rtn_seat_types_2": pl.Int64,
    "rtn_international": pl.Boolean,
}

In [10]:
def process_flight_parquet(
    parquet_path: str,
    airline_info: Airline,
    aircraft_info: Aircraft,
    airport_info: Airport,
    chunk_size: int = 200_000,
    output_path: str = None,
):
    def sanitize_feature_row(row: dict) -> dict:
        def clean(x):
            if isinstance(x, (int, float, str, bool)) or x is None:
                return x
            try:
                return float(x)
            except:
                return None
        return {k: clean(v) for k, v in row.items()}

    scan = pl.scan_parquet(parquet_path)
    row_count = scan.select(pl.len()).collect(engine="streaming")[0, 0]

    all_rows = [] if output_path is None else None

    for start in range(0, row_count, chunk_size):
        batch_df = scan.slice(start, chunk_size).collect(engine="streaming")
        feats = []

        for row in batch_df.iter_rows(named=True):
            try:
                feats.append(engineer_all_features(row, airline_info, aircraft_info, airport_info))
            except Exception as e:
                print(f"Error processing row {row}: {e}")
                raise e
        feats = [sanitize_feature_row(row) for row in feats]
        result_df = pl.DataFrame(feats, schema=SCHEMA)

        if output_path:
            result_df.write_parquet(f"{output_path}_part_{start // chunk_size}.parquet")
        else:
            all_rows.append(result_df)

        del batch_df, result_df, feats

        print("[INFO] chunk_done")

    if output_path is None:
        return pl.concat(all_rows)

In [None]:
flight_train = process_flight_parquet(
    parquet_path=train_file,
    airline_info=airline_info,
    aircraft_info=aircraft_info,
    airport_info=airport_info,
    chunk_size=100_000
)

flight_test = process_flight_parquet(
    parquet_path=test_file,
    airline_info=airline_info,
    aircraft_info=aircraft_info,
    airport_info=airport_info,
    chunk_size=100_000
)

In [223]:
flight_train = flight_train.with_row_index(name='row_id')
flight_train.write_parquet(os.path.join(root_dir, "data", "v1", "processed_flight_features_train.parquet"))

flight_test = flight_test.with_row_index(name='row_id')
flight_test.write_parquet(os.path.join(root_dir, "data", "v1", "processed_flight_features_test.parquet"))