# 3W Dataset (WELL-only) — Row-per-file Event Classification

Goal: Convert each time-series file into a single feature row (leakage-safe by well grouping), then train and evaluate multi-class classifiers for event_type_code (0..9).


# 3W Dataset (WELL-only) — Row-per-file Classification (v3_1)

## Summary (Engineer Story)
- Built a file index from the 3W dataset and filtered to **WELL files only** (1119 files, 40 wells).
- Cleaned each file: parsed timestamp index, renamed sensors, converted to numeric types.
- Engineered **row-per-file features**:
  - Continuous sensors: raw median/IQR/last + robust z-stats (mean/std/min/max/last) + delta(first→last) + missing ratio.
  - State/valve signals: last state + transition count/rate + time-in-state proportions (+ “other” state).
- Evaluated with **Group splits by `well_id`** to avoid leakage across wells (harder but realistic).
- Compared baseline (Logistic Regression) vs boosting (HistGradientBoosting with class-weighted sample_weight).
- Result: **HGB outperformed baseline**, reaching about **macro-F1 ~0.42 (repeated group shuffle)** and **fault-vs-normal F1 ~0.76**, with variance due to very rare classes.


## 0. Setup & Configuration

- Paths, random seeds, output folders
- Sensor name mapping and label mapping

In [1]:
# ============================================================
# 3W Dataset (WELL-only) — Row-per-file Classification (v3_1)
# Fully runnable, leakage-safe, interview-ready (+ stability eval)
#
# ✅ Key upgrades included (your requested 3)
# 1) Holdout diagnostics for HGB:
#    - confusion matrix (counts) saved as CSV
#    - per-class precision/recall/F1/support table saved as CSV
# 2) Holdout wells are logged into results_summary.json (traceability)
# 3) Probability-imputation semantics fixed:
#    - if a state sensor has ALL probability features missing -> set p_other=1 and p_state_* = 0
#    - otherwise: fill missing probs with 0.0 and renormalize so probs sum to 1 (per sensor)
#    - plus missing indicators for all features (kept)
# ============================================================

import os, json, re
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, clone
from sklearn.model_selection import GroupShuffleSplit, StratifiedGroupKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
    f1_score,
    classification_report,
    confusion_matrix,
)
from sklearn.inspection import permutation_importance


# -------------------------
# Config
# -------------------------
BASE = "/kaggle/input/3w-dataset/2.0.0"
RANDOM_STATE = 42
N_WELL_FILES = None  # None = all WELL files

FEATURES_VERSION = "v3_1_fixed_timewin_zlast_hybridImputer_probSemanticFix"
OUT_DIR = f"/kaggle/working/3w_prepared_{FEATURES_VERSION}"
os.makedirs(OUT_DIR, exist_ok=True)

CACHE_PATH = f"{OUT_DIR}/df_ml_well_{FEATURES_VERSION}.parquet"
print("OUT_DIR:", OUT_DIR)
print("CACHE_PATH:", CACHE_PATH)

# -------------------------
# Dataset constants / mappings
# -------------------------
VAR_RENAME = {
    "ABER-CKGL": "gl_choke_opening_pct",
    "ABER-CKP":  "prod_choke_opening_pct",
    "ESTADO-DHSV":   "dhsv_state",
    "ESTADO-M1":     "prod_master_valve_state",
    "ESTADO-M2":     "ann_master_valve_state",
    "ESTADO-PXO":    "pig_crossover_valve_state",
    "ESTADO-SDV-GL": "gl_shutdown_valve_state",
    "ESTADO-SDV-P":  "prod_shutdown_valve_state",
    "ESTADO-W1":     "prod_wing_valve_state",
    "ESTADO-W2":     "ann_wing_valve_state",
    "ESTADO-XO":     "crossover_valve_state",
    "P-ANULAR":     "annulus_pressure_pa",
    "P-JUS-BS":     "svc_pump_downstream_pressure_pa",
    "P-JUS-CKGL":   "gl_choke_downstream_pressure_pa",
    "P-JUS-CKP":    "prod_choke_downstream_pressure_pa",
    "P-MON-CKGL":   "gl_choke_upstream_pressure_pa",
    "P-MON-CKP":    "prod_choke_upstream_pressure_pa",
    "P-MON-SDV-P":  "prod_sdv_upstream_pressure_pa",
    "P-PDG":        "pdg_downhole_pressure_pa",
    "PT-P":         "xmas_tree_prod_line_pressure_pa",
    "P-TPT":        "tpt_pressure_pa",
    "QBS": "svc_pump_flow_m3s",
    "QGL": "gas_lift_flow_m3s",
    "T-JUS-CKP": "prod_choke_downstream_temp_c",
    "T-MON-CKP": "prod_choke_upstream_temp_c",
    "T-PDG":     "pdg_downhole_temp_c",
    "T-TPT":     "tpt_temp_c",
    "class": "class_code",
    "state": "state_code",
}

EVENT_TYPE_CODE_TO_NAME = {
    0:"Normal Operation", 1:"Abrupt Increase of BSW", 2:"Spurious Closure of DHSV",
    3:"Severe Slugging", 4:"Flow Instability", 5:"Rapid Productivity Loss",
    6:"Quick Restriction in PCK", 7:"Scaling in PCK",
    8:"Hydrate in Production Line", 9:"Hydrate in Service Line",
}

LABEL_COLS = {"class_code", "state_code", "class_label", "state_label"}

OUT_DIR: /kaggle/working/3w_prepared_v3_1_fixed_timewin_zlast_hybridImputer_probSemanticFix
CACHE_PATH: /kaggle/working/3w_prepared_v3_1_fixed_timewin_zlast_hybridImputer_probSemanticFix/df_ml_well_v3_1_fixed_timewin_zlast_hybridImputer_probSemanticFix.parquet


## 1. Build File Index (Metadata Table)

- Scan all parquet files
- Extract: event_type_code, source, well_id, run timestamp
- Filter to WELL-only

In [2]:
# ============================================================
# 1) Build file index
# ============================================================
def build_file_index(base: str) -> pd.DataFrame:
    paths = []
    for root, _, files in os.walk(base):
        for f in files:
            if f.endswith(".parquet"):
                paths.append(os.path.join(root, f))

    df = pd.DataFrame({"path": paths})
    codes = df["path"].str.extract(r"/2\.0\.0/(\d+)/", expand=False)
    df["event_type_code"] = pd.to_numeric(codes, errors="coerce").astype("Int64")
    df = df.dropna(subset=["event_type_code"])
    df["event_type_code"] = df["event_type_code"].astype(int)

    df["file"] = df["path"].str.split("/").str[-1]
    df["source"] = df["file"].str.extract(r"^(WELL|SIMULATED|DRAWN)")
    df["well_id"] = df["file"].str.extract(r"(WELL-\d+)")
    df["run_ts"] = df["file"].str.extract(r"_(\d{14})")
    df["run_ts"] = pd.to_datetime(df["run_ts"], format="%Y%m%d%H%M%S", errors="coerce")

    return df.sort_values(["event_type_code","source","well_id","run_ts"]).reset_index(drop=True)

df_files = build_file_index(BASE)
df_w_files = df_files[df_files["source"] == "WELL"].copy()
assert df_w_files["well_id"].notna().all()

print("Total files:", len(df_files), "| WELL files:", len(df_w_files), "| Wells:", df_w_files["well_id"].nunique())
print("Class counts:\n", df_w_files["event_type_code"].value_counts().sort_index())

Total files: 2228 | WELL files: 1119 | Wells: 40
Class counts:
 event_type_code
0    594
1      4
2     22
3     32
4    343
5     11
6      6
7     36
8     14
9     57
Name: count, dtype: int64


## 2. Data Cleaning (Per-file)

- Parse timestamps and sort
- Rename sensors
- Coerce numeric dtypes


In [3]:
# ============================================================
# 2) Cleaning
# ============================================================
def clean_3w_instance(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "timestamp" in df.columns:
        df["timestamp"] = pd.to_datetime(df["timestamp"], errors="coerce")
        df = df.set_index("timestamp")
    else:
        df.index = pd.to_datetime(df.index, errors="coerce")

    df = df[~df.index.isna()].sort_index()
    df.index.name = "timestamp"
    df = df.rename(columns=VAR_RENAME)

    for c in df.columns:
        if c in ("class_code", "state_code"):
            df[c] = pd.to_numeric(df[c], errors="coerce").astype("Int16")
        else:
            df[c] = pd.to_numeric(df[c], errors="coerce").astype("float64")
    return df


## 3. Feature Engineering (Per-file → Single Row)

- Continuous sensors: robust z-stats + deltas between first/last time windows + missingness
- State sensors: last state, transition counts/rates, state proportions

In [4]:


# ============================================================
# 3) Feature extraction (time-window first/last + last-valid z_last)
# ============================================================
def summarize_timeseries_v3_1(df_clean: pd.DataFrame, frac: float = 0.1, state_max: int = 3) -> dict:
    sensors = df_clean.drop(columns=list(LABEL_COLS), errors="ignore")
    num = sensors.select_dtypes(include=[np.number])

    out = {
        "n_obs": int(len(df_clean)),
        "duration_s": float((df_clean.index.max() - df_clean.index.min()).total_seconds())
                      if len(df_clean) else np.nan,
    }
    if num.shape[1] == 0 or len(num) == 0:
        return out

    state_cols = [c for c in num.columns if c.endswith("_state")]
    cont_cols  = [c for c in num.columns if c not in state_cols]

    def _first_last_masks(index, frac: float, n: int):
        k = max(1, int(n * frac))
        if n < 2:
            m = np.ones(n, dtype=bool)
            return m, m

        try:
            idx = pd.DatetimeIndex(index)
        except Exception:
            idx = None

        if idx is not None and idx.notna().all():
            tmin, tmax = idx.min(), idx.max()
            dur_s = (tmax - tmin).total_seconds()
            if np.isfinite(dur_s) and dur_s > 0:
                w = dur_s * frac
                t_first_end = tmin + pd.Timedelta(seconds=w)
                t_last_start = tmax - pd.Timedelta(seconds=w)

                first_mask = np.asarray(idx <= t_first_end, dtype=bool)
                last_mask  = np.asarray(idx >= t_last_start, dtype=bool)

                if first_mask.sum() == 0:
                    first_mask = np.zeros(n, dtype=bool); first_mask[:k] = True
                if last_mask.sum() == 0:
                    last_mask = np.zeros(n, dtype=bool); last_mask[-k:] = True
                return first_mask, last_mask

        first_mask = np.zeros(n, dtype=bool); first_mask[:k] = True
        last_mask  = np.zeros(n, dtype=bool); last_mask[-k:] = True
        return first_mask, last_mask

    # Continuous sensors
    if len(cont_cols):
        cont = num[cont_cols]
        med_raw = cont.median()
        iqr_raw = (cont.quantile(0.75) - cont.quantile(0.25))

        iqr = iqr_raw.replace(0, np.nan)
        z = (cont - med_raw) / iqr

        first_mask, last_mask = _first_last_masks(df_clean.index, frac=frac, n=len(z))
        first = z.iloc[first_mask].mean()
        last  = z.iloc[last_mask].mean()

        agg = z.agg(["mean", "std", "min", "max"]).T
        miss = cont.isna().mean()

        for col in cont_cols:
            out[f"{col}__raw_median"] = med_raw[col]
            out[f"{col}__raw_iqr"]    = iqr_raw[col]

            s_raw = cont[col].dropna()
            out[f"{col}__raw_last"] = s_raw.iloc[-1] if len(s_raw) else np.nan

            out[f"{col}__z_mean"] = agg.loc[col, "mean"]
            out[f"{col}__z_std"]  = agg.loc[col, "std"]
            out[f"{col}__z_min"]  = agg.loc[col, "min"]
            out[f"{col}__z_max"]  = agg.loc[col, "max"]

            # last VALID z (not last row)
            s_z = z[col].dropna()
            out[f"{col}__z_last"] = s_z.iloc[-1] if len(s_z) else np.nan

            out[f"{col}__delta_last_first"] = (last[col] - first[col])
            out[f"{col}__abs_delta"]        = abs(last[col] - first[col])
            out[f"{col}__missing_frac"]     = miss[col]

    # State-like sensors
    if len(state_cols):
        st = num[state_cols]
        for col in state_cols:
            s = st[col]
            s_non = s.dropna()

            out[f"{col}__missing_frac"] = float(s.isna().mean())
            out[f"{col}__last"] = float(s_non.iloc[-1]) if len(s_non) else np.nan

            if len(s_non) >= 2:
                n_trans = int((s_non != s_non.shift()).sum() - 1)
            else:
                n_trans = 0

            out[f"{col}__n_transitions"] = n_trans
            out[f"{col}__transitions_rate"] = n_trans / max(1, len(s_non))

            known = 0.0
            for v in range(state_max + 1):
                p = float((s_non == v).mean()) if len(s_non) else np.nan
                out[f"{col}__p_state_{v}"] = p
                if not np.isnan(p):
                    known += p
            out[f"{col}__p_state_other"] = (1.0 - known) if len(s_non) else np.nan

    return out


## 4. Build Row-per-file Dataset (+ Caching)

- Loop over WELL files
- Read parquet → clean → summarize → append
- Save/load cached parquet to speed up reruns

In [5]:
# ============================================================
# 4) Build row-per-file dataset (with caching)
# ============================================================
def build_row_per_file_dataset(df_files: pd.DataFrame, n_files: int | None = None, random_state: int = 42) -> pd.DataFrame:
    if n_files is None:
        sample = df_files.reset_index(drop=True)
    else:
        sample = df_files.sample(n_files, random_state=random_state).reset_index(drop=True)

    rows = []
    for _, r in tqdm(sample.iterrows(), total=len(sample), desc="Building WELL dataset"):
        df_raw = pd.read_parquet(r["path"])
        df_clean = clean_3w_instance(df_raw)

        feats = summarize_timeseries_v3_1(df_clean)
        feats["event_type_code"] = int(r["event_type_code"])
        feats["event_type_name"] = EVENT_TYPE_CODE_TO_NAME.get(int(r["event_type_code"]), "Unknown")
        feats["well_id"] = r["well_id"]
        feats["run_ts"] = r["run_ts"]
        feats["file"] = r["file"]
        rows.append(feats)

    return pd.DataFrame(rows)

if os.path.exists(CACHE_PATH):
    df_ml_well = pd.read_parquet(CACHE_PATH)
    print("Loaded cached features:", df_ml_well.shape)
else:
    df_ml_well = build_row_per_file_dataset(df_w_files, n_files=N_WELL_FILES, random_state=RANDOM_STATE)
    df_ml_well.to_parquet(CACHE_PATH, index=False)
    print("Built + saved features:", df_ml_well.shape)

with open(f"{OUT_DIR}/dataset_config.json", "w") as f:
    json.dump({
        "base": BASE,
        "random_state": RANDOM_STATE,
        "n_well_files_used": int(len(df_ml_well)),
        "features_version": FEATURES_VERSION,
        "notes": "continuous: raw median/iqr/last + robust-z stats + deltas; "
                 "state: last + transitions + proportions (+other); WELL-only; "
                 "fixes: time-window first/last + z_last last-valid; "
                 "imputer: hybrid median/-1 + prob semantic fix + indicators; "
                 "cache versioned to avoid staleness"
    }, f, indent=2)


Loaded cached features: (1119, 286)


## 5. Prepare ML Matrices

- X: engineered features
- y: event_type_code
- groups: well_id (for leakage-safe splitting)

In [6]:
# ============================================================
# 5) Build X/y/groups (NO global filtering leakage)
# ============================================================
def make_Xy_groups(df: pd.DataFrame):
    y = df["event_type_code"].astype(int).copy()
    groups = df["well_id"].astype(str).copy()
    drop_cols = ["event_type_code","event_type_name","file","run_ts","well_id"]
    X = df.drop(columns=drop_cols, errors="ignore").copy()
    X = X.replace([np.inf, -np.inf], np.nan)
    return X, y, groups

X, y, groups = make_Xy_groups(df_ml_well)
print("X:", X.shape, "| y:", y.shape, "| wells:", groups.nunique())
print("Label counts:\n", y.value_counts().sort_index())


X: (1119, 281) | y: (1119,) | wells: 40
Label counts:
 event_type_code
0    594
1      4
2     22
3     32
4    343
5     11
6      6
7     36
8     14
9     57
Name: count, dtype: int64


## 6. Leakage-safe Preprocessing داخل Pipeline

- Train-only column filtering
- Hybrid imputation with probability semantics fix
- Missing indicators for all features

In [7]:
# ============================================================
# 6) Train-only column filter (inside pipeline)
# ============================================================
class TrainOnlyColumnFilter(BaseEstimator, TransformerMixin):
    def __init__(self, missing_threshold: float = 0.98):
        self.missing_threshold = missing_threshold

    def fit(self, X, y=None):
        X = X.copy().replace([np.inf, -np.inf], np.nan)
        all_missing = X.columns[X.isna().all()]
        miss = X.isna().mean()
        high_missing = miss[miss > self.missing_threshold].index
        const_cols = [c for c in X.columns if X[c].nunique(dropna=True) <= 1]
        drop = set(all_missing) | set(high_missing) | set(const_cols)
        self.keep_columns_ = [c for c in X.columns if c not in drop]
        return self

    def transform(self, X):
        X = X.copy().replace([np.inf, -np.inf], np.nan)
        for c in getattr(self, "keep_columns_", []):
            if c not in X.columns:
                X[c] = np.nan
        return X[self.keep_columns_]

# ============================================================
# 7) Hybrid imputer (probability semantics fixed)
#    - median for continuous
#    - -1 for state-ish "level/last/transition" features
#    - probability block per state sensor:
#         * if ALL probs missing -> p_other=1, p_state_*=0
#         * else fill NaN with 0 and renormalize to sum to 1
#    - + missing indicators for all columns
# ============================================================
class HybridImputer(BaseEstimator, TransformerMixin):
    def __init__(self, state_token: str = "_state", fill_value_state: float = -1.0):
        self.state_token = state_token
        self.fill_value_state = fill_value_state

    @staticmethod
    def _parse_prob_block(col: str):
        """
        Expect columns like: <sensor>_state__p_state_<k> or <sensor>_state__p_state_other
        Returns (base_sensor, is_other, k or None) or None if not match.
        """
        # examples:
        # dhsv_state__p_state_0
        # dhsv_state__p_state_other
        m_k = re.match(r"^(.*_state)__p_state_(\d+)$", col)
        if m_k:
            return m_k.group(1), False, int(m_k.group(2))
        m_o = re.match(r"^(.*_state)__p_state_other$", col)
        if m_o:
            return m_o.group(1), True, None
        return None

    def fit(self, X, y=None):
        X = X.copy()
        self.cols_ = list(X.columns)

        # any feature derived from a state sensor has "_state" in name
        state_like = [c for c in self.cols_ if self.state_token in c]

        # probability columns + group them by state sensor base
        prob_cols = []
        prob_groups = {}  # base -> {"other": col or None, "states": {k: col}}
        for c in self.cols_:
            parsed = self._parse_prob_block(c)
            if parsed is None:
                continue
            base, is_other, k = parsed
            prob_cols.append(c)
            prob_groups.setdefault(base, {"other": None, "states": {}})
            if is_other:
                prob_groups[base]["other"] = c
            else:
                prob_groups[base]["states"][k] = c

        self.prob_cols_ = prob_cols
        self.prob_groups_ = prob_groups

        # remaining state-like (non-prob) -> sentinel -1
        self.state_cols_ = [c for c in state_like if c not in self.prob_cols_]

        # continuous-ish are not state-derived -> median
        self.cont_cols_ = [c for c in self.cols_ if c not in state_like]
        self.cont_medians_ = (
            X[self.cont_cols_].median(numeric_only=True)
            if len(self.cont_cols_) else pd.Series(dtype=float)
        )

        return self

    def transform(self, X):
        X = X.copy()
        # Ensure all fit columns exist
        for c in self.cols_:
            if c not in X.columns:
                X[c] = np.nan
        X = X[self.cols_]

        # Missing indicators
        miss = X.isna().astype(np.int8)
        miss.columns = [f"{c}__isna" for c in miss.columns]

        # Continuous: median then fallback to 0.0 if median was NaN
        if len(self.cont_cols_):
            X[self.cont_cols_] = X[self.cont_cols_].fillna(self.cont_medians_)
            X[self.cont_cols_] = X[self.cont_cols_].fillna(0.0)

        # Probabilities: semantic fix per sensor group
        # - if all probs are NaN -> set p_other=1, p_state_k=0
        # - else fill NaN with 0 and renormalize
        if len(self.prob_cols_):
            # work group-by-group
            for base, g in self.prob_groups_.items():
                other_col = g.get("other", None)
                state_map = g.get("states", {})
                cols = []
                if other_col is not None:
                    cols.append(other_col)
                cols.extend([state_map[k] for k in sorted(state_map.keys())])

                if not cols:
                    continue

                block = X[cols]
                all_nan = block.isna().all(axis=1)

                # initialize by filling NaNs with 0
                block_filled = block.fillna(0.0)

                # rows where all probs were missing: force p_other=1 and others 0
                if other_col is not None:
                    block_filled.loc[all_nan, other_col] = 1.0
                # ensure all p_state_k are 0 for all_nan rows
                for c in cols:
                    if c != other_col:
                        block_filled.loc[all_nan, c] = 0.0

                # renormalize rows where NOT all_nan (sum>0) to sum to 1
                not_all = ~all_nan
                if not_all.any():
                    s = block_filled.loc[not_all, cols].sum(axis=1)
                    # if sum is 0 (possible if everything was NaN but other_col missing in schema)
                    # fall back to "other=1" when available, else leave zeros.
                    zero_sum = s == 0
                    if zero_sum.any() and other_col is not None:
                        block_filled.loc[not_all & zero_sum, other_col] = 1.0
                        for c in cols:
                            if c != other_col:
                                block_filled.loc[not_all & zero_sum, c] = 0.0
                        s = block_filled.loc[not_all, cols].sum(axis=1)

                    # divide by sum safely
                    block_filled.loc[not_all, cols] = block_filled.loc[not_all, cols].div(s, axis=0)

                X[cols] = block_filled

        # State-level / transition features: fill missing with -1 sentinel
        if len(self.state_cols_):
            X[self.state_cols_] = X[self.state_cols_].fillna(self.fill_value_state)

        return pd.concat([X, miss], axis=1)


## 7. Models

- Baseline: Logistic Regression
- Main: Weighted HistGradientBoosting (class-balanced sample weights)

In [8]:
# ============================================================
# 8) Weighted HGB estimator (cloneable) + predict_proba
# ============================================================
class WeightedHGBClassifier(BaseEstimator, ClassifierMixin):
    def __init__(
        self,
        max_depth=6,
        learning_rate=0.06,
        max_iter=700,
        max_leaf_nodes=31,
        min_samples_leaf=20,
        l2_regularization=0.1,
        random_state=42,
    ):
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.max_iter = max_iter
        self.max_leaf_nodes = max_leaf_nodes
        self.min_samples_leaf = min_samples_leaf
        self.l2_regularization = l2_regularization
        self.random_state = random_state

    def fit(self, X, y):
        self.clf_ = HistGradientBoostingClassifier(
            max_depth=self.max_depth,
            learning_rate=self.learning_rate,
            max_iter=self.max_iter,
            max_leaf_nodes=self.max_leaf_nodes,
            min_samples_leaf=self.min_samples_leaf,
            l2_regularization=self.l2_regularization,
            random_state=self.random_state,
        )

        y_arr = np.asarray(y)
        classes = np.unique(y_arr)
        cw = compute_class_weight(class_weight="balanced", classes=classes, y=y_arr)
        cw_map = dict(zip(classes, cw))
        sample_weight = np.array([cw_map[v] for v in y_arr], dtype=float)

        self.clf_.fit(X, y_arr, sample_weight=sample_weight)
        self.classes_ = classes
        return self

    def predict(self, X):
        return self.clf_.predict(X)

    def predict_proba(self, X):
        return self.clf_.predict_proba(X)


## 8. Evaluation

- Final holdout split by well_id (GroupShuffleSplit)
- CV on train wells only (StratifiedGroupKFold)
- Stability: repeated group holdout


In [9]:
# ============================================================
# 9) Evaluation helpers
# ============================================================
def choose_stratified_group_n_splits(X, y, groups, max_splits=4, random_state=42) -> int:
    """Pick the largest n_splits (<= max_splits) that doesn't raise ValueError."""
    for k in range(max_splits, 1, -1):
        try:
            cv = StratifiedGroupKFold(n_splits=k, shuffle=True, random_state=random_state)
            _ = next(cv.split(X, y, groups=groups))
            return k
        except ValueError:
            continue
    return 2

def evaluate_stratified_group_cv(model, X, y, groups, n_splits=4, random_state=42, major_min_support=10):
    labels = np.sort(pd.unique(y))
    cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    macro_f1s, major_f1s, bin_f1s = [], [], []
    per_class_f1 = {int(lbl): [] for lbl in labels}

    for _, (tr, te) in enumerate(cv.split(X, y, groups=groups), start=1):
        Xtr, Xte = X.iloc[tr], X.iloc[te]
        ytr, yte = y.iloc[tr], y.iloc[te]

        m = clone(model)
        m.fit(Xtr, ytr)
        pred = m.predict(Xte)

        macro_f1s.append(f1_score(yte, pred, average="macro", labels=labels, zero_division=0))

        major_classes = set(ytr.value_counts()[lambda s: s >= major_min_support].index.astype(int))
        if major_classes:
            major_labels = np.array(sorted(major_classes))
            major_mask = yte.isin(major_labels)
            if major_mask.any():
                major_f1s.append(
                    f1_score(
                        yte[major_mask], pred[major_mask],
                        average="macro", labels=major_labels, zero_division=0
                    )
                )
            else:
                major_f1s.append(np.nan)
        else:
            major_f1s.append(np.nan)

        bin_f1s.append(
            f1_score((yte != 0).astype(int), (pred != 0).astype(int), zero_division=0)
        )

        rep = classification_report(yte, pred, labels=labels, output_dict=True, zero_division=0)
        for lbl in labels:
            per_class_f1[int(lbl)].append(rep[str(int(lbl))]["f1-score"])

    return {
        "macro_f1_mean": float(np.nanmean(macro_f1s)),
        "macro_f1_std":  float(np.nanstd(macro_f1s)),
        "major_macro_f1_mean": float(np.nanmean(major_f1s)),
        "fault_vs_normal_f1_mean": float(np.nanmean(bin_f1s)),
        "per_class_f1_mean": {str(k): float(np.nanmean(v)) for k, v in per_class_f1.items()},
    }

def evaluate_on_holdout(model, Xtr, ytr, Xho, yho, major_min_support=10):
    labels = np.sort(pd.unique(pd.concat([ytr, yho], axis=0)))

    m = clone(model)
    m.fit(Xtr, ytr)
    pred = m.predict(Xho)

    macro = f1_score(yho, pred, average="macro", labels=labels, zero_division=0)

    major_classes = set(ytr.value_counts()[lambda s: s >= major_min_support].index.astype(int))
    major_labels = np.array(sorted(major_classes)) if major_classes else np.array([], dtype=int)
    if len(major_labels) and yho.isin(major_labels).any():
        major = f1_score(
            yho[yho.isin(major_labels)], pred[yho.isin(major_labels)],
            average="macro", labels=major_labels, zero_division=0
        )
    else:
        major = np.nan

    bin_f1 = f1_score((yho != 0).astype(int), (pred != 0).astype(int), zero_division=0)
    report_text = classification_report(yho, pred, labels=labels, zero_division=0)
    report_dict = classification_report(yho, pred, labels=labels, zero_division=0, output_dict=True)

    return {
        "macro_f1": float(macro),
        "major_macro_f1": float(major) if np.isfinite(major) else np.nan,
        "fault_vs_normal_f1": float(bin_f1),
        "classification_report": report_text,
        "classification_report_dict": report_dict,  # for tables
        "labels": labels,
        "y_pred": pred,
    }

def repeated_holdout(model, X, y, groups, repeats=30, test_size=0.2, random_state=42, major_min_support=10):
    labels = np.sort(pd.unique(y))
    gss = GroupShuffleSplit(n_splits=repeats, test_size=test_size, random_state=random_state)

    macs, bins, majors = [], [], []
    for tr, te in gss.split(X, y, groups=groups):
        Xtr, Xte = X.iloc[tr], X.iloc[te]
        ytr, yte = y.iloc[tr], y.iloc[te]

        m = clone(model)
        m.fit(Xtr, ytr)
        pred = m.predict(Xte)

        macs.append(f1_score(yte, pred, average="macro", labels=labels, zero_division=0))
        bins.append(f1_score((yte != 0).astype(int), (pred != 0).astype(int), zero_division=0))

        major_classes = set(ytr.value_counts()[lambda s: s >= major_min_support].index.astype(int))
        major_labels = np.array(sorted(major_classes)) if major_classes else np.array([], dtype=int)
        if len(major_labels) and yte.isin(major_labels).any():
            majors.append(
                f1_score(
                    yte[yte.isin(major_labels)], pred[yte.isin(major_labels)],
                    average="macro", labels=major_labels, zero_division=0
                )
            )
        else:
            majors.append(np.nan)

    return {
        "macro_f1_mean": float(np.nanmean(macs)),
        "macro_f1_std": float(np.nanstd(macs)),
        "major_macro_f1_mean": float(np.nanmean(majors)),
        "fault_vs_normal_f1_mean": float(np.nanmean(bins)),
        "fault_vs_normal_f1_std": float(np.nanstd(bins)),
    }

## 9. Diagnostics & Explainability

- Confusion matrix (CSV)
- Top confusions table (CSV)
- Per-class metrics (CSV)
- Permutation importance on holdout (CSV)
- Per-well performance (CSV)
- Save a single results_summary.json with full traceability

In [None]:
# ============================================================
# 10) Models (consistent preprocessing)
# ============================================================
preprocess_common = [
    ("col_filter", TrainOnlyColumnFilter(missing_threshold=0.98)),
    ("imputer", HybridImputer(state_token="_state", fill_value_state=-1.0)),
]

logreg = Pipeline(preprocess_common + [
    ("scaler", StandardScaler(with_mean=False)),
    ("clf", LogisticRegression(max_iter=8000, class_weight="balanced"))
])

hgb = Pipeline(preprocess_common + [
    ("clf", WeightedHGBClassifier(
        max_depth=6,
        learning_rate=0.06,
        max_iter=700,
        max_leaf_nodes=31,
        min_samples_leaf=20,
        l2_regularization=0.1,
        random_state=RANDOM_STATE
    ))
])

# ============================================================
# 11) Final holdout by WELL + CV on train only + repeated holdout
#     + NEW: Save holdout diagnostics (HGB confusion matrix + per-class table)
#     + NEW: Log holdout well IDs in JSON
# ============================================================
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_STATE)
tr_idx, ho_idx = next(gss.split(X, y, groups=groups))

X_tr, y_tr, g_tr = X.iloc[tr_idx], y.iloc[tr_idx], groups.iloc[tr_idx]
X_ho, y_ho, g_ho = X.iloc[ho_idx], y.iloc[ho_idx], groups.iloc[ho_idx]

holdout_wells = sorted(g_ho.unique().tolist())
train_wells = sorted(g_tr.unique().tolist())

print("\nTrain wells:", len(train_wells), "| Holdout wells:", len(holdout_wells))
print("Holdout wells:", holdout_wells)
print("Holdout label counts:\n", y_ho.value_counts().sort_index())

n_splits = choose_stratified_group_n_splits(X_tr, y_tr, g_tr, max_splits=4, random_state=RANDOM_STATE)
print(f"\nUsing StratifiedGroupKFold with n_splits={n_splits}")

print("\n=== CV on TRAIN ONLY ===")
logreg_cv = evaluate_stratified_group_cv(logreg, X_tr, y_tr, g_tr, n_splits=n_splits, random_state=RANDOM_STATE)
hgb_cv    = evaluate_stratified_group_cv(hgb,   X_tr, y_tr, g_tr, n_splits=n_splits, random_state=RANDOM_STATE)
print("LogReg CV:", logreg_cv)
print("HGB   CV:", hgb_cv)

print("\n=== FINAL HOLDOUT (train wells -> holdout wells) ===")
logreg_hold = evaluate_on_holdout(logreg, X_tr, y_tr, X_ho, y_ho)
hgb_hold    = evaluate_on_holdout(hgb,   X_tr, y_tr, X_ho, y_ho)

print("LogReg Holdout:", {k: v for k, v in logreg_hold.items() if k not in {"classification_report","classification_report_dict","labels","y_pred"}})
print("HGB   Holdout:", {k: v for k, v in hgb_hold.items() if k not in {"classification_report","classification_report_dict","labels","y_pred"}})

print("\nLogReg classification report (holdout):\n", logreg_hold["classification_report"])
print("\nHGB classification report (holdout):\n", hgb_hold["classification_report"])

# ---- NEW: Holdout diagnostics for HGB (confusion matrix + per-class table)
labels = hgb_hold["labels"]
pred_hgb = hgb_hold["y_pred"]

cm = confusion_matrix(y_ho, pred_hgb, labels=labels)
cm_df = pd.DataFrame(cm, index=[f"true_{l}" for l in labels], columns=[f"pred_{l}" for l in labels])
cm_path = f"{OUT_DIR}/hgb_holdout_confusion_matrix.csv"
cm_df.to_csv(cm_path, index=True)
print(f"\nSaved HGB holdout confusion matrix: {cm_path}")


# ---- (A) Error analysis: Top confusions (off-diagonal)
cm_counts = cm.copy()
np.fill_diagonal(cm_counts, 0)

row_support = cm.sum(axis=1)  # support per true class
pairs = []
for i, t in enumerate(labels):
    for j, p in enumerate(labels):
        if i == j:
            continue
        c = int(cm_counts[i, j])
        if c > 0:
            pairs.append({
                "true_class": int(t),
                "true_name": EVENT_TYPE_CODE_TO_NAME.get(int(t), "Unknown"),
                "pred_class": int(p),
                "pred_name": EVENT_TYPE_CODE_TO_NAME.get(int(p), "Unknown"),
                "count": c,
                "pct_of_true_class": c / max(1, int(row_support[i]))
            })

top_conf_df = (pd.DataFrame(pairs)
               .sort_values(["count", "pct_of_true_class"], ascending=False)
               .head(10))

top_conf_path = f"{OUT_DIR}/hgb_holdout_top_confusions.csv"
top_conf_df.to_csv(top_conf_path, index=False)
print(f"Saved top confusions table: {top_conf_path}")

print("\nTop-3 confusions (HGB holdout):")
print(top_conf_df.head(3).to_string(index=False))

# ---- (B) Permutation importance on HOLDOUT (macro-F1)
hgb_final = clone(hgb).fit(X_tr, y_tr)

perm = permutation_importance(
    hgb_final,
    X_ho, y_ho,
    n_repeats=8,
    random_state=RANDOM_STATE,
    scoring="f1_macro",
    n_jobs=1
)

perm_df = (pd.DataFrame({
    "feature": X_ho.columns,
    "importance_mean": perm.importances_mean,
    "importance_std": perm.importances_std
}).sort_values("importance_mean", ascending=False))

perm_path = f"{OUT_DIR}/hgb_holdout_permutation_importance_f1macro.csv"
perm_df.to_csv(perm_path, index=False)
print(f"\nSaved permutation importance: {perm_path}")

print("\nTop-15 important features (perm importance, macro-F1):")
print(perm_df.head(15).to_string(index=False))

# ---- (C) Per-well performance on HOLDOUT
pred_hgb_final = hgb_final.predict(X_ho)

df_well_perf = pd.DataFrame({
    "well_id": g_ho.values,
    "y_true": y_ho.values,
    "y_pred": pred_hgb_final
})

well_rows = []
for well, gdf in df_well_perf.groupby("well_id"):
    yt = gdf["y_true"].astype(int)
    yp = gdf["y_pred"].astype(int)
    well_rows.append({
        "well_id": well,
        "n_samples": int(len(gdf)),
        "macro_f1": float(f1_score(yt, yp, average="macro", labels=labels, zero_division=0)),
        "fault_vs_normal_f1": float(f1_score((yt != 0).astype(int), (yp != 0).astype(int), zero_division=0)),
    })

well_metrics_df = (pd.DataFrame(well_rows)
                   .sort_values(["macro_f1", "n_samples"], ascending=[True, False]))

well_metrics_path = f"{OUT_DIR}/hgb_holdout_per_well_metrics.csv"
well_metrics_df.to_csv(well_metrics_path, index=False)
print(f"\nSaved per-well metrics: {well_metrics_path}")

print("\nWorst 5 wells by macro-F1 (HGB holdout):")
print(well_metrics_df.head(5).to_string(index=False))

print("\nBest 5 wells by macro-F1 (HGB holdout):")
print(well_metrics_df.tail(5).to_string(index=False))


rep_dict = hgb_hold["classification_report_dict"]
rows = []
for l in labels:
    key = str(int(l))
    if key in rep_dict:
        rows.append({
            "class": int(l),
            "name": EVENT_TYPE_CODE_TO_NAME.get(int(l), "Unknown"),
            "precision": rep_dict[key]["precision"],
            "recall": rep_dict[key]["recall"],
            "f1": rep_dict[key]["f1-score"],
            "support": rep_dict[key]["support"],
        })
per_class_df = pd.DataFrame(rows).sort_values("class")
per_class_path = f"{OUT_DIR}/hgb_holdout_per_class_metrics.csv"
per_class_df.to_csv(per_class_path, index=False)
print(f"Saved HGB holdout per-class metrics table: {per_class_path}")

# Optional print: quick recall view
print("\nHGB holdout per-class recall (quick view):")
print(per_class_df[["class","name","recall","support"]].to_string(index=False))

print("\n=== Repeated Group Holdout (30 splits) ===")
logreg_rep = repeated_holdout(logreg, X, y, groups, repeats=30, test_size=0.2, random_state=RANDOM_STATE)
hgb_rep    = repeated_holdout(hgb,   X, y, groups, repeats=30, test_size=0.2, random_state=RANDOM_STATE)
print("LogReg:", logreg_rep)
print("HGB   :", hgb_rep)

# ============================================================
# Save results (JSON) + NEW: holdout wells + artifact paths
# ============================================================
results = {
    "features_version": FEATURES_VERSION,
    "random_state": RANDOM_STATE,

    "split_traceability": {
        "train_wells": train_wells,
        "holdout_wells": holdout_wells,
        "n_train_wells": len(train_wells),
        "n_holdout_wells": len(holdout_wells),
    },

    "logreg_cv": logreg_cv,
    "hgb_cv": hgb_cv,

    "logreg_holdout": {
        k: v for k, v in logreg_hold.items()
        if k not in {"classification_report","classification_report_dict","labels","y_pred"}
    },
    "hgb_holdout": {
        k: v for k, v in hgb_hold.items()
        if k not in {"classification_report","classification_report_dict","labels","y_pred"}
    },

    "holdout_diagnostics": {
    "hgb_confusion_matrix_csv": cm_path,
    "hgb_per_class_metrics_csv": per_class_path,
    "hgb_top_confusions_csv": top_conf_path,
    "hgb_permutation_importance_csv": perm_path,
    "hgb_per_well_metrics_csv": well_metrics_path,
},

    "logreg_repeated_holdout": logreg_rep,
    "hgb_repeated_holdout": hgb_rep,

    "notes": {
        "leakage_control": [
            "train-only column filtering inside pipeline",
            "group splitting by well_id",
            "CV on train wells only (plus a separate holdout)",
            "major classes computed from y_train per fold",
            "fixed-label macro-F1 with zero_division=0",
            "repeated group holdout metrics (mean/std)",
            "hybrid imputation: median for continuous; prob semantic fix + renorm; state-level->-1; + missing indicators",
            "cache path versioned to avoid stale features",
        ]
    }
}

with open(f"{OUT_DIR}/results_summary.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"\nSaved: {OUT_DIR}/results_summary.json")



Train wells: 32 | Holdout wells: 8
Holdout wells: ['WELL-00005', 'WELL-00013', 'WELL-00016', 'WELL-00019', 'WELL-00022', 'WELL-00029', 'WELL-00030', 'WELL-00040']
Holdout label counts:
 event_type_code
0    82
2     1
4    38
5     4
7     8
8     6
9     4
Name: count, dtype: int64

Using StratifiedGroupKFold with n_splits=4

=== CV on TRAIN ONLY ===
LogReg CV: {'macro_f1_mean': 0.23411257009793876, 'macro_f1_std': 0.11036159661262168, 'major_macro_f1_mean': 0.3466105934245418, 'fault_vs_normal_f1_mean': 0.6515985012280396, 'per_class_f1_mean': {'0': 0.022547945205479453, '1': 0.0, '2': 0.5570821185617104, '3': 0.0, '4': 0.35916489738145785, '5': 0.0, '6': 0.0, '7': 0.4297297297297298, '8': 0.25, '9': 0.7226010101010102}}
HGB   CV: {'macro_f1_mean': 0.414913949489049, 'macro_f1_std': 0.04739455413528661, 'major_macro_f1_mean': 0.6226991965158268, 'fault_vs_normal_f1_mean': 0.9249655129770701, 'per_class_f1_mean': {'0': 0.6873414934167565, '1': 0.08333333333333333, '2': 0.907407407407

## Results (Current)
- Logistic Regression (repeated group shuffle, 30): macro-F1 ≈ 0.22, fault-vs-normal F1 ≈ 0.61
- HistGradientBoosting (repeated group shuffle, 30): macro-F1 ≈ 0.42, fault-vs-normal F1 ≈ 0.76
- Note: Macro-F1 is unstable because some classes have very few samples (e.g., class 1 has 4 files).
