In [None]:
# t7_service_classifier_pandas_chrono_featimp.py
import os
from typing import Optional, Dict, Any, List

import numpy as np
import pandas as pd
import pyarrow.dataset as ds

import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    f1_score,
    accuracy_score,
)
from sklearn.inspection import permutation_importance

from utils.constants import (
    TASK2_OUT_ROOT,
    TASK7_FHV_SCHEMA,
    TASK7_FHVHV_SCHEMA,
    TASK7_GREENTAXI_SCHEMA,
    TASK7_YELLOWTAXI_SCHEMA,
    LATEX_ROOT,
    RESULTS_ROOT,
)

# -------------------- STYLE / IO --------------------
sns.set_theme(style="whitegrid")
FIG_DIR = os.path.join(LATEX_ROOT, "figures")
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(RESULTS_ROOT, exist_ok=True)

# -------------------- CONFIG --------------------
SERVICE_DIR = {
    "Yellow": "yellow_tripdata",
    "Green":  "green_tripdata",
    "FHV":    "fhv_tripdata",
    "FHVHV":  "fhvhv_tripdata",
}
SCHEMA_MAP: Dict[str, Dict[str, Any]] = {
    "Yellow": TASK7_YELLOWTAXI_SCHEMA,
    "Green":  TASK7_GREENTAXI_SCHEMA,
    "FHV":    TASK7_FHV_SCHEMA,
    "FHVHV":  TASK7_FHVHV_SCHEMA,
}

# union of columns we may want (read only those that exist)
DESIRED = [
    "pickup_datetime", "dropoff_datetime",
    "pickup_borough", "dropoff_borough",
    "trip_distance",
    "fare_amount", "tip_amount",            # Yellow/Green
    "base_passenger_fare", "tips",          # FHVHV
]

# two consecutive months → chronological split
MONTHS = [(2023, 3), (2023, 4)]    # earlier = train, later = test
SERVICES = ["Yellow", "Green", "FHV", "FHVHV"]
MAX_ROWS = 200_000                 # cap per service per month (None = all)

TOP_K = 25                         # features to plot
PERMUTATION_SAMPLE = 20_000        # rows for permutation fallback
PERMUTATION_REPEATS = 5
RANDOM_STATE = 42

# -------------------- HELPERS --------------------
def _astype_with_schema(df: pd.DataFrame, schema: Dict[str, Any]) -> pd.DataFrame:
    """Cast columns present in df to types from schema; coerce datetimes."""
    for col, dt in schema.items():
        if col not in df.columns:
            continue
        try:
            df[col] = df[col].astype(dt)
        except Exception:
            if isinstance(dt, str) and dt.startswith("datetime64"):
                df[col] = pd.to_datetime(df[col], errors="coerce")
    return df

def read_service_month(service: str, year: int, month: int, max_rows: Optional[int] = None) -> pd.DataFrame:
    """Read a single month with predicate pushdown; unify monetary columns."""
    year_dir = os.path.join(TASK2_OUT_ROOT, SERVICE_DIR[service], f"year={year}")
    if not os.path.isdir(year_dir):
        print(f"[WARN] missing: {year_dir}")
        return pd.DataFrame()

    avail = set(ds.dataset(year_dir, format="parquet").schema.names)
    cols = [c for c in DESIRED if c in avail]

    start = pd.Timestamp(year=year, month=month, day=1)
    end   = start + pd.offsets.MonthBegin(1)

    df = pd.read_parquet(
        year_dir,
        engine="pyarrow",
        columns=cols,
        filters=[("pickup_datetime", ">=", start),
                 ("pickup_datetime", "<",  end)],
    )

    df = _astype_with_schema(df, SCHEMA_MAP[service])

    # unify monetary names
    if service in ("Yellow", "Green"):
        df["fare"] = df["fare_amount"] if "fare_amount" in df.columns else np.nan
        df["tip"]  = df["tip_amount"]  if "tip_amount"  in df.columns else np.nan
    elif service == "FHVHV":
        df["fare"] = df["base_passenger_fare"] if "base_passenger_fare" in df.columns else np.nan
        df["tip"]  = df["tips"]               if "tips"               in df.columns else np.nan
    else:  # FHV typically has no monetary fields
        df["fare"] = np.nan
        df["tip"]  = np.nan

    df["service"] = service

    if max_rows and len(df) > max_rows:
        df = df.sample(n=max_rows, random_state=RANDOM_STATE)

    keep = [
        "pickup_datetime","dropoff_datetime",
        "pickup_borough","dropoff_borough",
        "trip_distance","fare","tip","service",
    ]
    keep = [c for c in keep if c in df.columns]
    return df[keep].reset_index(drop=True)

def add_features(dfx: pd.DataFrame) -> pd.DataFrame:
    """Feature engineering inplace; returns the same frame."""
    # duration & speed
    dur_min = (dfx["dropoff_datetime"] - dfx["pickup_datetime"]).dt.total_seconds() / 60.0
    dur_min = dur_min.clip(lower=0.1, upper=6*60)   # guard
    dfx["duration_min"] = dur_min
    dur_hr = dur_min / 60.0
    with np.errstate(divide="ignore", invalid="ignore"):
        speed = dfx.get("trip_distance", pd.Series(np.nan, index=dfx.index)) / dur_hr
    dfx["speed_mph"] = np.where(np.isfinite(speed), speed, np.nan).clip(0, 120)

    # calendar
    dfx["hour"] = dfx["pickup_datetime"].dt.hour.astype("int16")
    dfx["dow"]  = dfx["pickup_datetime"].dt.dayofweek.astype("int16")

    # categories
    dfx["pickup_borough"]  = dfx["pickup_borough"].astype("string")
    dfx["dropoff_borough"] = dfx["dropoff_borough"].astype("string")
    dfx["od_pair"] = (dfx["pickup_borough"].fillna("NA") + "→" +
                      dfx["dropoff_borough"].fillna("NA")).astype("string")

    # money-derived
    dfx["fare"] = dfx.get("fare", np.nan).astype("float32")
    dfx["tip"]  = dfx.get("tip",  np.nan).astype("float32")
    dfx["has_fare"] = dfx["fare"].notna().astype("int8")
    dfx["has_tip"]  = (dfx["tip"].notna() & (dfx["tip"] > 0)).astype("int8")
    dfx["tip_rate"] = np.where(dfx["fare"] > 0, dfx["tip"] / dfx["fare"], np.nan).astype("float32")
    dfx["log_fare"] = np.where(dfx["fare"] > 0, np.log1p(dfx["fare"]), np.nan).astype("float32")
    return dfx

def OneHotDense(**kwargs):
    """Compat wrapper around OneHotEncoder sparse_output/sparse param across sklearn versions."""
    try:
        return OneHotEncoder(sparse_output=False, **kwargs)  # sklearn ≥ 1.2
    except TypeError:
        return OneHotEncoder(sparse=False, **kwargs)

def pretty_names(names: np.ndarray) -> List[str]:
    out = []
    for n in names:
        s = str(n)
        s = s.replace("num__", "").replace("cat__", "")
        s = s.replace("pickup_borough_", "PU:")
        s = s.replace("dropoff_borough_", "DO:")
        s = s.replace("od_pair_", "")
        out.append(s)
    return out

# -------------------- LOAD TWO MONTHS --------------------
frames = []
for (Y, M) in MONTHS:
    for svc in SERVICES:
        part = read_service_month(svc, Y, M, max_rows=MAX_ROWS)
        print(f"{svc} {Y}-{M:02d}: {len(part):,} rows")
        if not part.empty:
            part["yy_mm"] = f"{Y}-{M:02d}"
            frames.append(part)

if not frames:
    raise SystemExit("No data loaded. Check paths/months.")

df = pd.concat(frames, ignore_index=True).sort_values("pickup_datetime").reset_index(drop=True)

# -------------------- SPLIT, FEATURES --------------------
last_year, last_month = MONTHS[-1]
split_ts = pd.Timestamp(year=last_year, month=last_month, day=1)
train = df[df["pickup_datetime"] < split_ts].copy()
test  = df[df["pickup_datetime"] >= split_ts].copy()

print(f"\nSplit at {split_ts}")
print(f"Train rows: {len(train):,} | Test rows: {len(test):,}")
print("\nClass counts (train):\n", train["service"].value_counts())
print("\nClass counts (test):\n",  test["service"].value_counts())

train = add_features(train)
test  = add_features(test)

y_train = train["service"]
y_test  = test["service"]

# -------------------- PREPROCESSOR --------------------
# Use only behavior-level features (no monetary fields)
num_cols = ["trip_distance", "duration_min", "speed_mph", "hour", "dow"]
cat_cols = ["pickup_borough", "dropoff_borough", "od_pair"]

preprocess = ColumnTransformer(
    transformers=[
        ("num", Pipeline([
            ("imputer", SimpleImputer(strategy="median")),
            ("scaler", StandardScaler())
        ]), [c for c in num_cols if c in train.columns]),
        ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False, min_frequency=100),
         [c for c in cat_cols if c in train.columns]),
    ],
    remainder="drop",
)

preprocess = ColumnTransformer(
    transformers=[
        ("num", Pipeline([
            ("imputer", SimpleImputer(strategy="median")),
            ("scaler", StandardScaler())
        ]), [c for c in num_cols if c in train.columns]),
        ("cat", OneHotDense(handle_unknown="ignore", min_frequency=100),
         [c for c in cat_cols if c in train.columns]),
    ],
    remainder="drop",
    verbose_feature_names_out=True,
)

# -------------------- MODELS --------------------
models = {
    "LogReg_balanced": LogisticRegression(max_iter=300, class_weight="balanced", random_state=RANDOM_STATE),
    "HGB": HistGradientBoostingClassifier(
        max_iter=300, learning_rate=0.1,
        early_stopping=True, validation_fraction=0.1,
        random_state=RANDOM_STATE
    ),
}

labels = sorted(pd.concat([y_train, y_test]).unique())
results = []

# -------------------- TRAIN / EVAL / FEATURE IMPORTANCE --------------------
for name, clf in models.items():
    pipe = Pipeline([("prep", preprocess), ("clf", clf)])
    pipe.fit(train, y_train)
    pred = pipe.predict(test)

    acc = accuracy_score(y_test, pred)
    f1m = f1_score(y_test, pred, average="macro")

    print(f"\n=== {name} (chronological split) ===")
    print(f"Accuracy: {acc:.3f} | Macro-F1: {f1m:.3f}")
    print(classification_report(y_test, pred, digits=3))

    # Confusion matrix
    cm = confusion_matrix(y_test, pred, labels=labels)
    cm_df = pd.DataFrame(cm, index=[f"true_{l}" for l in labels], columns=[f"pred_{l}" for l in labels])

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues")
    plt.title(f"Service classifier — {name} (chronological split)")
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, f"t7_cm_{name}_chrono.pdf"))
    plt.savefig(os.path.join(FIG_DIR, f"t7_cm_{name}_chrono.png"), dpi=160)
    plt.close()

    # ---------- Feature names from the ColumnTransformer ----------
    feat_names = pipe.named_steps["prep"].get_feature_names_out()
    feat_names = np.array(pretty_names(feat_names))

    # ---------- Importances ----------
    importances = None

    if name.startswith("LogReg"):
        # Use L2 norm across classes for multinomial LR
        coefs = pipe.named_steps["clf"].coef_
        importances = np.linalg.norm(coefs, axis=0)

    elif name == "HGB":
        hgb = pipe.named_steps["clf"]
        if hasattr(hgb, "feature_importances_"):
            importances = hgb.feature_importances_
        else:
            # Fallback: permutation importance on transformed *test* sample
            prep = pipe.named_steps["prep"]
            cols_in = getattr(prep, "feature_names_in_", None)
            if cols_in is None:
                cols_in = test.columns  # last resort

            per_class = max(1, PERMUTATION_SAMPLE // max(1, len(labels)))
            sample_df = (
                test.groupby("service", group_keys=False)
                    .apply(lambda d: d.sample(min(len(d), per_class), random_state=RANDOM_STATE))
            )
            Xt = prep.transform(sample_df.loc[:, cols_in])
            yt = sample_df["service"].values

            perm = permutation_importance(
                hgb, Xt, yt,
                n_repeats=PERMUTATION_REPEATS,
                random_state=RANDOM_STATE,
                scoring="accuracy",
            )
            importances = perm.importances_mean

    # Plot & save top-K importances (if computed)
    if importances is not None:
        imp_df = pd.DataFrame({"feature": feat_names, "importance": importances})
        imp_df = imp_df.sort_values("importance", ascending=False).head(TOP_K)

        imp_csv = os.path.join(RESULTS_ROOT, f"t7_feature_importance_{name}.csv")
        imp_df.to_csv(imp_csv, index=False)

        plt.figure(figsize=(8, max(4, 0.35 * len(imp_df))))
        sns.barplot(
            data=imp_df.sort_values("importance", ascending=True),
            x="importance", y="feature", orient="h"
        )
        plt.xlabel("Importance")
        plt.ylabel("")
        plt.title(f"Top-{TOP_K} feature importances — {name}")
        plt.tight_layout()
        plt.savefig(os.path.join(FIG_DIR, f"t7_featimp_{name}.pdf"))
        plt.savefig(os.path.join(FIG_DIR, f"t7_featimp_{name}.png"), dpi=160)
        plt.close()

    results.append({"model": name, "accuracy": acc, "macro_f1": f1m})

# Save summary
res_df = pd.DataFrame(results).sort_values("macro_f1", ascending=False)
res_df.to_csv(os.path.join(RESULTS_ROOT, "t7_service_classifier_chrono_results.csv"), index=False)
print("\nSaved results, confusion matrices, and feature-importance plots.")
print("[OK] Figures →", FIG_DIR)
print("[OK] CSVs    →", RESULTS_ROOT)
