In [14]:
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Tuple
import locale
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")

import os
os.environ["PYTHONIOENCODING"] = "utf-8"
os.environ["LC_ALL"] = "en_US.UTF-8"

import json
import numpy as np
import pandas as pd
from pathlib import Path

# Robust repo root detection
try:
    # Running as a script
    REPO_ROOT = Path(__file__).resolve().parents[1]
except NameError:
    # Running in Colab or notebook
    REPO_ROOT = Path.cwd()

DEFAULT_DATA_DIR = REPO_ROOT / "data" / "eclss_synthetic_dataset_full"



# ============================================================
# CONFIGURATION
# ============================================================

@dataclass
class DatasetConfig:
    # Size
    n_samples_per_system_class: int = 30
    n_timesteps: int = 1000
    sampling_rate_hz: float = 2.0

    # Nominal values (ISS-aligned)
    O2_nominal: float = 20.9          # %
    CO2_nominal: float = 0.3          # % (~2 mmHg target)
    pressure_nominal_psi: float = 14.7

    # Nominal waveform amplitudes
    O2_amp: float = 0.3
    CO2_amp: float = 0.10
    P_amp: float = 0.3

    # Sensor noise
    O2_noise_std: float = 0.05
    CO2_noise_std: float = 0.02
    P_noise_std: float = 0.05

    # Slow drift (sensor aging / slow process drift)
    enable_drift: bool = True
    drift_std_O2: float = 0.01
    drift_std_CO2: float = 0.003
    drift_std_P: float = 0.01

    # Cycle timing jitter
    freq_jitter_range: Tuple[float, float] = (0.95, 1.05)

    # System fault severity ranges
    co2_leak_magnitude_range: Tuple[float, float] = (0.05, 0.40)
    valve_stiction_alpha_range: Tuple[float, float] = (0.85, 0.99)
    vacuum_drop_range: Tuple[float, float] = (0.5, 2.0)
    cdra_offset_range: Tuple[float, float] = (0.05, 0.5)
    oga_offset_range: Tuple[float, float] = (0.5, 4.0)

    # Sensor faults
    enable_sensor_faults: bool = True
    p_sensor_fault: float = 0.15  # 15% of cycles get a sensor fault

    # Sensor fault parameters
    bias_drift_max: float = 0.5
    high_noise_factor: float = 4.0
    freeze_min_fraction: float = 0.2
    freeze_max_fraction: float = 0.6
    n_spikes_min: int = 5
    n_spikes_max: int = 30
    spike_magnitude_range: Tuple[float, float] = (0.5, 3.0)

    # Physical limits (sanity clips)
    O2_min: float = 0.0
    O2_max: float = 100.0
    CO2_min: float = 0.0
    CO2_max: float = 5.0       # allow up to 5% for severe faults
    P_min: float = 0.0
    P_max: float = 20.0

    # NASA-inspired health thresholds (for flags only)
    CO2_warn: float = 0.7
    CO2_crit: float = 1.0
    O2_warn_low: float = 19.0   # closer to 20.9 nominal
    O2_crit_low: float = 16.0
    P_warn_low: float = 14.0
    P_warn_high: float = 15.4
    P_crit_low: float = 13.5
    P_crit_high: float = 15.8

    # Random seed
    random_seed: int = 42


# System modes (system-level faults)
SYSTEM_CLASSES = [
    (0, "Nominal"),
    (1, "CO2_Leak"),
    (2, "Valve_Stiction"),
    (3, "Vacuum_Anomaly"),
    (4, "CDRA_Degradation"),
    (5, "OGA_Degradation"),
]

# Sensor fault types
SENSOR_FAULTS = [
    (0, "None"),
    (1, "Bias_Drift"),
    (2, "High_Noise"),
    (3, "Partial_Freeze"),
    (4, "Spike_Outliers"),
]


# ============================================================
# BASELINE NOMINAL GENERATION
# ============================================================

def generate_nominal_cycle(cfg: DatasetConfig,
                           rng: np.random.Generator) -> np.ndarray:
    """
    Generate one nominal adsorption–desorption cycle.
    Returns array of shape (T, 3) with columns [O2, CO2, P].
    """
    T = cfg.n_timesteps
    t = np.linspace(0.0, 1.0, T)

    # Small random frequency jitter
    freq_scale = rng.uniform(*cfg.freq_jitter_range)

    # Periodic behavior
    O2 = cfg.O2_nominal + cfg.O2_amp * np.sin(2 * np.pi * freq_scale * t)
    CO2 = cfg.CO2_nominal - cfg.CO2_amp * np.sin(2 * np.pi * freq_scale * t)
    P = cfg.pressure_nominal_psi + cfg.P_amp * np.cos(2 * np.pi * freq_scale * t)

    # Slow drift (random walk)
    if cfg.enable_drift:
        O2 += np.cumsum(rng.normal(0.0, cfg.drift_std_O2, T)) / T
        CO2 += np.cumsum(rng.normal(0.0, cfg.drift_std_CO2, T)) / T
        P += np.cumsum(rng.normal(0.0, cfg.drift_std_P, T)) / T

    # Sensor noise
    O2 += rng.normal(0.0, cfg.O2_noise_std, T)
    CO2 += rng.normal(0.0, cfg.CO2_noise_std, T)
    P += rng.normal(0.0, cfg.P_noise_std, T)

    # Physical clipping
    O2 = np.clip(O2, cfg.O2_min, cfg.O2_max)
    CO2 = np.clip(CO2, cfg.CO2_min, cfg.CO2_max)
    P = np.clip(P, cfg.P_min, cfg.P_max)

    return np.stack([O2, CO2, P], axis=1)


# ============================================================
# SYSTEM FAULT MODELS
# ============================================================

def _interp_from_severity(severity: float, low: float, high: float) -> float:
    s = float(np.clip(severity, 0.0, 1.0))
    return low + s * (high - low)


def inject_co2_leak(cycle: np.ndarray, cfg: DatasetConfig,
                    rng: np.random.Generator, severity: float) -> np.ndarray:
    """CO₂ leak: gradual elevation of CO₂ after mid-cycle."""
    faulty = cycle.copy()
    T = faulty.shape[0]
    mid = T // 2

    mag = _interp_from_severity(severity, *cfg.co2_leak_magnitude_range)
    ramp = np.linspace(0.0, mag, T - mid)
    faulty[mid:, 1] += ramp

    faulty[:, 1] = np.clip(faulty[:, 1], cfg.CO2_min, cfg.CO2_max)
    return faulty


def _lowpass_first_order(x: np.ndarray, alpha: float) -> np.ndarray:
    """First-order IIR low-pass filter."""
    y = np.empty_like(x)
    y[0] = x[0]
    for t in range(1, len(x)):
        y[t] = alpha * y[t - 1] + (1.0 - alpha) * x[t]
    return y


def inject_valve_stiction(cycle: np.ndarray, cfg: DatasetConfig,
                          rng: np.random.Generator, severity: float) -> np.ndarray:
    """Valve stiction: slower pressure dynamics via low-pass filtering."""
    faulty = cycle.copy()
    alpha = _interp_from_severity(severity, *cfg.valve_stiction_alpha_range)
    P = faulty[:, 2]
    P_slow = _lowpass_first_order(P, alpha)
    faulty[:, 2] = np.clip(P_slow, cfg.P_min, cfg.P_max)
    return faulty


def inject_vacuum_anomaly(cycle: np.ndarray, cfg: DatasetConfig,
                          rng: np.random.Generator, severity: float) -> np.ndarray:
    """Vacuum anomaly: localized Gaussian-shaped pressure drop near mid-cycle."""
    faulty = cycle.copy()
    T = faulty.shape[0]
    center = T // 2

    base_width = 80
    width = int(base_width * (1.2 - 0.7 * severity))  # narrower for high severity
    drop_psi = _interp_from_severity(severity, *cfg.vacuum_drop_range)

    idx = np.arange(T)
    dist2 = (idx - center) ** 2
    sigma2 = (width / 3.0) ** 2
    pulse = drop_psi * np.exp(-dist2 / (2.0 * sigma2))

    P = faulty[:, 2]
    P_faulty = np.clip(P - pulse, cfg.P_min, cfg.P_max)
    faulty[:, 2] = P_faulty
    return faulty


def inject_cdra_degradation(cycle: np.ndarray, cfg: DatasetConfig,
                            rng: np.random.Generator, severity: float) -> np.ndarray:
    """CDRA degradation: whole-cycle CO₂ baseline elevated."""
    faulty = cycle.copy()
    offset = _interp_from_severity(severity, *cfg.cdra_offset_range)
    faulty[:, 1] += offset
    faulty[:, 1] = np.clip(faulty[:, 1], cfg.CO2_min, cfg.CO2_max)
    return faulty


def inject_oga_degradation(cycle: np.ndarray, cfg: DatasetConfig,
                           rng: np.random.Generator, severity: float) -> np.ndarray:
    """OGA degradation: whole-cycle O₂ baseline reduced."""
    faulty = cycle.copy()
    reduction = _interp_from_severity(severity, *cfg.oga_offset_range)
    faulty[:, 0] -= reduction
    faulty[:, 0] = np.clip(faulty[:, 0], cfg.O2_min, cfg.O2_max)
    return faulty


def apply_system_fault(system_class_id: int, cycle: np.ndarray,
                       cfg: DatasetConfig, rng: np.random.Generator,
                       severity: float) -> np.ndarray:
    """Dispatch to system fault injector based on class id."""
    if system_class_id == 0:
        return cycle  # Nominal
    elif system_class_id == 1:
        return inject_co2_leak(cycle, cfg, rng, severity)
    elif system_class_id == 2:
        return inject_valve_stiction(cycle, cfg, rng, severity)
    elif system_class_id == 3:
        return inject_vacuum_anomaly(cycle, cfg, rng, severity)
    elif system_class_id == 4:
        return inject_cdra_degradation(cycle, cfg, rng, severity)
    elif system_class_id == 5:
        return inject_oga_degradation(cycle, cfg, rng, severity)
    else:
        raise ValueError(f"Unknown system_class_id: {system_class_id}")


# ============================================================
# SENSOR FAULT MODELS
# ============================================================

def apply_sensor_bias_drift(cycle: np.ndarray, cfg: DatasetConfig,
                            rng: np.random.Generator) -> np.ndarray:
    """Slow additive drift over the cycle for one random sensor channel."""
    faulty = cycle.copy()
    T = faulty.shape[0]
    ch = rng.integers(0, 3)

    end_drift = rng.uniform(-cfg.bias_drift_max, cfg.bias_drift_max)
    drift = np.linspace(0.0, end_drift, T)
    faulty[:, ch] += drift
    return faulty


def apply_sensor_high_noise(cycle: np.ndarray, cfg: DatasetConfig,
                            rng: np.random.Generator) -> np.ndarray:
    """Increase noise for some channels."""
    faulty = cycle.copy()
    T = faulty.shape[0]

    n_ch = rng.integers(1, 4)
    channels = rng.choice([0, 1, 2], size=n_ch, replace=False)

    noise_scales = {
        0: cfg.O2_noise_std,
        1: cfg.CO2_noise_std,
        2: cfg.P_noise_std,
    }

    for ch in channels:
        std = noise_scales[ch] * cfg.high_noise_factor
        faulty[:, ch] += rng.normal(0.0, std, T)
    return faulty


def apply_sensor_partial_freeze(cycle: np.ndarray, cfg: DatasetConfig,
                                rng: np.random.Generator) -> np.ndarray:
    """Segment of the time series where the sensor output freezes."""
    faulty = cycle.copy()
    T = faulty.shape[0]
    ch = rng.integers(0, 3)

    frac_len = rng.uniform(cfg.freeze_min_fraction, cfg.freeze_max_fraction)
    seg_len = int(T * frac_len)
    if seg_len <= 1:
        return faulty

    start = rng.integers(0, T - seg_len)
    end = start + seg_len

    frozen_value = faulty[start, ch]
    faulty[start:end, ch] = frozen_value
    return faulty


def apply_sensor_spike_outliers(cycle: np.ndarray, cfg: DatasetConfig,
                                rng: np.random.Generator) -> np.ndarray:
    """Random spike outliers on a random channel."""
    faulty = cycle.copy()
    T = faulty.shape[0]
    ch = rng.integers(0, 3)

    n_spikes = rng.integers(cfg.n_spikes_min, cfg.n_spikes_max + 1)
    indices = rng.integers(0, T, size=n_spikes)

    sign = rng.choice([-1.0, 1.0], size=n_spikes)
    mag = rng.uniform(cfg.spike_magnitude_range[0],
                      cfg.spike_magnitude_range[1],
                      size=n_spikes)
    faulty[indices, ch] += sign * mag
    return faulty


def apply_random_sensor_fault(cycle: np.ndarray, cfg: DatasetConfig,
                              rng: np.random.Generator) -> Tuple[np.ndarray, int, str]:
    """Randomly choose a sensor fault type (excluding None) and apply."""
    fault_id, fault_name = SENSOR_FAULTS[rng.integers(1, len(SENSOR_FAULTS))]

    if fault_id == 1:
        new_cycle = apply_sensor_bias_drift(cycle, cfg, rng)
    elif fault_id == 2:
        new_cycle = apply_sensor_high_noise(cycle, cfg, rng)
    elif fault_id == 3:
        new_cycle = apply_sensor_partial_freeze(cycle, cfg, rng)
    elif fault_id == 4:
        new_cycle = apply_sensor_spike_outliers(cycle, cfg, rng)
    else:
        raise ValueError(f"Unexpected sensor fault id: {fault_id}")

    return new_cycle, fault_id, fault_name


# ============================================================
# SAFETY FLAGS
# ============================================================

def compute_safety_flags(cycle: np.ndarray, cfg: DatasetConfig) -> Dict[str, bool]:
    """Compute simple NASA-inspired safety flags on a cycle."""
    O2 = cycle[:, 0]
    CO2 = cycle[:, 1]
    P = cycle[:, 2]

    mean_O2 = float(O2.mean())
    max_CO2 = float(CO2.max())
    mean_P = float(P.mean())

    return {
        "flag_CO2_warn": max_CO2 > cfg.CO2_warn,
        "flag_CO2_crit": max_CO2 > cfg.CO2_crit,
        "flag_O2_warn_low": mean_O2 < cfg.O2_warn_low,
        "flag_O2_crit_low": mean_O2 < cfg.O2_crit_low,
        "flag_P_warn": not (cfg.P_warn_low <= mean_P <= cfg.P_warn_high),
        "flag_P_crit": not (cfg.P_crit_low <= mean_P <= cfg.P_crit_high),
    }


# ============================================================
# DATASET GENERATION
# ============================================================

def generate_dataset(cfg: DatasetConfig,
                     out_dir: str | Path = DEFAULT_DATA_DIR
                    ) -> Dict[str, np.ndarray]:
    """
    Generate full dataset with:
      - 6 system classes
      - 5 sensor fault types
      - progressive severity per mode
      - safety flags
    """
    rng = np.random.default_rng(cfg.random_seed)
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    all_cycles: List[np.ndarray] = []
    system_labels: List[int] = []
    sensor_labels: List[int] = []
    meta_rows: List[Dict] = []

    sample_id = 0

    for class_id, class_name in SYSTEM_CLASSES:
        n = cfg.n_samples_per_system_class

        for k in range(n):
            # Severity: linearly increasing with some noise
            base_frac = (k + 0.5) / n
            severity = float(np.clip(base_frac + rng.normal(0.0, 0.1), 0.0, 1.0))

            cycle_nominal = generate_nominal_cycle(cfg, rng)
            cycle_sys = apply_system_fault(class_id, cycle_nominal, cfg, rng, severity)

            # Sensor fault?
            if cfg.enable_sensor_faults and rng.random() < cfg.p_sensor_fault:
                cycle_faulty, sensor_fault_id, sensor_fault_name = apply_random_sensor_fault(
                    cycle_sys, cfg, rng
                )
            else:
                cycle_faulty = cycle_sys
                sensor_fault_id, sensor_fault_name = SENSOR_FAULTS[0]

            # Enforce physical sensor range after sensor faults
            cycle_final = cycle_faulty.copy()
            cycle_final[:, 0] = np.clip(cycle_final[:, 0], cfg.O2_min, cfg.O2_max)
            cycle_final[:, 1] = np.clip(cycle_final[:, 1], cfg.CO2_min, cfg.CO2_max)
            cycle_final[:, 2] = np.clip(cycle_final[:, 2], cfg.P_min, cfg.P_max)

            flags = compute_safety_flags(cycle_final, cfg)


            all_cycles.append(cycle_final)
            system_labels.append(class_id)
            sensor_labels.append(sensor_fault_id)

            row = {
                "sample_id": sample_id,
                "system_class_id": class_id,
                "system_class_name": class_name,
                "sensor_fault_id": sensor_fault_id,
                "sensor_fault_name": sensor_fault_name,
                "severity": severity,
                "replicate_idx": k,
            }
            row.update(flags)
            meta_rows.append(row)

            sample_id += 1

    data_3d = np.stack(all_cycles, axis=0)             # (N, T, 3)
    labels_system = np.array(system_labels, dtype=int)
    labels_sensor = np.array(sensor_labels, dtype=int)
    data_flat = data_3d.reshape(data_3d.shape[0], -1)  # (N, 3T)

    # Quick sanity prints
    print("=== BASIC VALIDATION (quick) ===")
    print("Shape (N, T, C):", data_3d.shape)
    print("System class counts:", np.bincount(labels_system))
    print("Sensor fault counts:", np.bincount(labels_sensor))

    O2 = data_3d[:, :, 0]
    CO2 = data_3d[:, :, 1]
    P = data_3d[:, :, 2]

    print("O2 in [0,100]:", bool(np.all((O2 >= cfg.O2_min) & (O2 <= cfg.O2_max))))
    print("CO2 in [0,5]:  ", bool(np.all((CO2 >= cfg.CO2_min) & (CO2 <= cfg.CO2_max))))
    print("P in [0,20]:   ", bool(np.all((P >= cfg.P_min) & (P <= cfg.P_max))))

    df_meta = pd.DataFrame(meta_rows)
    for class_id, class_name in SYSTEM_CLASSES:
        idx = np.where(labels_system == class_id)[0]
        O2_c = O2[idx].mean()
        CO2_c = CO2[idx].mean()
        P_c = P[idx].mean()
        print(f"Class {class_id} ({class_name}): "
              f"mean O2={O2_c:.3f}, CO2={CO2_c:.3f}, P={P_c:.3f}")

    # Save arrays + metadata
    np.save(out_path / "cycles_3d.npy", data_3d)
    np.save(out_path / "cycles_flat.npy", data_flat)
    np.save(out_path / "labels_system.npy", labels_system)
    np.save(out_path / "labels_sensor.npy", labels_sensor)

    df_meta.to_csv(out_path / "metadata.csv", index=False)
    with open(out_path / "config.json", "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    print(f"\nSaved dataset to: {out_path.resolve()}")

    return {
        "data_3d": data_3d,
        "data_flat": data_flat,
        "labels_system": labels_system,
        "labels_sensor": labels_sensor,
        "metadata": df_meta,
        "out_dir": out_dir,
    }


# ============================================================
# COMPREHENSIVE VALIDATION
# ============================================================

def validate_dataset(data_3d: np.ndarray,
                     labels_system: np.ndarray,
                     cfg: DatasetConfig) -> bool:
    """
    Comprehensive validation:
      - Physical constraints
      - NaN / Inf
      - Temporal coherence (lag-1 autocorr)
      - Optional class separability (PCA + silhouette, if sklearn available)
    """
    print("\n" + "=" * 60)
    print("COMPREHENSIVE VALIDATION")
    print("=" * 60)

    all_pass = True

    O2 = data_3d[:, :, 0]
    CO2 = data_3d[:, :, 1]
    P = data_3d[:, :, 2]

    # 1. Physical constraints
    O2_valid = np.all((O2 >= cfg.O2_min) & (O2 <= cfg.O2_max))
    CO2_valid = np.all((CO2 >= cfg.CO2_min) & (CO2 <= cfg.CO2_max))
    P_valid = np.all((P >= cfg.P_min) & (P <= cfg.P_max))

    print("\n1. Physical Constraints:")
    print(f"   O2 in [{cfg.O2_min}, {cfg.O2_max}]: {'✅' if O2_valid else '❌'}")
    print(f"   CO2 in [{cfg.CO2_min}, {cfg.CO2_max}]: {'✅' if CO2_valid else '❌'}")
    print(f"   P in [{cfg.P_min}, {cfg.P_max}]:   {'✅' if P_valid else '❌'}")

    all_pass &= (O2_valid and CO2_valid and P_valid)

    # 2. NaN / Inf
    has_nan = np.isnan(data_3d).any()
    has_inf = np.isinf(data_3d).any()

    print("\n2. Invalid Values:")
    print(f"   No NaN: {'✅' if not has_nan else '❌'}")
    print(f"   No Inf: {'✅' if not has_inf else '❌'}")

    all_pass &= (not has_nan and not has_inf)

    # 3. Temporal coherence via lag-1 autocorrelation
    print("\n3. Temporal Coherence:")
    autocorrs = []
    n_samples = min(20, len(data_3d))
    for i in range(n_samples):
        for sensor_idx in range(3):
            signal = data_3d[i, :, sensor_idx]
            if signal.std() > 0:
                x = signal[:-1]
                y = signal[1:]
                cov = np.cov(x, y)[0, 1]
                denom = np.std(x) * np.std(y)
                if denom > 0:
                    acf = cov / denom
                    autocorrs.append(acf)

    if autocorrs:
        mean_acf = float(np.mean(autocorrs))
        temporal_ok = mean_acf > 0.7
        print(f"   Mean lag-1 autocorr: {mean_acf:.3f} "
              f"{'✅' if temporal_ok else '⚠️'}")
    else:
        temporal_ok = False
        print("   Could not compute autocorrelation (degenerate signals) ⚠️")

    # 4. Class separability (optional, requires sklearn)
    print("\n4. Class Separability (PCA + silhouette):")
    try:
        from sklearn.decomposition import PCA
        from sklearn.metrics import silhouette_score

        data_flat = data_3d.reshape(len(data_3d), -1)
        pca = PCA(n_components=2)
        data_2d = pca.fit_transform(data_flat)
        sil_score = float(silhouette_score(data_2d, labels_system))
        separable = sil_score > 0.2
        print(f"   Silhouette score (2D PCA): {sil_score:.3f} "
              f"{'✅' if separable else '⚠️'}")
    except ImportError:
        print("   sklearn not available → skipping separability check (OK).")

    print("\n" + "=" * 60)
    print(f"OVERALL PHYSICAL/NUMERIC: {'✅ PASS' if all_pass else '❌ FAIL'}")
    print("=" * 60 + "\n")

    return all_pass


# ============================================================
# VISUAL VALIDATION (OVERLAY PLOTS)
# ============================================================

def visualize_samples(data_3d: np.ndarray,
                      labels_system: np.ndarray,
                      out_dir: str = "eclss_synthetic_dataset_full") -> None:
    """Generate overlay plots (O2, CO2, P) for each system class."""
    try:
        import matplotlib.pyplot as plt
    except ImportError:
        print("Matplotlib not available, skipping visualization.")
        return

    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    n_classes = len(SYSTEM_CLASSES)
    sensor_names = ['O2 (%)', 'CO2 (%)', 'Pressure (psi)']

    fig, axes = plt.subplots(n_classes, 3, figsize=(15, 2.8 * n_classes), sharex=True)
    if n_classes == 1:
        axes = np.expand_dims(axes, axis=0)  # ensure 2D

    for row_idx, (class_id, class_name) in enumerate(SYSTEM_CLASSES):
        class_data = data_3d[labels_system == class_id]
        if class_data.size == 0:
            continue

        for sensor_idx, sensor_name in enumerate(sensor_names):
            ax = axes[row_idx, sensor_idx]

            n_plot = min(10, len(class_data))
            for sample in class_data[:n_plot]:
                ax.plot(sample[:, sensor_idx], alpha=0.3, linewidth=0.8)

            mean_signal = class_data[:, :, sensor_idx].mean(axis=0)
            ax.plot(mean_signal, 'k-', linewidth=2.0, label='Mean')

            if sensor_idx == 0:
                ax.set_ylabel(class_name, fontsize=10, fontweight='bold')
            if row_idx == 0:
                ax.set_title(sensor_name, fontsize=11, fontweight='bold')
            if row_idx == n_classes - 1:
                ax.set_xlabel('Timestep', fontsize=9)

            ax.grid(True, alpha=0.3)
            ax.legend(fontsize=7, loc='upper right')

    plt.tight_layout()
    fig_file = out_path / "dataset_validation_overlay.png"
    plt.savefig(fig_file, dpi=150, bbox_inches='tight')
    plt.close(fig)
    print(f"✅ Visual validation saved: {fig_file}")


# ============================================================
# SUMMARY REPORT
# ============================================================

def generate_summary_report(data_3d: np.ndarray,
                            labels_system: np.ndarray,
                            labels_sensor: np.ndarray,
                            df_meta: pd.DataFrame,
                            cfg: DatasetConfig,
                            out_dir: str) -> None:
    """Generate a human-readable markdown summary report."""
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    O2 = data_3d[:, :, 0]
    CO2 = data_3d[:, :, 1]
    P = data_3d[:, :, 2]

    report = f"""# ECLSS Synthetic Dataset Summary

## Dataset Statistics

- Total Samples: {len(data_3d)}
- Timesteps per Sample: {cfg.n_timesteps}
- Sensors: 3 (O₂, CO₂, Pressure)
- Sampling Rate: {cfg.sampling_rate_hz} Hz
- Cycle Duration: {cfg.n_timesteps / cfg.sampling_rate_hz / 60:.1f} minutes

## System Class Distribution

| Class ID | Class Name | Count | Percentage |
|----------|------------|-------|------------|
"""
    for class_id, class_name in SYSTEM_CLASSES:
        count = int((labels_system == class_id).sum())
        pct = count / len(labels_system) * 100
        report += f"| {class_id} | {class_name} | {count} | {pct:.1f}% |\n"

    report += "\n## Sensor Fault Distribution\n\n"
    report += "| Fault ID | Fault Name | Count | Percentage |\n"
    report += "|----------|------------|-------|------------|\n"

    for fault_id, fault_name in SENSOR_FAULTS:
        count = int((labels_sensor == fault_id).sum())
        pct = count / len(labels_sensor) * 100
        report += f"| {fault_id} | {fault_name} | {count} | {pct:.1f}% |\n"

    # Safety flags summary
    report += "\n## Safety Flags Triggered\n\n"
    for col in df_meta.columns:
        if col.startswith("flag_"):
            count = int(df_meta[col].sum())
            pct = count / len(df_meta) * 100
            report += f"- {col}: {count} samples ({pct:.1f}%)\n"

    # Sensor statistics
    report += f"""

## Sensor Value Ranges

| Sensor | Min | Max | Mean | Std |
|--------|-----|-----|------|-----|
| O₂ (%) | {O2.min():.3f} | {O2.max():.3f} | {O2.mean():.3f} | {O2.std():.3f} |
| CO₂ (%) | {CO2.min():.3f} | {CO2.max():.3f} | {CO2.mean():.3f} | {CO2.std():.3f} |
| Pressure (psi) | {P.min():.3f} | {P.max():.3f} | {P.mean():.3f} | {P.std():.3f} |

## Configuration (JSON)

{json.dumps(asdict(cfg), indent=2)}

---

*Generated on {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}*
"""
    # ============================================================
    # Write summary file
    # ============================================================
    summary_file = out_path / "dataset_summary.md"
    with open(summary_file, "w", encoding="utf-8") as f:
        f.write(report)

    print(f"✅ Summary report saved: {summary_file}")



# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    cfg = DatasetConfig(
        n_samples_per_system_class=30,
        n_timesteps=1000,
        sampling_rate_hz=2.0,
    )

    out = generate_dataset(cfg)
    passed = validate_dataset(out["data_3d"], out["labels_system"], cfg)

    if passed:
        visualize_samples(out["data_3d"], out["labels_system"], out["out_dir"])
        generate_summary_report(
            out["data_3d"],
            out["labels_system"],
            out["labels_sensor"],
            out["metadata"],
            cfg,
            out["out_dir"],
        )
        print("\n✅ Dataset generation, validation, plots, and summary complete.")
        print("   Ready for VAE/SVM training.")
    else:
        print("\n❌ Validation failed. Please inspect logs before training.")


=== BASIC VALIDATION (quick) ===
Shape (N, T, C): (180, 1000, 3)
System class counts: [30 30 30 30 30 30]
Sensor fault counts: [153   5   4   8  10]
O2 in [0,100]: True
CO2 in [0,5]:   True
P in [0,20]:    True
Class 0 (Nominal): mean O2=20.908, CO2=0.301, P=14.697
Class 1 (CO2_Leak): mean O2=20.901, CO2=0.356, P=14.708
Class 2 (Valve_Stiction): mean O2=20.901, CO2=0.305, P=14.702
Class 3 (Vacuum_Anomaly): mean O2=20.902, CO2=0.300, P=14.635
Class 4 (CDRA_Degradation): mean O2=20.901, CO2=0.571, P=14.700
Class 5 (OGA_Degradation): mean O2=18.693, CO2=0.302, P=14.695

Saved dataset to: C:\Users\ahasa\data\eclss_synthetic_dataset_full

COMPREHENSIVE VALIDATION

1. Physical Constraints:
   O2 in [0.0, 100.0]: ✅
   CO2 in [0.0, 5.0]: ✅
   P in [0.0, 20.0]:   ✅

2. Invalid Values:
   No NaN: ✅
   No Inf: ✅

3. Temporal Coherence:
   Mean lag-1 autocorr: 0.890 ✅

4. Class Separability (PCA + silhouette):
   Silhouette score (2D PCA): 0.209 ✅

OVERALL PHYSICAL/NUMERIC: ✅ PASS

✅ Visual valida