# Apex Weld Quality – Phase 2 Dashboard
## Defect Detection (Binary Classification) with Confidence

This notebook is the **evaluation dashboard** for Phase 2. It covers:

1. **Model & calibration loading** – checkpoint, temperature, threshold
2. **Dataset overview** – counts, durations, missing/corrupt stats
3. **Label distribution** – defect vs non-defect, defect-type counts
4. **Representative examples** – sensor signals, video frame preview, audio waveform/spectrogram
5. **Data quality indicators** – class imbalance, outliers, noise
6. **Calibrated predictions** – val and test sets
7. **Core binary metrics** – accuracy, precision, recall, F1, AUC-ROC, AUC-PR, Brier score
8. **Threshold analysis** – sweep plot, operating point
9. **Error breakdown** – false positives / false negatives with examples
10. **Calibration curve** – reliability diagram
11. **Exportable reports** – save all plots/tables, generate summary

---

In [None]:
# ── Imports & Setup ───────────────────────────────────────────
import sys, os, json, logging, warnings
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

try:
    import plotly.express as px
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    HAS_PLOTLY = True
except ImportError:
    HAS_PLOTLY = False

import torch
from torch.utils.data import DataLoader
from PIL import Image
import scipy.signal

# Project modules
PROJECT_ROOT = Path().resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import (
    GOOD_WELD_DIR, DEFECT_WELD_DIR, LABELS_CSV, SPLIT_DIR, OUTPUT_DIR,
    DASHBOARD_DIR, MODEL_DIR, LABEL_COL, SAMPLE_ID_COL, LABEL_MAP,
    SENSOR_COLUMNS, FIXED_SEQ_LEN, IMAGE_SIZE, BATCH_SIZE,
    DEFECT_TYPES, CATEGORY_COL, DEFECT_TYPE_COL,
)
from src.data_ingestion import ingest
from src.feature_engineering import build_feature_table, sensor_to_fixed_tensor
from src.splitter import load_split
from src.dataset import WeldDataset, compute_normalize_stats
from src.trainer import WeldClassifier, _get_device
from src.calibration import TemperatureScaler, fit_temperature, predict_calibrated
from src.evaluation import (
    compute_binary_metrics, threshold_sweep, select_threshold,
    error_breakdown, full_evaluation_report,
    plot_confusion_matrix, plot_roc_and_pr, plot_calibration,
    plot_threshold_sweep, plot_error_examples,
)

# Style
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
warnings.filterwarnings("ignore", category=FutureWarning)
sns.set_theme(style="whitegrid", palette="muted", font_scale=1.05)
plt.rcParams.update({"figure.dpi": 120, "savefig.dpi": 150, "figure.figsize": (12, 5)})

DASHBOARD_DIR.mkdir(parents=True, exist_ok=True)
print("All imports OK ✓")

## 1. Data Ingestion & Dataset Overview

Ingest all weld runs from `good_weld/` and `defect-weld/`, compute features, and display summary statistics: total runs, durations, sensor row counts, missing/corrupt data.

In [None]:
# ── Ingest all data ───────────────────────────────────────────
manifest, sensor_data = ingest(GOOD_WELD_DIR, DEFECT_WELD_DIR)

n_good = (manifest[LABEL_COL] == 0).sum()
n_defect = (manifest[LABEL_COL] == 1).sum()
durations = manifest["duration_s"].dropna()

print(f"{'=' * 55}")
print(f"  DATASET OVERVIEW")
print(f"{'=' * 55}")
print(f"  Total weld runs     : {len(manifest)}")
print(f"  Labelled good       : {n_good}")
print(f"  Labelled defect     : {n_defect}")
print(f"  Sensor channels     : {len(SENSOR_COLUMNS)}  ({', '.join(SENSOR_COLUMNS)})")
print(f"  Rows per run        : {manifest['n_sensor_rows'].min()}–{manifest['n_sensor_rows'].max()}"
      f"  (mean={manifest['n_sensor_rows'].mean():.0f})")
print(f"  Duration (s)        : {durations.min():.1f}–{durations.max():.1f}"
      f"  (mean={durations.mean():.1f})")
print(f"  Images per run      : {manifest['n_images'].min()}–{manifest['n_images'].max()}")
n_issues = manifest["issues"].apply(len).gt(0).sum()
print(f"  Runs with issues    : {n_issues} / {len(manifest)}")
print(f"{'=' * 55}")

# Overview plots
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
fig.suptitle("Dataset Overview", fontsize=15, fontweight="bold")

axes[0, 0].hist(durations, bins=20, color="steelblue", edgecolor="white")
axes[0, 0].axvline(durations.mean(), color="red", ls="--", label=f"mean={durations.mean():.1f}s")
axes[0, 0].set_xlabel("Duration (s)"); axes[0, 0].set_ylabel("Count")
axes[0, 0].set_title("Weld Duration Distribution"); axes[0, 0].legend()

axes[0, 1].hist(manifest["n_sensor_rows"], bins=20, color="darkorange", edgecolor="white")
axes[0, 1].set_xlabel("Rows"); axes[0, 1].set_ylabel("Count")
axes[0, 1].set_title("Sensor Readings per Run")

# Missing data heatmap
miss_data = []
for _, row in manifest.iterrows():
    sid = row[SAMPLE_ID_COL]
    sdf = sensor_data[sid]
    nan_pct = sdf[SENSOR_COLUMNS].isnull().mean() * 100
    miss_data.append(nan_pct.values)
miss_arr = np.array(miss_data)
axes[1, 0].imshow(miss_arr.T, aspect="auto", cmap="Reds", interpolation="nearest")
axes[1, 0].set_yticks(range(len(SENSOR_COLUMNS)))
axes[1, 0].set_yticklabels(SENSOR_COLUMNS, fontsize=8)
axes[1, 0].set_xlabel("Run index"); axes[1, 0].set_title("Missing Data (% NaN per channel)")

# Global sensor stats
all_sensor = pd.concat([sensor_data[sid][SENSOR_COLUMNS] for sid in sensor_data], ignore_index=True)
stats = all_sensor.describe().T[["mean", "std", "min", "max"]].round(2)
axes[1, 1].axis("off")
tbl = axes[1, 1].table(cellText=stats.values, rowLabels=stats.index, colLabels=stats.columns,
                         loc="center", cellLoc="center")
tbl.auto_set_font_size(False); tbl.set_fontsize(9); tbl.scale(1.2, 1.4)
axes[1, 1].set_title("Global Sensor Statistics", fontsize=12, pad=20)

plt.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig(DASHBOARD_DIR / "p2_01_dataset_overview.png", bbox_inches="tight")
plt.show()

## 2. Label Distribution – Defect vs Non-Defect & Defect Sub-Types

In [None]:
# ── Label distribution ────────────────────────────────────────
labelled = manifest[manifest[LABEL_COL].notna()].copy()
labelled["label_name"] = labelled[LABEL_COL].map(LABEL_MAP)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# (a) Binary: good vs defect
vc = labelled["label_name"].value_counts()
colors_bin = ["#4CAF50" if x == "good" else "#F44336" for x in vc.index]
axes[0].bar(vc.index, vc.values, color=colors_bin, edgecolor="white", width=0.5)
for i, (lbl, cnt) in enumerate(vc.items()):
    axes[0].text(i, cnt + 3, str(cnt), ha="center", fontweight="bold")
axes[0].set_title("Good vs Defect")
axes[0].set_ylabel("# Runs")

# (b) Pie chart
axes[1].pie(vc.values, labels=vc.index, autopct="%1.0f%%", colors=colors_bin,
            startangle=90, textprops={"fontsize": 12})
axes[1].set_title("Label Proportions")

# (c) Defect sub-types
if DEFECT_TYPE_COL in manifest.columns:
    defect_only = manifest[manifest[LABEL_COL] == 1]
    dt_counts = defect_only[DEFECT_TYPE_COL].value_counts()
    palette = sns.color_palette("Set2", len(dt_counts))
    axes[2].barh(dt_counts.index, dt_counts.values, color=palette, edgecolor="white")
    for i, (dtype, cnt) in enumerate(dt_counts.items()):
        axes[2].text(cnt + 1, i, str(cnt), va="center", fontweight="bold")
    axes[2].set_xlabel("# Runs")
    axes[2].set_title("Defect Sub-Type Counts")
    axes[2].invert_yaxis()
else:
    axes[2].text(0.5, 0.5, "No defect_type column", transform=axes[2].transAxes,
                 ha="center", va="center")
    axes[2].set_title("Defect Sub-Types (N/A)")

plt.suptitle("Label Distribution Analysis", fontsize=14, fontweight="bold")
plt.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig(DASHBOARD_DIR / "p2_02_label_distribution.png", bbox_inches="tight")
plt.show()

# Class imbalance assessment
majority, minority = vc.max(), vc.min()
ratio = majority / minority if minority > 0 else float("inf")
print(f"\nClass imbalance ratio: {ratio:.1f}:1  (majority={vc.idxmax()}, minority={vc.idxmin()})")
if ratio > 3:
    print("⚠ Significant imbalance – using weighted loss + stratified splits")
else:
    print("✓ Imbalance within acceptable range")

## 3. Representative Examples – Sensor Signals, Video Frame, Audio Waveform

For selected good and defect runs, we visualise:
- **6-channel sensor time-series** with weld-phase annotations
- **Video frame preview** (first frame extracted from `.avi`)
- **Audio waveform & spectrogram** (from `.flac`)

In [None]:
# ── Helper: locate run directory from manifest ───────────────
def _find_run_dir(sid, manifest_df):
    """Locate the run directory for a sample_id by searching known bases."""
    row = manifest_df[manifest_df[SAMPLE_ID_COL] == sid].iloc[0]
    cat = row[CATEGORY_COL]
    lbl = row[LABEL_COL]
    base = GOOD_WELD_DIR if lbl == 0 else DEFECT_WELD_DIR
    run_dir = base / cat / sid
    if run_dir.exists():
        return run_dir
    # Fallback: search
    for p in base.rglob(sid):
        if p.is_dir():
            return p
    return None

# Pick example runs: 1 good, 1 defect per sub-type (up to 4 total)
example_ids = []
good_ids = manifest[manifest[LABEL_COL] == 0][SAMPLE_ID_COL].tolist()
if good_ids:
    example_ids.append(good_ids[len(good_ids)//2])  # middle good sample

if DEFECT_TYPE_COL in manifest.columns:
    for dtype in sorted(manifest[manifest[LABEL_COL] == 1][DEFECT_TYPE_COL].unique())[:3]:
        subset = manifest[(manifest[DEFECT_TYPE_COL] == dtype)]
        if len(subset) > 0:
            example_ids.append(subset[SAMPLE_ID_COL].iloc[len(subset)//2])
else:
    defect_ids = manifest[manifest[LABEL_COL] == 1][SAMPLE_ID_COL].tolist()
    if defect_ids:
        example_ids.append(defect_ids[0])

print(f"Plotting {len(example_ids)} representative runs\n")

for sid in example_ids:
    run_dir = _find_run_dir(sid, manifest)
    sdf = sensor_data[sid].copy()
    lbl_val = manifest[manifest[SAMPLE_ID_COL] == sid][LABEL_COL].values[0]
    lbl_name = LABEL_MAP.get(int(lbl_val), "?")
    dtype_name = ""
    if DEFECT_TYPE_COL in manifest.columns:
        dtype_name = manifest[manifest[SAMPLE_ID_COL] == sid][DEFECT_TYPE_COL].values[0]

    # Time axis
    if "datetime" in sdf.columns and sdf["datetime"].notna().any():
        t = (sdf["datetime"] - sdf["datetime"].min()).dt.total_seconds().values
    else:
        t = np.arange(len(sdf)) * 0.11

    # ── Build figure: 4 rows (sensors top 3 rows, media bottom row) ──
    fig = plt.figure(figsize=(18, 16))
    gs = gridspec.GridSpec(4, 3, figure=fig, height_ratios=[1, 1, 1, 1.2], hspace=0.35, wspace=0.3)
    title = f"Run: {sid}  |  {lbl_name.upper()}"
    if dtype_name and dtype_name != "good":
        title += f"  ({dtype_name})"
    fig.suptitle(title, fontsize=15, fontweight="bold")

    # Sensor signals (6 channels, 3×2 grid in top 3 rows)
    for i, col in enumerate(SENSOR_COLUMNS):
        row_i, col_i = i // 2, i % 2
        ax = fig.add_subplot(gs[row_i, col_i])
        ax.plot(t, sdf[col].values, linewidth=0.8, color="steelblue")
        ax.set_ylabel(col, fontsize=9)
        ax.grid(True, alpha=0.3)
        if row_i == 2:
            ax.set_xlabel("Time (s)")

        # Phase annotations
        current = sdf["Primary Weld Current"].values
        arc_on = np.where(current > 10.0)[0]
        if len(arc_on) > 0:
            t_start, t_end = t[arc_on[0]], t[arc_on[-1]]
            ax.axvspan(0, t_start, alpha=0.06, color="blue")
            ax.axvspan(t_end, t[-1], alpha=0.06, color="gray")

    # ── Image preview (top-right area) ──
    ax_img = fig.add_subplot(gs[0, 2])
    if run_dir:
        img_dir = run_dir / "images"
        imgs = sorted(img_dir.glob("*.jpg")) + sorted(img_dir.glob("*.png")) if img_dir.exists() else []
        if imgs:
            img = Image.open(imgs[len(imgs)//2]).convert("RGB")
            ax_img.imshow(img)
            ax_img.set_title(f"Inspection Photo ({len(imgs)} total)", fontsize=10)
        else:
            ax_img.text(0.5, 0.5, "No images", ha="center", va="center", transform=ax_img.transAxes)
    ax_img.axis("off")

    # ── Video frame preview ──
    ax_vid = fig.add_subplot(gs[1, 2])
    avi_path = run_dir / f"{sid}.avi" if run_dir else None
    if avi_path and avi_path.exists():
        try:
            import cv2
            cap = cv2.VideoCapture(str(avi_path))
            # Read frame at 25% into the video
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames // 4)
            ret, frame = cap.read()
            cap.release()
            if ret:
                ax_vid.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                ax_vid.set_title(f"Video Frame ({total_frames} frames)", fontsize=10)
            else:
                ax_vid.text(0.5, 0.5, "Could not read frame", ha="center", va="center",
                           transform=ax_vid.transAxes)
        except ImportError:
            ax_vid.text(0.5, 0.5, f"Video: {avi_path.name}\n(cv2 not installed)", ha="center",
                       va="center", transform=ax_vid.transAxes, fontsize=9)
    else:
        ax_vid.text(0.5, 0.5, "No .avi file", ha="center", va="center", transform=ax_vid.transAxes)
    ax_vid.axis("off")

    # ── Audio waveform + spectrogram ──
    ax_wave = fig.add_subplot(gs[2, 2])
    ax_spec = fig.add_subplot(gs[3, :])

    flac_path = run_dir / f"{sid}.flac" if run_dir else None
    audio_loaded = False
    if flac_path and flac_path.exists():
        try:
            import soundfile as sf
            audio, sr = sf.read(str(flac_path))
            audio_loaded = True
        except ImportError:
            pass
        if not audio_loaded:
            try:
                from scipy.io import wavfile as _wf
                # FLAC may not work with scipy; try anyway
                sr, audio = _wf.read(str(flac_path))
                audio = audio.astype(np.float32) / max(np.abs(audio).max(), 1)
                audio_loaded = True
            except Exception:
                pass

    if audio_loaded:
        if audio.ndim > 1:
            audio = audio.mean(axis=1)
        t_audio = np.arange(len(audio)) / sr

        # Waveform
        ax_wave.plot(t_audio, audio, linewidth=0.3, color="purple")
        ax_wave.set_ylabel("Amplitude")
        ax_wave.set_title(f"Audio Waveform (sr={sr} Hz, {len(audio)/sr:.1f}s)", fontsize=10)
        ax_wave.grid(True, alpha=0.3)

        # Spectrogram
        nperseg = min(1024, len(audio) // 4)
        f_spec, t_spec, Sxx = scipy.signal.spectrogram(audio, fs=sr, nperseg=nperseg)
        ax_spec.pcolormesh(t_spec, f_spec, 10 * np.log10(Sxx + 1e-10), shading="gouraud", cmap="magma")
        ax_spec.set_ylabel("Frequency (Hz)")
        ax_spec.set_xlabel("Time (s)")
        ax_spec.set_title("Audio Spectrogram", fontsize=11)
    else:
        if flac_path and flac_path.exists():
            ax_wave.text(0.5, 0.5, f"Audio: {flac_path.name}\n(no audio library to decode FLAC)",
                        ha="center", va="center", transform=ax_wave.transAxes, fontsize=9)
        else:
            ax_wave.text(0.5, 0.5, "No .flac file", ha="center", va="center", transform=ax_wave.transAxes)
        ax_wave.axis("off")
        ax_spec.text(0.5, 0.5, "Spectrogram unavailable", ha="center", va="center",
                    transform=ax_spec.transAxes)
        ax_spec.axis("off")

    fig.savefig(DASHBOARD_DIR / f"p2_03_example_{sid}.png", bbox_inches="tight")
    plt.show()
    print()

## 4. Data Quality Indicators – Class Imbalance, Outliers, Noise

Per-run sensor statistics, signal-to-noise ratio during arcing, IQR-based outlier detection, and constant-column analysis.

In [None]:
# ── Per-run quality stats ─────────────────────────────────────
run_stats = []
for sid, sdf in sensor_data.items():
    row = {"sample_id": sid}
    lbl_rows = manifest[manifest[SAMPLE_ID_COL] == sid]
    row["label"] = int(lbl_rows[LABEL_COL].values[0]) if len(lbl_rows) else -1

    for col in SENSOR_COLUMNS:
        vals = sdf[col].dropna()
        row[f"{col}__mean"] = vals.mean() if len(vals) else 0
        row[f"{col}__std"] = vals.std() if len(vals) else 0

    current = sdf["Primary Weld Current"].values
    arcing = current > 10.0
    if arcing.sum() > 5:
        arc_current = current[arcing]
        row["arc_current_snr"] = arc_current.mean() / arc_current.std() if arc_current.std() > 0 else np.inf
    else:
        row["arc_current_snr"] = 0.0
    run_stats.append(row)

run_stats_df = pd.DataFrame(run_stats).set_index("sample_id")

# ── Boxplots by class ────────────────────────────────────────
fig, axes = plt.subplots(2, 3, figsize=(16, 9))
fig.suptitle("Per-Run Sensor Statistics – Good vs Defect", fontsize=14, fontweight="bold")

for i, col in enumerate(SENSOR_COLUMNS):
    ax = axes[i // 3, i % 3]
    good_vals = run_stats_df[run_stats_df["label"] == 0][f"{col}__mean"].dropna()
    defect_vals = run_stats_df[run_stats_df["label"] == 1][f"{col}__mean"].dropna()
    bp = ax.boxplot([good_vals, defect_vals], labels=["Good", "Defect"], patch_artist=True)
    bp["boxes"][0].set_facecolor("#4CAF50"); bp["boxes"][0].set_alpha(0.6)
    bp["boxes"][1].set_facecolor("#F44336"); bp["boxes"][1].set_alpha(0.6)
    ax.set_title(f"{col} (mean)", fontsize=10)
    ax.grid(True, alpha=0.3)

plt.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig(DASHBOARD_DIR / "p2_04_quality_boxplots.png", bbox_inches="tight")
plt.show()

# ── Outlier detection ─────────────────────────────────────────
print("\n=== IQR Outlier Runs per Sensor ===")
outlier_sids = set()
for col in SENSOR_COLUMNS:
    means = run_stats_df[f"{col}__mean"].dropna()
    Q1, Q3 = means.quantile(0.25), means.quantile(0.75)
    IQR = Q3 - Q1
    mask = (means < Q1 - 1.5*IQR) | (means > Q3 + 1.5*IQR)
    n_out = mask.sum()
    outlier_sids.update(means[mask].index.tolist())
    print(f"  {col:30s}: {n_out} outlier runs")

print(f"\n  Total unique outlier runs: {len(outlier_sids)} / {len(run_stats_df)}")

# SNR summary
snr = run_stats_df["arc_current_snr"].replace([np.inf, -np.inf], np.nan).dropna()
print(f"\n  Arc Current SNR: mean={snr.mean():.2f}  median={snr.median():.2f}  "
      f"min={snr.min():.2f}  max={snr.max():.2f}")

## 5. Load Model, Calibrate, and Generate Predictions

Load the trained checkpoint, apply temperature scaling on the validation set, select the optimal threshold, and produce calibrated predictions for both val and test splits.

In [None]:
# ── Build features, datasets, loaders ─────────────────────────
feature_df = build_feature_table(manifest, sensor_data)
split_map = load_split()

# Norm stats from training set
norm_stats = compute_normalize_stats(sensor_data, split_map["train"])

datasets, loaders = {}, {}
for split_name, ids in split_map.items():
    ds = WeldDataset(
        manifest=manifest, sensor_data=sensor_data,
        sample_ids=ids, feature_df=feature_df, normalize_stats=norm_stats,
    )
    datasets[split_name] = ds
    loaders[split_name] = DataLoader(
        ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True,
    )

n_features = datasets["train"][0]["features"].shape[0]
n_channels = len(SENSOR_COLUMNS)

for k, v in split_map.items():
    print(f"  {k:5s}: {len(v)} samples")
print(f"  Feature vector size: {n_features}")

# ── Load model ────────────────────────────────────────────────
device = _get_device()
model_path = MODEL_DIR / "weld_classifier.pt"
model = WeldClassifier(n_channels, n_features).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print(f"\n✓ Model loaded from {model_path}")

# ── Temperature scaling ──────────────────────────────────────
scaler = fit_temperature(model, loaders["val"], device=device)
T = scaler.temperature.item()
print(f"✓ Learned temperature: T = {T:.4f}")

# ── Generate calibrated predictions ──────────────────────────
pred_dfs = {}
for split_name in ["val", "test"]:
    results = predict_calibrated(scaler, loaders[split_name], device=device)
    rows = []
    for sid, info in results.items():
        rows.append({
            "sample_id": sid,
            "label": info["label"],
            "p_defect": round(info["p_defect"], 6),
            "p_good": round(info["p_good"], 6),
            "confidence": round(info["confidence"], 6),
        })
    pred_dfs[split_name] = pd.DataFrame(rows)
    print(f"  {split_name}: {len(pred_dfs[split_name])} predictions")

# ── Threshold selection (on val ONLY) ────────────────────────
val_labelled = pred_dfs["val"][pred_dfs["val"]["label"].notna()].copy()
threshold = select_threshold(
    val_labelled["label"].values.astype(int),
    val_labelled["p_defect"].values,
    strategy="f1",
)
print(f"\n✓ Optimal threshold (F1 on val): {threshold:.4f}")

# Apply threshold
for split_name, df in pred_dfs.items():
    df["pred_defect"] = (df["p_defect"] >= threshold).astype(int)
    df["confidence"] = df.apply(
        lambda r: r["p_defect"] if r["pred_defect"] == 1 else (1 - r["p_defect"]), axis=1
    ).round(6)

# Save predictions_binary.csv
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
pred_dfs["test"].to_csv(OUTPUT_DIR / "predictions_binary.csv", index=False)
pred_dfs["val"].to_csv(OUTPUT_DIR / "predictions_binary_val.csv", index=False)
print(f"✓ Predictions saved")

## 6. Threshold Analysis – Metric Sweep

How do precision, recall, F1, and accuracy change as the decision threshold varies? The red dashed line marks the threshold chosen by maximising F1 on the validation set.

In [None]:
# ── Threshold sweep on validation set ─────────────────────────
sweep_df = threshold_sweep(
    val_labelled["label"].values.astype(int),
    val_labelled["p_defect"].values,
)
plot_threshold_sweep(sweep_df, threshold, save_path=DASHBOARD_DIR / "p2_05_threshold_sweep.png")
plt.show()

# Show metrics at the chosen threshold
m_val = compute_binary_metrics(
    val_labelled["label"].values.astype(int),
    val_labelled["p_defect"].values,
    threshold=threshold,
)
print(f"\n{'=' * 50}")
print(f"  VALIDATION METRICS @ threshold={threshold:.4f}")
print(f"{'=' * 50}")
for k in ["accuracy", "precision", "recall", "specificity", "f1", "roc_auc", "avg_precision", "brier_score"]:
    print(f"  {k:20s}: {m_val[k]:.4f}")
print(f"  TP={m_val['tp']}  FP={m_val['fp']}  FN={m_val['fn']}  TN={m_val['tn']}")

## 7. Test Set Evaluation – Core Binary Metrics

The threshold was fixed on the validation set. These are unseen test-set results with **no threshold tuning**.

In [None]:
# ── Full evaluation on TEST set ───────────────────────────────
test_labelled = pred_dfs["test"][pred_dfs["test"]["label"].notna()].copy()
report = full_evaluation_report(
    predictions_df=test_labelled,
    threshold=threshold,
    split_name="test",
    save_dir=DASHBOARD_DIR,
)
print(report["report_text"])

## 8. Confusion Matrix – Test Set

In [None]:
# ── Confusion matrix (already saved by full_evaluation_report) ─
y_true = test_labelled["label"].values.astype(int)
y_prob = test_labelled["p_defect"].values
y_pred = (y_prob >= threshold).astype(int)

plot_confusion_matrix(y_true, y_pred, save_path=DASHBOARD_DIR / "p2_06_confusion_matrix.png")
plt.show()

## 9. ROC Curve & Precision-Recall Curve – Test Set

In [None]:
plot_roc_and_pr(y_true, y_prob, threshold=threshold,
               save_path=DASHBOARD_DIR / "p2_07_roc_pr.png")
plt.show()

## 10. Calibration Analysis – Reliability Diagram

A well-calibrated model's predicted probabilities match the true fraction of positives. Temperature scaling helps close any gap.

In [None]:
plot_calibration(y_true, y_prob, n_bins=10, label="Temp-scaled",
                save_path=DASHBOARD_DIR / "p2_08_calibration.png")
plt.show()

# Brier score context
brier = report["metrics"]["brier_score"]
print(f"Brier Score: {brier:.4f}  (0 = perfect, 0.25 = random)")
if brier < 0.1:
    print("✓ Excellent calibration")
elif brier < 0.2:
    print("○ Reasonable calibration")
else:
    print("⚠ Poor calibration – consider re-calibration or different model")

## 11. Error Breakdown – False Positives & False Negatives

Which samples did the model get wrong? The bar charts show the worst misclassifications ranked by predicted probability.

In [None]:
# ── Error breakdown ───────────────────────────────────────────
fp_df, fn_df = error_breakdown(test_labelled, threshold=threshold)

plot_error_examples(fp_df, fn_df, top_n=10,
                   save_path=DASHBOARD_DIR / "p2_09_error_examples.png")
plt.show()

print(f"\nFalse Positives: {len(fp_df)}  (good runs wrongly flagged as defect)")
if len(fp_df) > 0:
    display(fp_df[["sample_id", "label", "p_defect", "confidence"]].head(10))

print(f"\nFalse Negatives: {len(fn_df)}  (defect runs missed by the model)")
if len(fn_df) > 0:
    display(fn_df[["sample_id", "label", "p_defect", "confidence"]].head(10))

## 12. Exportable Reports & Artefacts

Save a consolidated metrics CSV and a plain-text report that can be attached to pull-requests or audit logs.

In [None]:
# ── Export artefacts ──────────────────────────────────────────
# 1) Save a single-row metrics CSV for programmatic comparison
metrics_csv = DASHBOARD_DIR / "phase2_metrics_summary.csv"
pd.DataFrame([test_metrics]).to_csv(metrics_csv, index=False)
print(f"[✓] Metrics summary saved → {metrics_csv}")

# 2) Save the full text report next to the plots
report_path = DASHBOARD_DIR / "phase2_report.txt"
report_lines = [
    "=" * 60,
    " Phase 2 – Binary Classification Evaluation Report",
    "=" * 60,
    f"\nModel          : WeldClassifier (SensorCNN + FeatureMLP)",
    f"Threshold      : {threshold:.4f}  (strategy: F1-optimal on val set)",
    f"Temperature    : {temperature:.4f}",
    f"Test samples   : {len(test_labelled)}",
    "",
    "── Key Metrics ──────────────────────────────────────────",
    f"  Accuracy     : {test_metrics['accuracy']:.4f}",
    f"  Precision    : {test_metrics['precision']:.4f}",
    f"  Recall       : {test_metrics['recall']:.4f}",
    f"  F1 Score     : {test_metrics['f1']:.4f}",
    f"  Specificity  : {test_metrics['specificity']:.4f}",
    f"  ROC AUC      : {test_metrics['roc_auc']:.4f}",
    f"  PR AUC (AP)  : {test_metrics['avg_precision']:.4f}",
    f"  Brier Score  : {test_metrics['brier_score']:.4f}",
    f"  Log Loss     : {test_metrics['log_loss']:.4f}",
    "",
    "── Confusion Matrix (rows=actual, cols=predicted) ──────",
    f"  TN={test_metrics['tn']}  FP={test_metrics['fp']}",
    f"  FN={test_metrics['fn']}  TP={test_metrics['tp']}",
    "",
    "── Error Summary ────────────────────────────────────────",
    f"  False Positives : {len(fp_df)}  (good → flagged)",
    f"  False Negatives : {len(fn_df)}  (defect → missed)",
    "",
    "── Saved Artefacts ──────────────────────────────────────",
]

# List all saved plots
for p in sorted(DASHBOARD_DIR.glob("p2_*.png")):
    report_lines.append(f"  {p.name}")
report_lines.append(f"  phase2_metrics_summary.csv")
report_lines.append(f"  phase2_report.txt")
report_lines.append("\n" + "=" * 60)

report_text = "\n".join(report_lines)
report_path.write_text(report_text, encoding="utf-8")
print(f"[✓] Full report saved → {report_path}")
print("\n" + report_text)

## 13. Phase 2 – Data Card

| Field | Value |
|---|---|
| **Task** | Binary weld defect detection (good vs defect) |
| **Model** | `WeldClassifier` — dual-branch: 1-D SensorCNN (6 channels) + Feature MLP (104 dims) → 192-d fusion → 2-class logits |
| **Calibration** | Temperature scaling (Guo et al., ICML 2017) fitted on validation set |
| **Threshold** | F1-optimal on validation set, saved in `threshold.json` |
| **Confidence** | `max(p_defect, 1 − p_defect)` after temperature calibration |
| **Output** | `predictions_binary.csv` — columns: `sample_id, p_defect, pred_defect, confidence, label` |
| **Dataset** | 1 552 weld runs (good_weld + defect-weld), 70 / 15 / 15 % train / val / test split |
| **Input per run** | 6-channel sensor time-series (padded to 400 steps) + 104 engineered features |
| **Sensor columns** | current, voltage, wire_feed_speed, gas_flow_rate, heat_input, energy |
| **Normalisation** | Per-channel z-score computed on training set (`normalize_stats.json`) |
| **Artefacts** | `best_model.pt`, `temperature.json`, `threshold.json`, `normalize_stats.json`, `predictions_binary.csv` |
| **Limitations** | Trained on GTAW butt joints only; may not generalise to other joint types or welding processes |