In [1]:
# === Notebook-friendly: API vs Local (with smoothed y_act) ===================
# Paste this cell into a Jupyter notebook.

from __future__ import annotations
import os, json, time, math, threading
from typing import Iterable, Tuple, Optional
from pathlib import Path

import numpy as np
import pandas as pd
import requests
import joblib

# -------------------------------------------------------------------
# Import your package (assumes `pip install -e .` from repo root)
# If needed, adjust ROOT so imports resolve in your environment.
# -------------------------------------------------------------------
try:
    # Try to infer repo root two levels up from the notebook location
    ROOT = Path.cwd()
    while ROOT.name not in ("traffic_flow_package_src", "") and ROOT.parent != ROOT:
        ROOT = ROOT.parent
    if str(ROOT) not in os.sys.path:
        os.sys.path.insert(0, str(ROOT))
except Exception:
    pass

from traffic_flow.service.app import create_app
from traffic_flow.service.runtime import InferenceRuntime
from traffic_flow.pipeline.data_pipeline_orchestrator import TrafficDataPipelineOrchestrator
from traffic_flow.evaluation.model_comparison import ModelEvaluator  # only to reuse some helpers if needed

# -------------------------------------------------------------------
# Config you can tweak in the notebook right after executing this cell
# -------------------------------------------------------------------
ARTIFACT   = Path("../../artifacts/traffic_pipeline_h-15.joblib")
RAW_PATH   = Path("../../data/NDW/ndw_three_weeks.parquet")
BASE_URL   = "http://127.0.0.1:8080"
PORT       = 8080
BATCH_ROWS = 20_000               # per-sensor chunk size
TOL        = 1e-6                 # equality tolerance (abs max diff)
BAD_WEATHER_COLS = [
    "Snow_depth_surface",
    "Water_equivalent_of_accumulated_snow_depth_surface",
]


# ================== Server runner (in-process) =====================

class ServerThread(threading.Thread):
    """Run the Flask app in-process so you don't need a second terminal."""
    def __init__(self, artifact_path: str, host="127.0.0.1", port=8080):
        super().__init__(daemon=True)
        self.host = host
        self.port = port
        self.app = create_app(artifact_path=artifact_path)
        from werkzeug.serving import make_server
        self._srv = make_server(self.host, self.port, self.app)
        self._ctx = self.app.app_context()
        self._ctx.push()
    def run(self): self._srv.serve_forever()
    def shutdown(self):
        self._srv.shutdown()
        self._ctx.pop()


# ================= Helpers: states, cleaning, truth =================

def load_artifact(artifact_path: str | Path) -> dict:
    b = joblib.load(artifact_path)
    return {"bundle": b, "states": b["states"], "horizon": int(b.get("horizon", 15))}

def get_lag_steps(states: dict, default: int = 25) -> int:
    # Your orchestrator saves TemporalLagFeatureAdder under "lag_state"
    # Most likely the state has key "lags" (int). Adjust if your key differs.
    return int(states.get("lag_state", {}).get("lags", default))

def make_orchestrator_from_states(raw_path: str | Path, states: dict) -> TrafficDataPipelineOrchestrator:
    """Use the artifact's cleaning params so train/test split & smoothing match training."""
    clean = states["clean_state"]
    tdp = TrafficDataPipelineOrchestrator(file_path=str(raw_path), sensor_encoding_type="mean")
    tdp.prepare_base_features(
        window_size=clean["smoothing_window"],
        filter_extreme_changes=True,
        smooth_speeds=True,
        relative_threshold=clean["relative_threshold"],
        use_median_instead_of_mean_smoothing=clean["use_median"],
    )
    return tdp

def get_raw_test(raw_path: str | Path, states: dict) -> pd.DataFrame:
    tdp = make_orchestrator_from_states(raw_path, states)
    raw = pd.read_parquet(raw_path)
    # (Optional) drop problematic JSON cols / reduce payload size
    raw = raw.drop(columns=[c for c in BAD_WEATHER_COLS if c in raw.columns], errors="ignore")
    test = raw.loc[raw["date"] >= tdp.first_test_timestamp].copy()
    test.sort_values(["sensor_id", "date"], kind="mergesort", inplace=True)
    return test

def build_smoothed_y_act(raw_path: str | Path, states: dict, horizon: int) -> pd.DataFrame:
    """
    Rebuild training-style target on test rows:
      y_act = 'target_total_speed' at 'date_of_prediction'
    """
    tdp = make_orchestrator_from_states(raw_path, states)
    tdp.finalise_for_horizon(horizon=horizon, drop_datetime=False)  # keep date cols visible
    df = tdp.df.loc[tdp.df["test_set"], ["sensor_id", "date_of_prediction", "target_total_speed"]].copy()
    df.rename(columns={"date_of_prediction": "prediction_time", "target_total_speed": "y_act"}, inplace=True)
    df["prediction_time"] = pd.to_datetime(df["prediction_time"])
    df.sort_values(["sensor_id", "prediction_time"], inplace=True, kind="mergesort")
    return df


# ========== Overlapped per-sensor batching (preserves lags) ==========

def api_predict_sensor_overlapped(base_url: str,
                                  df_sensor: pd.DataFrame,
                                  lag_steps: int,
                                  chunk_rows: int = 20_000,
                                  timeout: int = 300) -> pd.DataFrame:
    """
    Send one sensor's rows in chunks with an overlap of lag_steps rows
    so lag features are computed with correct history.
    """
    df_sensor = df_sensor.sort_values("date", kind="mergesort").reset_index(drop=True)
    out = []
    n = len(df_sensor)
    start = 0
    while start < n:
        end   = min(start + chunk_rows, n)
        warm  = max(0, start - lag_steps)
        chunk = df_sensor.iloc[warm:end].copy()
        # datetime -> string for JSON
        chunk["date"] = pd.to_datetime(chunk["date"]).dt.strftime("%Y-%m-%d %H:%M:%S")
        records = json.loads(chunk.to_json(orient="records", date_format="iso"))
        r = requests.post(f"{base_url}/predict", json={"records": records}, timeout=timeout)
        r.raise_for_status()
        pred = pd.DataFrame(r.json()["predictions"])
        pred["input_time"] = pd.to_datetime(pred["input_time"])

        # Drop the warm-up predictions from this batch
        keep_from = pd.to_datetime(df_sensor.loc[start, "date"])
        pred = pred.loc[pred["input_time"] >= keep_from].copy()
        out.append(pred)
        start += chunk_rows

    if not out:
        return pd.DataFrame()
    out = pd.concat(out, ignore_index=True)
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out.sort_values(["sensor_id","prediction_time"], inplace=True, kind="mergesort")
    out.reset_index(drop=True, inplace=True)
    return out

def local_predict_sensor_overlapped(rt: InferenceRuntime,
                                    df_sensor: pd.DataFrame,
                                    lag_steps: int,
                                    chunk_rows: int = 20_000) -> pd.DataFrame:
    df_sensor = df_sensor.sort_values("date", kind="mergesort").reset_index(drop=True)
    out = []
    n = len(df_sensor)
    start = 0
    while start < n:
        end   = min(start + chunk_rows, n)
        warm  = max(0, start - lag_steps)
        chunk = df_sensor.iloc[warm:end].copy()
        pred_df, _ = rt.predict_df(chunk)
        pred_df["input_time"] = pd.to_datetime(pred_df["input_time"])

        keep_from = pd.to_datetime(df_sensor.loc[start, "date"])
        pred_df = pred_df.loc[pred_df["input_time"] >= keep_from].copy()
        out.append(pred_df)
        start += chunk_rows

    if not out:
        return pd.DataFrame()
    out = pd.concat(out, ignore_index=True)
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out.sort_values(["sensor_id","prediction_time"], inplace=True, kind="mergesort")
    out.reset_index(drop=True, inplace=True)
    return out

def api_predict_all_sensors(base_url: str,
                            raw_test: pd.DataFrame,
                            states: dict,
                            chunk_rows: int = 20_000,
                            timeout: int = 300) -> pd.DataFrame:
    L = get_lag_steps(states)
    outs = []
    for sid, df_s in raw_test.groupby("sensor_id", sort=False):
        outs.append(api_predict_sensor_overlapped(base_url, df_s, L, chunk_rows, timeout))
    out = pd.concat(outs, ignore_index=True) if outs else pd.DataFrame()
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out.sort_values(["sensor_id","prediction_time"], inplace=True, kind="mergesort")
    out.reset_index(drop=True, inplace=True)
    return out

def local_predict_all_sensors(artifact_path: str,
                              raw_test: pd.DataFrame,
                              states: dict,
                              chunk_rows: int = 20_000) -> pd.DataFrame:
    L = get_lag_steps(states)
    rt = InferenceRuntime(str(artifact_path))
    outs = []
    for sid, df_s in raw_test.groupby("sensor_id", sort=False):
        outs.append(local_predict_sensor_overlapped(rt, df_s, L, chunk_rows))
    out = pd.concat(outs, ignore_index=True) if outs else pd.DataFrame()
    out["prediction_time"] = pd.to_datetime(out["prediction_time"])
    out.sort_values(["sensor_id","prediction_time"], inplace=True, kind="mergesort")
    out.reset_index(drop=True, inplace=True)
    return out


# ===================== Ground-truth & compare =======================

def attach_y_act(pred_df: pd.DataFrame, truth_df: pd.DataFrame) -> pd.DataFrame:
    pred = pred_df.copy()
    pred["prediction_time"] = pd.to_datetime(pred["prediction_time"])
    out = pred.merge(truth_df, on=["sensor_id","prediction_time"], how="left")
    out = out[["prediction_time","sensor_id","y_act","y_pred_total"]].rename(columns={"y_pred_total":"y_pred"})
    out.sort_values(["sensor_id","prediction_time"], inplace=True, kind="mergesort")
    out.reset_index(drop=True, inplace=True)
    return out

def check_equal(a: pd.DataFrame, b: pd.DataFrame, tol: float = 1e-6) -> Tuple[bool,float]:
    key = ["sensor_id","prediction_time"]
    A = a[key+["y_pred"]].rename(columns={"y_pred":"y_pred_a"})
    B = b[key+["y_pred"]].rename(columns={"y_pred":"y_pred_b"})
    m = A.merge(B, on=key, how="inner")
    if m.empty:
        return False, float("inf")
    diff = (m["y_pred_a"] - m["y_pred_b"]).to_numpy()
    return bool(np.max(np.abs(diff)) <= tol), float(np.max(np.abs(diff)))


# =========================== Driver =================================

def run_all(artifact: str | Path = ARTIFACT,
            raw_path: str | Path = RAW_PATH,
            url: str = BASE_URL,
            start_server: bool = True,
            batch_rows: int = BATCH_ROWS,
            for_sensor: Optional[str] = None):
    """
    Returns:
      api_df, local_df, truth_df  (each canonical: prediction_time, sensor_id, y_act/y_pred)
    """
    # 0) Load states & horizon
    art = load_artifact(artifact)
    states, horizon = art["states"], art["horizon"]

    # 1) Optional: start API here
    server = None
    if start_server:
        os.environ["ARTIFACT_PATH"] = str(artifact)
        host, port = "127.0.0.1", int(url.rsplit(":", 1)[-1])
        server = ServerThread(artifact_path=str(artifact), host=host, port=port)
        server.start()
        # Wait for /healthz
        for _ in range(60):
            try:
                if requests.get(f"{url}/healthz", timeout=1.0).ok:
                    break
            except Exception:
                time.sleep(0.25)
        else:
            if server:
                server.shutdown()
            raise RuntimeError("Service did not become healthy on /healthz")

    try:
        # 2) RAW test rows (aligned with training cleaning)
        raw_test = get_raw_test(raw_path, states)
        if for_sensor is not None:
            raw_test = raw_test.loc[raw_test["sensor_id"] == for_sensor].copy()
            if raw_test.empty:
                raise ValueError(f"Sensor_id '{for_sensor}' not found in test split.")

        # 3) Build smoothed ground-truth (training target)
        truth_df = build_smoothed_y_act(raw_path, states, horizon)

        # 4) Predictions
        api_pred  = api_predict_all_sensors(url, raw_test, states, chunk_rows=batch_rows)
        local_pred= local_predict_all_sensors(artifact, raw_test, states, chunk_rows=batch_rows)

        # 5) Attach y_act for plotting / analysis
        api_df   = attach_y_act(api_pred,   truth_df)
        local_df = attach_y_act(local_pred, truth_df)

        # 6) Equality check
        ok, max_abs = check_equal(api_df, local_df, tol=TOL)
        print(f"API vs Local equal within tol {TOL}: {ok}  |  max|Δ| = {max_abs:.3g}")
        return api_df, local_df, truth_df

    finally:
        if start_server and server:
            server.shutdown()

In [2]:
# Choose a sensor (or leave None for all sensors)
sensor = None  # e.g., "RWS01_MONIBAS_0041hrr0592ra"
BASE_URL = "http://127.0.0.1:8080"
api_df, local_df, truth_df = run_all(
    artifact=ARTIFACT,
    raw_path=RAW_PATH,
    url=BASE_URL,
    start_server=False,   # <— reuse existing server
    batch_rows=20_000,
    for_sensor=None,
)

api_df.head()

Running prepare_base_features!!!!!!!!!!!!!!!!
[MeanSensorEncoder] Mean encoding learned for 204 sensors. Global mean=93.85.
[AdjacentSensorFeatureAdder] Adding adjacent sensor features.
[AdjacentSensorFeatureAdder] Added features: downstream_sensor_1, upstream_sensor_1
Running prepare_base_features!!!!!!!!!!!!!!!!
[MeanSensorEncoder] Mean encoding learned for 204 sensors. Global mean=93.85.
[AdjacentSensorFeatureAdder] Adding adjacent sensor features.
[AdjacentSensorFeatureAdder] Added features: downstream_sensor_1, upstream_sensor_1




[PreviousWeekdayWindowFeatureEngineer] horizon=15′  window=[-0,+0]′ step=1′  aggs=-  mode=local
[WeatherFeatureDropper] Will drop ['incremental_id', 'Per_cent_frozen_precipitation_surface', 'Precipitable_water_entire_atmosphere_single_layer', 'Precipitation_rate_surface_3_Hour_Average', 'Storm_relative_helicity_height_above_ground_layer', 'Total_precipitation_surface_3_Hour_Accumulation', 'Categorical_Rain_surface_3_Hour_Average', 'Categorical_Freezing_Rain_surface_3_Hour_Average', 'Categorical_Ice_Pellets_surface_3_Hour_Average', 'Categorical_Snow_surface_3_Hour_Average', 'Convective_Precipitation_Rate_surface_3_Hour_Average', 'Convective_precipitation_surface_3_Hour_Accumulation', 'U-Component_Storm_Motion_height_above_ground_layer', 'V-Component_Storm_Motion_height_above_ground_layer', 'Geopotential_height_highest_tropospheric_freezing', 'Relative_humidity_highest_tropospheric_freezing', 'Ice_cover_surface', 'Snow_depth_surface', 'Water_equivalent_of_accumulated_snow_depth_surface',

HTTPError: 500 Server Error: INTERNAL SERVER ERROR for url: http://127.0.0.1:8080/predict