In [None]:
"""
Week 1 — BTS On-Time Performance: Data Acquisition & Initial EDA

Pipeline:
1) Download monthly ZIPs from BTS On-Time (1987–present)
   https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_YYYY_M.zip
2) Unzip CSVs
3) Concatenate → write Parquet (full + year/month partitions)
4) Initial EDA: columns/dtypes, missingness, numeric summary, correlations, plots (PNG)
5) (Optional) Spark demo (--use-spark)

Install:
pip install numpy pandas pyarrow fastparquet matplotlib tqdm scikit-learn
# Optional: pip install pyspark
"""

import argparse
import re
import time
import zipfile
import urllib.request
from urllib.error import HTTPError, URLError
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd

# ---------------------------
# CLI (robust to Jupyter args)
# ---------------------------
def parse_args(argv=None):
    parser = argparse.ArgumentParser(
        description="Week 1 — BTS On-Time: download → unzip → parquet → EDA"
    )
    parser.add_argument("--years", nargs="+", type=int, default=None,
                        help="Year list, e.g., 2024 2025")
    parser.add_argument("--months", nargs="+", type=int, default=None,
                        help="Month list (1-12), can cross years")
    parser.add_argument("--auto-recent", type=int, default=0,
                        help="If >0, pull the most recent N months (overrides years/months)")
    parser.add_argument("--outdir", type=str, default="./bts_on_time_data",
                        help="Output root directory")
    parser.add_argument("--limit-csv-rows-per-file", type=int, default=None,
                        help="Max rows per CSV to read (for memory-limited dry runs)")
    parser.add_argument("--sample-n", type=int, default=100_000,
                        help="Save a Parquet sample for fast iteration")
    parser.add_argument("--use-spark", action="store_true",
                        help="Run a simple Spark aggregation demo (requires pyspark)")
    # Use parse_known_args so Jupyter’s --f/−f does not crash
    args, unknown = parser.parse_known_args(argv)
    if unknown:
        print(f"[INFO] Ignoring unknown arguments: {unknown}")
    return args

def recent_year_month_pairs(n_recent: int):
    assert n_recent > 0
    now = datetime.utcnow()
    y, m = now.year, now.month
    pairs = []
    for _ in range(n_recent):
        pairs.append((y, m))
        m -= 1
        if m == 0:
            y -= 1
            m = 12
    pairs.reverse()
    years = {}
    for y, m in pairs:
        years.setdefault(y, []).append(m)
    years = {y: sorted(ms) for y, ms in years.items()}
    return years

# ---------------------------
# Download & unzip
# ---------------------------
BASE = "https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_{y}_{m}.zip"

def download_one(y: int, m: int, out: Path, retries: int = 2, sleep=2):
    url = BASE.format(y=y, m=m)
    if out.exists() and out.stat().st_size > 0:
        return True
    for k in range(retries + 1):
        try:
            print(f"[DL] {url}")
            urllib.request.urlretrieve(url, out.as_posix())
            if out.exists() and out.stat().st_size > 0:
                return True
        except (HTTPError, URLError) as e:
            print(f"[WARN] download failed {y}-{m}: {e}")
            if k < retries:
                time.sleep(sleep * (k + 1))
    return False

def unzip_all(zip_paths, out_dir: Path):
    n = 0
    for zp in zip_paths:
        try:
            with zipfile.ZipFile(zp, 'r') as zf:
                for name in zf.namelist():
                    if name.lower().endswith(".csv"):
                        zf.extract(name, out_dir)
                        n += 1
        except zipfile.BadZipFile:
            print(f"[WARN] bad zip: {zp}")
    print(f"[UNZIP] Extracted CSV files: {n}")
    return n

# ---------------------------
# Read → concat → Parquet
# ---------------------------
def infer_year_month_from_name(p: Path):
    m = re.search(r"(\d{4})[_-](\d{1,2})", p.name)
    if m:
        return int(m.group(1)), int(m.group(2))
    return None, None

def read_concat_csvs(csv_files, per_file_nrows=None):
    assert len(csv_files) > 0, "No CSV files found. Check download/unzip steps."
    print(f"[READ] CSV files: {len(csv_files)}")

    sample = pd.read_csv(csv_files[0], nrows=1000, low_memory=False)
    print("[COLUMNS-SAMPLE]", list(sample.columns)[:20], f"... total: {len(sample.columns)}")

    dfs = []
    for i, f in enumerate(csv_files, 1):
        print(f"[READ] {i}/{len(csv_files)} {f.name}")
        df = pd.read_csv(f, nrows=per_file_nrows, low_memory=False)
        if "Year" not in df.columns or "Month" not in df.columns:
            y, m = infer_year_month_from_name(f)
            if y is not None and "Year" not in df.columns:
                df["Year"] = y
            if m is not None and "Month" not in df.columns:
                df["Month"] = m
        dfs.append(df)

    full = pd.concat(dfs, ignore_index=True)
    print("[CONCAT] shape:", full.shape)
    return full

def write_parquet(full_df: pd.DataFrame, parquet_dir: Path):
    parquet_dir.mkdir(parents=True, exist_ok=True)
    all_path = parquet_dir / "bts_on_time_all.parquet"
    try:
        full_df.to_parquet(all_path, index=False)
        print("[PARQUET] Saved:", all_path.resolve())
    except Exception as e:
        print("[WARN] Failed to save all.parquet:", e)

    try:
        part_cols = [c for c in ("Year", "Month") if c in full_df.columns]
        if part_cols:
            base = parquet_dir / "by_year_month"
            for keys, g in full_df.groupby(part_cols):
                keys = keys if isinstance(keys, tuple) else (keys,)
                sub = base / "/".join(f"{c}={v}" for c, v in zip(part_cols, keys))
                sub.mkdir(parents=True, exist_ok=True)
                g.to_parquet(sub / "part.parquet", index=False)
            print("[PARQUET] Partitioned parquet under:", (parquet_dir / "by_year_month").resolve())
        else:
            print("[WARN] Year/Month not found; skip partition writing.")
    except Exception as e:
        print("[WARN] Failed partition parquet:", e)

# ---------------------------
# EDA
# ---------------------------
def eda_basic(df: pd.DataFrame, out_dir: Path, sample_n: int):
    out_dir.mkdir(parents=True, exist_ok=True)

    print("\n[EDA] columns:")
    print(df.columns.tolist())

    print("\n[EDA] dtypes:")
    print(df.dtypes)

    miss = df.isna().sum().sort_values(ascending=False)
    miss_pct = (miss / len(df)).round(4)
    pd.DataFrame({"missing": miss, "missing_pct": miss_pct}).head(25).to_csv(out_dir / "missing_top25.csv")
    print("[EDA] missing_top25.csv saved")

    num_desc = df.describe(include="number").transpose().sort_values("count", ascending=False)
    num_desc.to_csv(out_dir / "numeric_describe.csv")
    print("[EDA] numeric_describe.csv saved")

    numeric_cols = df.select_dtypes(include="number").columns.tolist()
    corr = df[numeric_cols].corr(numeric_only=True)
    keep = [c for c in numeric_cols if ("Delay" in c or "Taxi" in c or c == "Distance")]
    corr.loc[[c for c in keep if c in corr.index], [c for c in keep if c in corr.columns]].round(3)\
        .to_csv(out_dir / "corr_subset.csv")
    print("[EDA] corr_subset.csv saved")

    import matplotlib.pyplot as plt

    if "ArrDelay" in df.columns:
        plt.figure()
        df["ArrDelay"].dropna().clip(-60, 180).hist(bins=60)
        plt.title("Histogram of ArrDelay (clipped [-60, 180])")
        plt.xlabel("ArrDelay (minutes)"); plt.ylabel("Count")
        plt.tight_layout()
        plt.savefig(out_dir / "hist_arrdelay.png", dpi=120)
        plt.close()

    if "DepDelay" in df.columns:
        plt.figure()
        df["DepDelay"].dropna().clip(-60, 180).hist(bins=60)
        plt.title("Histogram of DepDelay (clipped [-60, 180])")
        plt.xlabel("DepDelay (minutes)"); plt.ylabel("Count")
        plt.tight_layout()
        plt.savefig(out_dir / "hist_depdelay.png", dpi=120)
        plt.close()

    carrier_col = "Carrier" if "Carrier" in df.columns else ("UniqueCarrier" if "UniqueCarrier" in df.columns else None)
    if carrier_col and "ArrDelay" in df.columns:
        small = df[[carrier_col, "ArrDelay"]].dropna()
        if len(small) > 200_000:
            small = small.sample(200_000, random_state=42)
        order = small.groupby(carrier_col)["ArrDelay"].median().sort_values().index[:15]
        data = [small.loc[small[carrier_col]==c, "ArrDelay"].clip(-60, 180) for c in order]
        plt.figure()
        plt.boxplot(data, vert=True, labels=list(order), showfliers=False)
        plt.title("ArrDelay by Carrier (top 15, clipped)")
        plt.xlabel("Carrier"); plt.ylabel("ArrDelay (minutes)")
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(out_dir / "box_arrdelay_by_carrier.png", dpi=120)
        plt.close()
        print("[EDA] plots saved:", out_dir)

    sample_n = min(sample_n, len(df))
    try:
        df.sample(sample_n, random_state=42).to_parquet(out_dir / "sample_100k.parquet", index=False)
        print("[EDA] sample_100k.parquet saved")
    except Exception as e:
        print("[WARN] save sample parquet failed:", e)

# ---------------------------
# Spark (optional)
# ---------------------------
def spark_demo(parquet_dir: Path):
    try:
        import pyspark  # noqa
        from pyspark.sql import SparkSession
        from pyspark.sql import functions as F
    except Exception as e:
        print("[WARN] pyspark not available:", e)
        return

    spark = SparkSession.builder.appName("bts-week1-eda").getOrCreate()
    base = (parquet_dir / "by_year_month").as_posix()
    sdf = spark.read.parquet(base)
    sdf.printSchema()
    print("[SPARK] rows:", sdf.count())

    carrier_col = "Carrier" if "Carrier" in sdf.columns else ("UniqueCarrier" if "UniqueCarrier" in sdf.columns else None)
    if carrier_col and "ArrDelay" in sdf.columns:
        agg = (sdf.groupBy(carrier_col)
               .agg(F.mean("ArrDelay").alias("mean_arr_delay"),
                    F.expr("percentile_approx(ArrDelay, 0.5)").alias("median_arr_delay"),
                    F.count("*").alias("n"))
               .orderBy(F.col("n").desc()))
        agg.show(20, truncate=False)

# ---------------------------
# Main
# ---------------------------
def main(argv=None):
    args = parse_args(argv)

    # Decide year/month selection
    if args.auto_recent and args.auto_recent > 0:
        year_to_months = recent_year_month_pairs(args.auto_recent)
    else:
        years = args.years if args.years else [2024, 2025]
        months = args.months if args.months else [9, 10, 11, 12, 1, 2, 3, 4]
        year_to_months = {}
        for y in years:
            year_to_months[y] = sorted([m for m in months if 1 <= m <= 12])

    OUT = Path(args.outdir)
    ZIP_DIR = OUT / "zip"
    RAW_DIR = OUT / "raw_csv"
    PARQUET_DIR = OUT / "parquet"
    EDA_DIR = OUT / "eda"

    for d in (OUT, ZIP_DIR, RAW_DIR, PARQUET_DIR, EDA_DIR):
        d.mkdir(parents=True, exist_ok=True)

    # 1) Download
    downloaded = []
    total_targets = sum(len(ms) for ms in year_to_months.values())
    done = 0
    for y, months in sorted(year_to_months.items()):
        for m in months:
            done += 1
            zpath = ZIP_DIR / f"On_Time_Reporting_Carrier_On_Time_Performance_1987_present_{y}_{m}.zip"
            if download_one(y, m, zpath):
                downloaded.append(zpath)
            print(f"[DL] progress: {done}/{total_targets}")
    print(f"[DONE] downloaded: {len(downloaded)} ZIP files")

    # 2) Unzip
    unzip_all(downloaded, RAW_DIR)

    # 3) Read & concat
    csvs = sorted(RAW_DIR.glob("*.csv"))
    full = read_concat_csvs(csvs, per_file_nrows=args.limit_csv_rows_per_file)

    # 4) Write Parquet
    write_parquet(full, PARQUET_DIR)

    # 5) EDA
    eda_basic(full, EDA_DIR, sample_n=args.sample_n)

    # 6) Spark (optional)
    if args.use_spark:
        spark_demo(PARQUET_DIR)

    print("\n[ALL DONE] Week 1 complete: download → unzip → concat → Parquet → EDA")

if __name__ == "__main__":
    # Allow direct runs AND `%run script.py ...` in notebooks
    # by passing the real argv from the interpreter:
    import sys
    main(sys.argv[1:])

[INFO] Ignoring unknown arguments: ['--f=/Users/guohaoyang/Library/Jupyter/runtime/kernel-v32410e68be1e33ec94c578d6ef13aec1991e655e5.json']
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_1.zip
[DL] progress: 1/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_2.zip
[DL] progress: 2/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_3.zip
[DL] progress: 3/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_4.zip
[DL] progress: 4/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_9.zip
[DL] progress: 5/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_present_2024_10.zip
[DL] progress: 6/16
[DL] https://transtats.bts.gov/PREZIP/On_Time_Reporting_Carrier_On_Time_Performance_1987_pr