In [1]:
# scripts/tri_compare_predictions.py
from __future__ import annotations
import os, json, time, argparse, threading
from pathlib import Path
from typing import Iterable, Tuple
import sys
import numpy as np
import pandas as pd
import requests
import joblib

# Add the repo root (parent of `scripts/`) to sys.path
try:
    ROOT = Path(__file__).resolve().parents[1]  # When run as .py
except NameError:
    ROOT = Path().resolve().parents[0]          # When run in Jupyter

if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

# Import from your package (assumes `pip install -e .`)
from traffic_flow.service.app import create_app
from traffic_flow.service.runtime import InferenceRuntime
from traffic_flow.pipeline.data_pipeline_orchestrator import TrafficDataPipelineOrchestrator
from traffic_flow.inference.prediction_protocol import make_prediction_frame
from traffic_flow.evaluation.model_comparison import ModelEvaluator

In [2]:
# ---------------- Utilities ----------------

BAD_WEATHER_COLS = [
    "Snow_depth_surface",
    "Water_equivalent_of_accumulated_snow_depth_surface",
]

class ServerThread(threading.Thread):
    """Run Flask in-process so you don't need a second terminal."""
    def __init__(self, artifact_path: str, host="127.0.0.1", port=8080):
        super().__init__(daemon=True)
        self.host = host; self.port = port
        self.app = create_app(artifact_path=artifact_path)
        from werkzeug.serving import make_server
        self.srv = make_server(host, port, self.app)
        self.ctx = self.app.app_context()
        self.ctx.push()
    def run(self): self.srv.serve_forever()
    def shutdown(self):
        self.srv.shutdown()
        self.ctx.pop()

def iter_batches(df: pd.DataFrame, batch_size: int) -> Iterable[pd.DataFrame]:
    for i in range(0, len(df), batch_size):
        yield df.iloc[i:i+batch_size].copy()

def to_json_records_strict(df: pd.DataFrame) -> list[dict]:
    """Use pandas to_json to convert NaN/Inf to null; then parse back to dicts."""
    return json.loads(df.to_json(orient="records", date_format="iso"))


In [3]:
def make_batches_list(data, batch_size):
    batches = []
    for i in range(0, len(data), batch_size):
        batches.append(data[i:i+batch_size])
        print(f"going from i = {i} to i+batch_size = {i+batch_size}")
    return batches

In [4]:
# ---------------- Build test RAW rows with same cleaning params ----------------

def load_artifact(artifact: str | Path) -> dict:
    b = joblib.load(artifact)
    return {"bundle": b, "states": b["states"], "horizon": int(b.get("horizon", 15))}

def make_orchestrator_from_states(raw_path: str | Path, states: dict) -> TrafficDataPipelineOrchestrator:
    clean = states["clean_state"]
    tdp = TrafficDataPipelineOrchestrator(file_path=str(raw_path), sensor_encoding_type="mean")
    tdp.prepare_base_features(
        window_size=clean["smoothing_window"],
        filter_extreme_changes=True,
        smooth_speeds=True,
        relative_threshold=clean["relative_threshold"],
        use_median_instead_of_mean_smoothing=clean["use_median"],
    )
    return tdp

def get_raw_test(raw_path: str | Path, states: dict) -> pd.DataFrame:
    tdp = make_orchestrator_from_states(raw_path, states)
    raw = pd.read_parquet(raw_path)
    raw = raw.drop(columns=[c for c in BAD_WEATHER_COLS if c in raw.columns])  # drop problematic cols
    test = raw.loc[raw["date"] >= tdp.first_test_timestamp].copy()
    test.sort_values(["date","sensor_id"], kind="mergesort", inplace=True)
    return test,tdp

In [5]:
# ---------------- Three ways to predict ----------------

def api_predict_canonical(base_url: str, raw_test: pd.DataFrame, batch_size=20000, timeout=300) -> pd.DataFrame:
    outs = []
    for i, batch in enumerate(iter_batches(raw_test, batch_size), start=1):
        b = batch.copy()
        # format datetimes to string
        b["date"] = pd.to_datetime(b["date"]).dt.strftime("%Y-%m-%d %H:%M:%S")
        # strict json (nulls etc.)
        records = to_json_records_strict(b)
        print(f"Sending {len(records)} records")
        print(f"records[0] {records[0]}")  # Inspect 1st sample
        r = requests.post(f"{base_url}/predict", json={"records": records}, timeout=timeout)
        print(f"r.status_code : {r.status_code}")
        print(f"r.text{r.text}")
        r.raise_for_status()
        part = pd.DataFrame(r.json()["predictions"])
        outs.append(part)
        print(f"[API] batch {i} -> {len(part)} preds")
    out = pd.concat(outs, ignore_index=True) if outs else pd.DataFrame()
    # Normalize dtypes
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out = out.sort_values(["prediction_time", "sensor_id"], kind="mergesort").reset_index(drop=True)
    return out

def local_predict_canonical(artifact_path: str, raw_test: pd.DataFrame, batch_size=20000) -> pd.DataFrame:
    rt = InferenceRuntime(artifact_path)
    outs = []
    for i, batch in enumerate(iter_batches(raw_test, batch_size), start=1):
        pred_df, _ = rt.predict_df(batch)  # already canonical
        outs.append(pred_df)
        print(f"[LOCAL] batch {i} -> {len(pred_df)} preds")
    out = pd.concat(outs, ignore_index=True) if outs else pd.DataFrame()
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out = out.sort_values(["prediction_time", "sensor_id"], kind="mergesort").reset_index(drop=True)
    return out

def evaluator_offline_canonical(artifact_path: str, raw_path: str | Path) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Rebuilds training-like X_test/df using the artifact's states, then uses
    ModelEvaluator to produce canonical predictions for test rows.
    Returns (pred_df, truth_lookup) where truth_lookup has:
      sensor_id, prediction_time, y_act
    """
    art = load_artifact(artifact_path)
    states, horizon = art["states"], art["horizon"]

    # 1) Rebuild training split & X/y using same params; keep df (full)
    tdp = make_orchestrator_from_states(raw_path, states)
    tdp.finalise_for_horizon(horizon=horizon, drop_datetime=True)  # keep 'date' visible
    X_train, X_test, y_train, y_test = tdp.X_train, tdp.X_test, tdp.y_train, tdp.y_test
    df_all = tdp.df

    # 2) Evaluator + canonical preds (test only)
    me = ModelEvaluator(
        X_test=X_test,
        df_for_ML=df_all,       # evaluator will internally take test_set
        y_train=y_train,
        y_test=y_test,
        target_is_gman_error_prediction=False,
        y_is_normalized=False,
        rounding=6,
    )
    model = art["bundle"]["model"]
    pred_df = me.to_canonical_predictions(model=model, states=states, horizon_min=horizon)

    # 3) Smoothed ground truth at prediction_time (training target)
    truth = df_all.loc[df_all["test_set"], ["sensor_id", "date_of_prediction", "target_total_speed"]].copy()
    truth.rename(columns={"date_of_prediction": "prediction_time", "target_total_speed": "y_act"}, inplace=True)
    truth["prediction_time"] = pd.to_datetime(truth["prediction_time"])
    truth = truth.sort_values(["prediction_time","sensor_id"], kind="mergesort").reset_index(drop=True)
    return pred_df, truth

In [None]:
def get_lag_steps(states: dict, default: int = 25) -> int:
    # Adjust the key path to your states layout if necessary
    return int(states.get("lag_state", {}).get("lags", default))

In [6]:
# ---------------- Compare helpers ----------------

def attach_y_act(pred_df: pd.DataFrame, truth: pd.DataFrame) -> pd.DataFrame:
    pred = pred_df.copy()
    pred["prediction_time"] = pd.to_datetime(pred["prediction_time"])
    out = pred.merge(truth, on=["sensor_id", "prediction_time"], how="left")
    return out.loc[:, ["prediction_time", "sensor_id", "y_act", "y_pred_total"]].rename(columns={"y_pred_total": "y_pred"})

def check_equal(a: pd.DataFrame, b: pd.DataFrame, tol=1e-8) -> Tuple[bool,float]:
    key = ["sensor_id", "prediction_time"]
    ma = a.copy(); mb = b.copy()
    ma["prediction_time"] = pd.to_datetime(ma["prediction_time"])
    mb["prediction_time"] = pd.to_datetime(mb["prediction_time"])
    m = ma.merge(mb, on=key, suffixes=("_a","_b"), how="inner")
    if m.empty:
        return False, float("inf")
    diff = (m["y_pred_a"] - m["y_pred_b"]).to_numpy()
    return bool(np.max(np.abs(diff)) <= tol), float(np.max(np.abs(diff)))

In [7]:
from pathlib import Path
import os
import time
import requests

# Paths (adapt as needed)
artifact = Path("../../artifacts/traffic_pipeline_h-15.joblib")
raw_path = Path("../../data/NDW/ndw_three_weeks.parquet")
url = "http://127.0.0.1:8080"
start_server = True
batch_size = 20000
tolerance = 1e-6
save_outputs = False

In [None]:
server = None
if start_server:
    os.environ["ARTIFACT_PATH"] = str(artifact)
    host, port = "127.0.0.1", int(url.rsplit(":", 1)[-1])
    server = ServerThread(artifact_path=str(artifact), host=host, port=port)
    server.start()

    for _ in range(60):
        try:
            if requests.get(f"{url}/healthz", timeout=1).ok:
                print("Server is up!")
                break
        except Exception:
            time.sleep(0.25)
    else:
        raise RuntimeError("API failed to come up.")

In [None]:
raw = pd.read_parquet('../../data/NDW/ndw_three_weeks.parquet')
artifact_load = load_artifact(artifact)
states = load_artifact(artifact)["states"]

# 1) Prepare RAW
raw_test,tdp = get_raw_test(raw_path, states)
raw = raw.loc[raw['date']>=tdp.first_test_timestamp]

# 2) Evaluator (offline)
eval_pred, truth = evaluator_offline_canonical(artifact, raw_path)
df_eval = attach_y_act(eval_pred, truth)

# 3) Local runtime
df_local_pred = local_predict_canonical(str(artifact), raw_test, batch_size=batch_size)
df_local = attach_y_act(df_local_pred, truth)

# 4) API
df_api_pred = api_predict_canonical(url, raw_test, batch_size=batch_size)
df_api = attach_y_act(df_api_pred, truth)

In [None]:
states['lag_state']['lags']

In [None]:
truth

In [None]:
artifact_load = load_artifact(artifact)
artifact_load.keys()

In [None]:
artifact_load['bundle'].keys()

In [None]:
# compare outputs
ok_al, maxdiff_al = check_equal(df_api.rename(columns={"y_pred":"y_pred_a"}), df_local.rename(columns={"y_pred":"y_pred_b"}), tol=tolerance)
ok_ae, maxdiff_ae = check_equal(df_api.rename(columns={"y_pred":"y_pred_a"}), df_eval.rename(columns={"y_pred":"y_pred_b"}), tol=tolerance)
ok_le, maxdiff_le = check_equal(df_local.rename(columns={"y_pred":"y_pred_a"}), df_eval.rename(columns={"y_pred":"y_pred_b"}), tol=tolerance)

print("\n=== Parity checks (predictions only) ===")
print(f"API vs LOCAL:  equal within tol={tolerance}? {ok_al}  (max|Δ|={maxdiff_al:.3g})")
print(f"API vs EVAL :  equal within tol={tolerance}? {ok_ae}  (max|Δ|={maxdiff_ae:.3g})")
print(f"LOCAL vs EVAL: equal within tol={tolerance}? {ok_le}  (max|Δ|={maxdiff_le:.3g})")

In [None]:
# view/save results
print("API head:\n", df_api.head())
print("LOCAL head:\n", df_local.head())
print("EVAL head:\n", df_eval.head())

if save_outputs:
    Path("outputs").mkdir(exist_ok=True)
    df_api.to_csv("outputs/api_with_y_act.csv", index=False)
    df_local.to_csv("outputs/local_with_y_act.csv", index=False)
    df_eval.to_csv("outputs/eval_with_y_act.csv", index=False)
    print("Saved CSVs to outputs/")

In [None]:
# shutdown server
if server:
    print("Shutting down server...")
    server.shutdown()