In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Pre-merge inspection for the UK DfT Road Safety dataset.

WHY THIS SCRIPT EXISTS
----------------------
Before you join/merge tables or build models, you should *profile* the raw data:
- Are the keys unique? (e.g., collision_index)
- Do foreign keys exist? (vehicles/casualties -> collisions)
- Are there out-of-range values? (lat/long, speed, ages)
- What are the shapes, dtypes, nulls, top categories?
- Are there duplicates on key combos?
- What years are present?

WHAT THIS SCRIPT PRODUCES
-------------------------
It writes three artifacts under 'data/processed/_profile/':

1) premerge_profile.md    (human-readable summary you can read & commit)
2) premerge_profile.json  (machine-readable profile for CI/tests)
3) columns_profile.csv    (column-level stats for diffing over time)

HOW TO RUN
----------
From the repo root:

    python scripts/premerge_inspect.py
    # Faster test run (reads only first N rows per CSV):
    python scripts/premerge_inspect.py --nrows 300000 --emit-samples

The script is Jupyter/VS Code friendly (it ignores unknown CLI args).
"""

from __future__ import annotations

import argparse
import json
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

# --------------------------- Configuration ---------------------------

# Common "unknown/not applicable" sentinel codes in DfT data that we treat as missing
SENTINELS = {-1, 97, 98, 99, 997, 998, 999}

# Output folder (relative to repo root)
OUTDIR_REL = Path("data/processed/_profile")

# Display/summary knobs
SAMPLE_K = 10   # number of example values to show per column
TOP_K = 12      # number of top values to list for categorical columns
CSV_SAMPLE_SIZE = 1000  # size of optional sample CSVs

# Join contracts (2020+ schema is lower_snake_case)
COLLISION_KEY = "collision_index"
VEH_KEYS = [COLLISION_KEY, "vehicle_reference"]
CAS_KEYS = [COLLISION_KEY, "vehicle_reference", "casualty_reference"]

# Light sanity checks (bounds). These are not "truth", just red-flag detectors.
RANGE_RULES = {
    "latitude": (-90, 90),
    "longitude": (-180, 180),
    "collision_year": (2010, 2100),
    "age_of_driver": (0, 110),
    "age_of_vehicle": (0, 80),
    "engine_capacity_cc": (0, 12000),
}

# Allowed speed limits (mph). We allow a few oddballs like 5/15/80 for edge cases/private roads.
ALLOWED_SPEEDS = {5, 10, 15, 20, 30, 40, 50, 60, 70, 80}

# Columns typically categorical in this dataset (used for top-values preview)
CATEGORICAL_HINTS = [
    "collision_severity", "day_of_week", "road_type", "speed_limit",
    "weather_conditions", "light_conditions", "urban_or_rural_area",
    "vehicle_type", "vehicle_manoeuvre", "casualty_severity", "casualty_type"
]

# Expected filenames (no legacy aliases)
EXPECTED_FILES = {
    "collisions": ["Collisions.csv"],
    "vehicles":   ["Vehicles.csv"],
    "casualties": ["Casualties.csv"],
}

# ------------------------------ Utilities ----------------------------

def repo_root(start: Path | None = None) -> Path:
    """Walk upward until we find a folder that looks like a Git repo (has .git or README.md)."""
    cur = start or Path.cwd()
    for p in [cur, *cur.parents]:
        if (p / ".git").exists() or (p / "README.md").exists():
            return p
    return cur

def find_dataset_root(raw_base: Path) -> Path | None:
    """
    Try to auto-detect the folder that contains the road safety CSVs under data/raw/.
    We pick the first folder that either looks like "road safety" or contains expected files.
    """
    cands: List[Path] = []
    if raw_base.exists():
        for d in raw_base.iterdir():
            if d.is_dir():
                has_expected = any(d.joinpath(n).exists()
                                   for names in EXPECTED_FILES.values() for n in names)
                looks_like = ("road" in d.name.lower() and "safety" in d.name.lower()) or has_expected
                if looks_like:
                    cands.append(d)
    return cands[0] if cands else None

def find_file(root: Path, names: List[str]) -> Path | None:
    """Find the first file matching exact name (case-insensitive), otherwise fallback to substring match."""
    all_csvs = list(root.rglob("*.csv")) + list(root.rglob("*.CSV"))
    # exact match first
    for nm in names:
        for p in all_csvs:
            if p.name.lower() == nm.lower():
                return p
    # substring fallback (e.g., Vehicles_2024.csv)
    for nm in names:
        stem = nm.lower().replace(".csv", "")
        for p in all_csvs:
            if stem in p.name.lower():
                return p
    return None

def read_csv_any(path: Path, nrows: int | None = None) -> pd.DataFrame:
    """Use pyarrow if available for speed; gracefully fallback otherwise."""
    try:
        return pd.read_csv(path, engine="pyarrow", nrows=nrows)
    except Exception:
        return pd.read_csv(path, nrows=nrows)

def coerce_sentinels_to_na(df: pd.DataFrame) -> pd.DataFrame:
    """
    Replace common sentinel codes (e.g., -1, 99) with proper NA values on integer/categorical columns.
    This prevents silent bugs where "unknown" is treated as a real category or value.
    """
    for c in df.columns:
        if pd.api.types.is_integer_dtype(df[c]) or isinstance(df[c].dtype, CategoricalDtype):
            df[c] = df[c].replace(list(SENTINELS), pd.NA)
    return df

def json_default(o: Any):
    """Make numpy/pandas scalars and NA values JSON-serializable."""
    if o is pd.NA or (isinstance(o, float) and pd.isna(o)) or o is None:
        return None
    if isinstance(o, (np.integer,)):
        return int(o)
    if isinstance(o, (np.floating,)):
        return float(o)
    if isinstance(o, (np.bool_,)):
        return bool(o)
    return str(o)

def top_values(df: pd.DataFrame, col: str, k: int = 12) -> List[Tuple[Any, int]]:
    """Top-k most frequent values (including NA) for a quick feel of distributions."""
    if col not in df.columns:
        return []
    vc = df[col].value_counts(dropna=False).head(k)
    out: List[Tuple[Any, int]] = []
    for idx, cnt in vc.items():
        if idx is pd.NA or (isinstance(idx, float) and pd.isna(idx)):
            out.append((None, int(cnt)))
        else:
            out.append((json_default(idx), int(cnt)))
    return out

def profile_table(df: pd.DataFrame, name: str) -> Dict[str, Any]:
    """Column-level overview: dtype, non-null counts, % nulls, uniques, sample values, memory."""
    info: Dict[str, Any] = {
        "name": name,
        "shape": [int(df.shape[0]), int(df.shape[1])],
        "memory_mb": round(df.memory_usage(deep=True).sum() / 1_048_576, 2),
        "columns": [],
    }
    for c in df.columns:
        s = df[c]
        info["columns"].append({
            "name": c,
            "dtype": str(s.dtype),
            "non_null": int(s.notna().sum()),
            "null_pct": round(float(s.isna().mean() * 100), 2),
            "nunique": int(s.nunique(dropna=True)),
            "examples": s.dropna().astype(str).head(SAMPLE_K).tolist(),
        })
    return info

def check_duplicates(df: pd.DataFrame, keys: List[str]) -> int:
    """Count rows duplicated on a set of key columns. Returns -1 if any key is missing."""
    if not all(k in df.columns for k in keys):
        return -1
    return int(df.duplicated(subset=keys, keep=False).sum())

def range_violations(df: pd.DataFrame, col: str, lo: float, hi: float) -> int:
    """Count values outside [lo, hi] after numeric coercion (non-numeric -> NA -> ignored)."""
    if col not in df.columns:
        return 0
    s = pd.to_numeric(df[col], errors="coerce").dropna()
    return int(((s < lo) | (s > hi)).sum())

def membership_violations(df: pd.DataFrame, col: str, allowed: set[int]) -> int:
    """Count values not in an allowed set (e.g., speed limits)."""
    if col not in df.columns:
        return 0
    s = pd.to_numeric(df[col], errors="coerce").dropna().astype(int)
    return int((~s.isin(allowed)).sum())

def write_markdown(md_path: Path, prof: Dict[str, Any], issues: List[str], ranges: Dict[str, Any],
                   years: Dict[str, List[int]], cats: Dict[str, Dict[str, List[Tuple[Any, int]]]]):
    """Render a concise, human-readable report and write it to Markdown."""
    lines: List[str] = []
    lines.append("# Pre-merge Inspection Report\n")
    lines.append(f"- Data root: `{prof['paths']['data_root']}`")
    lines.append(f"- Rows × Cols: "
                 f"Collisions={prof['shapes']['Collisions'][0]:,}×{prof['shapes']['Collisions'][1]}, "
                 f"Vehicles={prof['shapes']['Vehicles'][0]:,}×{prof['shapes']['Vehicles'][1]}, "
                 f"Casualties={prof['shapes']['Casualties'][0]:,}×{prof['shapes']['Casualties'][1]}\n")

    lines.append("## Years present")
    for k, v in years.items():
        lines.append(f"- **{k}**: {', '.join(map(str, v))}")

    lines.append("\n## Range checks (out-of-bounds counts)")
    for col, r in ranges.items():
        b = r['bounds']
        lines.append(f"- **{col}** [{b[0]}, {b[1]}]: "
                     f"Collisions={r['Collisions']}, Vehicles={r['Vehicles']}, Casualties={r['Casualties']}")

    lines.append("\n## Key integrity")
    lines.append("- Collisions.collision_index unique: " +
                 ("✅" if prof['contracts']['collisions_unique'] else "❌"))
    lines.append(f"- Vehicles duplicates on {VEH_KEYS}: {prof['contracts']['vehicles_dup']}")
    lines.append(f"- Casualties duplicates on {CAS_KEYS}: {prof['contracts']['casualties_dup']}")
    lines.append(f"- Vehicles missing FK -> Collisions: {prof['contracts']['vehicles_missing_fk']}")
    lines.append(f"- Casualties missing FK -> Collisions: {prof['contracts']['casualties_missing_fk']}")

    if prof['contracts'].get("speed_limit_outside_allowed") is not None:
        lines.append(f"- Speed limits not in {sorted(ALLOWED_SPEEDS)}: "
                     f"Collisions={prof['contracts']['speed_limit_outside_allowed']['Collisions']}, "
                     f"Vehicles={prof['contracts']['speed_limit_outside_allowed']['Vehicles']}, "
                     f"Casualties={prof['contracts']['speed_limit_outside_allowed']['Casualties']}")

    lines.append("\n## Top categorical values (previews)")
    for col, d in cats.items():
        lines.append(f"**{col}**")
        for table, items in d.items():
            if items:
                preview = ", ".join([f"{k if k is not None else 'NA'}×{v}" for k, v in items[:6]])
                lines.append(f"- {table}: {preview}")

    if issues:
        lines.append("\n## Issues found")
        for s in issues:
            lines.append(f"- {s}")
    else:
        lines.append("\n## Issues found\n- None detected.")

    md_path.write_text("\n".join(lines), encoding="utf-8")

# --------------------------------- Main ---------------------------------

def main():
    parser = argparse.ArgumentParser(description="Pre-merge inspection for DfT road safety data")
    parser.add_argument("--nrows", type=int, default=None,
                        help="Optional row cap per CSV for faster runs (e.g., 300000)")
    parser.add_argument("--emit-samples", action="store_true",
                        help="Also write small 1k-row CSV samples for each table")
    # JUPYTER/VS CODE SAFE: ignore unknown args like --f=... injected by kernels
    args, _ = parser.parse_known_args()

    logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(message)s")

    # Discover repo and data location
    ROOT = repo_root()
    RAW_BASE = ROOT / "data" / "raw"
    data_override = os.getenv("UK_RS_DATA_DIR")  # env var override if data lives elsewhere
    DATA_ROOT = Path(data_override) if data_override else find_dataset_root(RAW_BASE)
    if not DATA_ROOT or not DATA_ROOT.exists():
        raise FileNotFoundError(
            f"Couldn’t find road-safety data under {RAW_BASE}. "
            "Unzip the DfT files into data/raw/ (e.g., data/raw/dft_road_safety_last_5_years/), "
            "or set UK_RS_DATA_DIR to the directory."
        )

    # Locate the three expected CSVs
    acc_path = find_file(DATA_ROOT, EXPECTED_FILES["collisions"])
    veh_path = find_file(DATA_ROOT, EXPECTED_FILES["vehicles"])
    cas_path = find_file(DATA_ROOT, EXPECTED_FILES["casualties"])
    if not all([acc_path, veh_path, cas_path]):
        missing = []
        if not acc_path: missing.append("Collisions.csv")
        if not veh_path: missing.append("Vehicles.csv")
        if not cas_path: missing.append("Casualties.csv")
        raise FileNotFoundError(f"Missing expected tables: {', '.join(missing)}")

    logging.info(f"Data root: {DATA_ROOT}")
    logging.info(f"Detected:\n- {acc_path}\n- {veh_path}\n- {cas_path}")

    # Load (optionally capped by --nrows for speed)
    acc = read_csv_any(acc_path, nrows=args.nrows)
    veh = read_csv_any(veh_path, nrows=args.nrows)
    cas = read_csv_any(cas_path, nrows=args.nrows)

    # Normalize sentinels to proper NAs (so EDA isn't lying)
    for name, df in [("Collisions", acc), ("Vehicles", veh), ("Casualties", cas)]:
        before = df.isna().sum().sum()
        df = coerce_sentinels_to_na(df)
        after = df.isna().sum().sum()
        logging.info(f"{name}: coerced sentinels -> NA (+{after - before} new NAs)")
        if name == "Collisions": acc = df
        elif name == "Vehicles": veh = df
        else: cas = df

    # ------------------- Integrity & Contracts -------------------

    issues: List[str] = []

    # 1) Primary key uniqueness on collisions
    collisions_unique = COLLISION_KEY in acc.columns and acc[COLLISION_KEY].is_unique
    if not collisions_unique:
        dup_n = int(acc.duplicated(subset=[COLLISION_KEY]).sum()) if COLLISION_KEY in acc.columns else -1
        issues.append(f"[Collisions] {COLLISION_KEY} not unique (dups={dup_n})")

    # 2) Foreign key coverage: vehicles/casualties -> collisions
    vehicles_missing_fk = 0
    casualties_missing_fk = 0
    if COLLISION_KEY in veh.columns:
        vehicles_missing_fk = int((~veh[COLLISION_KEY].isin(acc[COLLISION_KEY])).sum())
        if vehicles_missing_fk:
            issues.append(f"[Vehicles] {vehicles_missing_fk} rows missing FK -> Collisions")
    if COLLISION_KEY in cas.columns:
        casualties_missing_fk = int((~cas[COLLISION_KEY].isin(acc[COLLISION_KEY])).sum())
        if casualties_missing_fk:
            issues.append(f"[Casualties] {casualties_missing_fk} rows missing FK -> Collisions")

    # 3) Duplicate key combos on vehicles/casualties
    vehicles_dup = check_duplicates(veh, VEH_KEYS)
    casualties_dup = check_duplicates(cas, CAS_KEYS)
    if vehicles_dup > 0:
        issues.append(f"[Vehicles] duplicate rows on {VEH_KEYS}: {vehicles_dup}")
    if casualties_dup > 0:
        issues.append(f"[Casualties] duplicate rows on {CAS_KEYS}: {casualties_dup}")

    # ------------------- Range & Domain Checks -------------------

    # 4) Ranges (bounds)
    range_report: Dict[str, Dict[str, Any]] = {}
    for col, (lo, hi) in RANGE_RULES.items():
        range_report[col] = {
            "Collisions": range_violations(acc, col, lo, hi),
            "Vehicles":   range_violations(veh, col, lo, hi),
            "Casualties": range_violations(cas, col, lo, hi),
            "bounds": [lo, hi],
        }

    # 5) Allowed values for speed_limit
    speed_out = {
        "Collisions": membership_violations(acc, "speed_limit", ALLOWED_SPEEDS),
        "Vehicles":   membership_violations(veh, "speed_limit", ALLOWED_SPEEDS),
        "Casualties": membership_violations(cas, "speed_limit", ALLOWED_SPEEDS),
    }

    # 6) Years present (helps pick train/test windows & detect partial years)
    years: Dict[str, List[int]] = {}
    for name, df in [("Collisions", acc), ("Vehicles", veh), ("Casualties", cas)]:
        if "collision_year" in df.columns:
            yrs = pd.to_numeric(df["collision_year"], errors="coerce").dropna().astype(int).unique()
            years[name] = sorted(map(int, yrs))

    # 7) Top values for common categorical columns (quick distribution snapshots)
    cats: Dict[str, Dict[str, List[Tuple[Any, int]]]] = {}
    for col in CATEGORICAL_HINTS:
        cats[col] = {
            "Collisions": top_values(acc, col),
            "Vehicles":   top_values(veh, col),
            "Casualties": top_values(cas, col),
        }

    # ------------------- Column Profile CSV -------------------

    OUTDIR = ROOT / OUTDIR_REL
    OUTDIR.mkdir(parents=True, exist_ok=True)

    def cols_df(df: pd.DataFrame, table: str) -> pd.DataFrame:
        rows = []
        for c in df.columns:
            s = df[c]
            rows.append({
                "table": table,
                "column": c,
                "dtype": str(s.dtype),
                "non_null": int(s.notna().sum()),
                "null_pct": float(s.isna().mean() * 100),
                "nunique": int(s.nunique(dropna=True)),
                "memory_bytes": int(s.memory_usage(deep=True)),
            })
        return pd.DataFrame(rows)

    cols_profile = pd.concat([
        cols_df(acc, "Collisions"),
        cols_df(veh, "Vehicles"),
        cols_df(cas, "Casualties")
    ], ignore_index=True)
    cols_profile.sort_values(["table", "column"]).to_csv(OUTDIR / "columns_profile.csv", index=False)

    # ------------------- Assemble & Write Reports -------------------

    prof = {
        "paths": {
            "root": str(ROOT),
            "data_root": str(DATA_ROOT),
            "collisions": str(acc_path),
            "vehicles": str(veh_path),
            "casualties": str(cas_path),
        },
        "shapes": {
            "Collisions": [int(acc.shape[0]), int(acc.shape[1])],
            "Vehicles":   [int(veh.shape[0]), int(veh.shape[1])],
            "Casualties": [int(cas.shape[0]), int(cas.shape[1])],
        },
        "profiles": {
            "Collisions": profile_table(acc, "Collisions"),
            "Vehicles":   profile_table(veh, "Vehicles"),
            "Casualties": profile_table(cas, "Casualties"),
        },
        "ranges": range_report,
        "categories_top": cats,
        "years_present": years,
        "contracts": {
            "collisions_unique": bool(collisions_unique),
            "vehicles_dup": int(vehicles_dup),
            "casualties_dup": int(casualties_dup),
            "vehicles_missing_fk": int(vehicles_missing_fk),
            "casualties_missing_fk": int(casualties_missing_fk),
            "speed_limit_outside_allowed": speed_out,
        },
        "issues": issues,
        "run": {
            "nrows": args.nrows,
            "emit_samples": bool(args.emit_samples),
        }
    }

    json_path = OUTDIR / "premerge_profile.json"
    md_path = OUTDIR / "premerge_profile.md"

    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(prof, f, indent=2, default=json_default)

    write_markdown(md_path, prof, issues, range_report, years, cats)

    # Optional: write small human-friendly samples for each table
    if args.emit_samples:
        (OUTDIR / "samples").mkdir(exist_ok=True, parents=True)
        for name, df in [("collisions", acc), ("vehicles", veh), ("casualties", cas)]:
            df.head(CSV_SAMPLE_SIZE).to_csv(OUTDIR / "samples" / f"{name}_sample_{CSV_SAMPLE_SIZE}.csv", index=False)

    logging.info(f"Wrote:\n- {md_path}\n- {json_path}\n- {OUTDIR/'columns_profile.csv'}")
    if args.emit_samples:
        logging.info(f"- samples: {OUTDIR/'samples'}/*_sample_{CSV_SAMPLE_SIZE}.csv")

    # CI NOTE:
    # To fail CI when issues are detected, uncomment the next two lines.
    # if issues:
    #     raise SystemExit(1)

if __name__ == "__main__":
    main()


INFO | Data root: c:\Users\James\Documents\GitHub\uk-road-safety-data-pipeline-dash\data\raw\dft_road_safety_last_5_years
INFO | Detected:
- c:\Users\James\Documents\GitHub\uk-road-safety-data-pipeline-dash\data\raw\dft_road_safety_last_5_years\Collisions.csv
- c:\Users\James\Documents\GitHub\uk-road-safety-data-pipeline-dash\data\raw\dft_road_safety_last_5_years\Vehicles.csv
- c:\Users\James\Documents\GitHub\uk-road-safety-data-pipeline-dash\data\raw\dft_road_safety_last_5_years\Casualties.csv
INFO | Collisions: coerced sentinels -> NA (+1528559 new NAs)
INFO | Vehicles: coerced sentinels -> NA (+2894042 new NAs)
INFO | Casualties: coerced sentinels -> NA (+925886 new NAs)
