<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/visualize_states_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
import argparse
import json
import math
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import matplotlib.pyplot as plt

try:
    import pandas as pd
except ImportError:
    pd = None


def load_artifacts(dir_path: Path) -> List[Dict]:
    artifacts = []
    for p in sorted(dir_path.glob("*.json")):
        try:
            with p.open("r") as f:
                data = json.load(f)
            data["_file"] = str(p)
            artifacts.append(data)
        except Exception:
            continue
    return artifacts


def pick_artifact(artifacts: List[Dict], artifact_file: Optional[str]) -> Dict:
    if artifact_file:
        for a in artifacts:
            if a.get("_file") == artifact_file:
                return a
        raise FileNotFoundError(f"Artifact not found: {artifact_file}")
    if not artifacts:
        raise FileNotFoundError("No JSON artifacts found.")
    # Prefer the highest numeric 'score' if available
    def score_of(a):
        s = a.get("score")
        try:
            return float(s)
        except Exception:
            return -math.inf
    artifacts_sorted = sorted(artifacts, key=score_of, reverse=True)
    return artifacts_sorted[0]


def extract_A(art: Dict) -> np.ndarray:
    # Prefer explicit matrix fields
    for key in ("A", "coef_matrix", "coeff_matrix", "matrix"):
        if key in art:
            A = np.array(art[key], dtype=float)
            if A.ndim == 2:
                return A
    # Fallback: parse linear equations if present (very simple parser)
    eqs = art.get("equations") or art.get("system") or []
    # Example: ["dx/dt = -1.0*x + 0.0*y", "dy/dt = 0.2*x - 0.8*y"]
    if isinstance(eqs, list) and eqs:
        vars_order = []
        # collect variable names from left-hand sides
        for eq in eqs:
            if "/dt" in eq:
                lhs = eq.split("=")[0].strip()
                var = lhs.split("/dt")[0].strip().lstrip("d")
                vars_order.append(var)
        vars_order = [v for v in vars_order if v]
        n = len(vars_order)
        A = np.zeros((n, n), dtype=float)
        # parse RHS coefficients for patterns "c*var"
        for i, eq in enumerate(eqs):
            if "=" not in eq:
                continue
            rhs = eq.split("=")[1]
            terms = rhs.replace("-", "+-").split("+")
            for term in terms:
                t = term.strip()
                if not t:
                    continue
                if "*" in t:
                    c_str, var = [x.strip() for x in t.split("*", 1)]
                    try:
                        c = float(c_str)
                    except Exception:
                        continue
                else:
                    # constant term not supported for pure linear homogeneous systems; skip
                    continue
                if var in vars_order:
                    j = vars_order.index(var)
                    A[i, j] += c
        return A
    raise ValueError("Could not extract A-matrix from artifact.")


def extract_truth_from_artifact(art: Dict) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    # Try several common layouts
    # 1) time: [..], true_state: [[..],[..],...]
    t = None
    X_true = None
    if "time" in art and isinstance(art["time"], list):
        t = np.array(art["time"], dtype=float)
    for key in ("true_state", "X_true", "y_true", "states_true"):
        if key in art and isinstance(art[key], list):
            X_true = np.array(art[key], dtype=float)
            break
    # Sometimes nested under 'data' or 'dataset'
    for container in ("data", "dataset"):
        if X_true is None and container in art and isinstance(art[container], dict):
            d = art[container]
            if t is None and "time" in d and isinstance(d["time"], list):
                t = np.array(d["time"], dtype=float)
            for key in ("true_state", "X_true", "y_true", "states_true"):
                if key in d and isinstance(d[key], list):
                    X_true = np.array(d[key], dtype=float)
                    break
    # Ensure shape (T, n)
    if X_true is not None and X_true.ndim == 1:
        X_true = X_true.reshape(-1, 1)
    return t, X_true


def load_csv_truth(csv_path: Optional[str], time_col: Optional[str], state_cols: List[str]) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    if not csv_path:
        return None, None
    if pd is None:
        raise RuntimeError("pandas is required for CSV input. Install with: pip install pandas")
    df = pd.read_csv(csv_path)
    if time_col and time_col in df.columns:
        t = df[time_col].to_numpy(dtype=float)
    else:
        t = None
    if not state_cols:
        # all numeric columns except time
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        state_cols = [c for c in numeric_cols if c != (time_col or "")]
    X_true = df[state_cols].to_numpy(dtype=float)
    return t, X_true


def rk4_step(f, t, x, dt):
    k1 = f(t, x)
    k2 = f(t + 0.5*dt, x + 0.5*dt*k1)
    k3 = f(t + 0.5*dt, x + 0.5*dt*k2)
    k4 = f(t + dt, x + dt*k3)
    return x + (dt/6.0)*(k1 + 2*k2 + 2*k3 + k4)


def integrate_linear(A: np.ndarray, t: np.ndarray, x0: np.ndarray) -> np.ndarray:
    n = A.shape[0]
    X = np.zeros((len(t), n), dtype=float)
    X[0] = x0
    def f(_t, x):
        return A.dot(x)
    for k in range(1, len(t)):
        dt = float(t[k] - t[k-1])
        X[k] = rk4_step(f, t[k-1], X[k-1], dt)
    return X


def summarize_errors(t: np.ndarray, X_true: np.ndarray, X_pred: np.ndarray) -> Dict:
    err = X_pred - X_true
    rmse_per_state = np.sqrt(np.mean(err**2, axis=0)).tolist()
    mae_per_state = np.mean(np.abs(err), axis=0).tolist()
    rmse_overall = float(np.sqrt(np.mean(err**2)))
    mae_overall = float(np.mean(np.abs(err)))
    return {
        "rmse_per_state": rmse_per_state,
        "mae_per_state": mae_per_state,
        "rmse_overall": rmse_overall,
        "mae_overall": mae_overall,
    }


def plot_comparison(outdir: Path, t: np.ndarray, X_true: np.ndarray, X_pred: np.ndarray, artifact_file: str, divergence_frac: float):
    outdir.mkdir(parents=True, exist_ok=True)
    n = X_true.shape[1]
    fig, axes = plt.subplots(nrows=n, ncols=1, figsize=(10, 3.2*n), sharex=True)
    if n == 1:
        axes = [axes]
    # threshold based on true dynamic range per state
    ranges = np.maximum(np.ptp(X_true, axis=0), 1e-9)
    thresholds = divergence_frac * ranges
    for i, ax in enumerate(axes):
        ax.plot(t, X_true[:, i], label=f"True s{i}", color="#1f77b4", lw=2)
        ax.plot(t, X_pred[:, i], label=f"Pred s{i}", color="#ff7f0e", lw=1.8, alpha=0.9)
        # shade divergence
        abs_err = np.abs(X_pred[:, i] - X_true[:, i])
        diverge_mask = abs_err > thresholds[i]
        if np.any(diverge_mask):
            # group contiguous segments
            idx = np.where(diverge_mask)[0]
            splits = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
            for seg in splits:
                ax.axvspan(t[seg[0]], t[seg[-1]], color="#d62728", alpha=0.12)
        ax.set_ylabel(f"s{i}")
        ax.grid(True, alpha=0.25)
        ax.legend(loc="upper right", frameon=False)
    axes[-1].set_xlabel("time")
    fig.suptitle(f"Predicted vs True (artifact: {os.path.basename(artifact_file)})", y=0.995)
    fig.tight_layout(rect=[0, 0.02, 1, 0.98])
    png_path = outdir / "state_comparison.png"
    fig.savefig(png_path, dpi=150)
    plt.close(fig)
    return str(png_path)


def main():
    parser = argparse.ArgumentParser(description="Visualize predicted vs true state curves from artifacts.")
    parser.add_argument("--artifacts", type=str, default="rge_run", help="Directory with artifact JSONs.")
    parser.add_argument("--artifact-file", type=str, default=None, help="Specific artifact JSON to use.")
    parser.add_argument("--csv", type=str, default=None, help="Optional CSV with ground-truth time/states.")
    parser.add_argument("--time-col", type=str, default=None, help="Time column in CSV.")
    parser.add_argument("--state-cols", nargs="+", default=None, help="State columns in CSV (e.g., x y z).")
    parser.add_argument("--divergence-frac", type=float, default=0.1, help="Shade where |err| > frac * range(true).")
    parser.add_argument("--x0", nargs="+", type=float, default=None, help="Initial state override (e.g., --x0 1 0).")
    parser.add_argument("--save-dir", type=str, default=None, help="Output directory (default: artifacts/plots).")
    args = parser.parse_args()

    art_dir = Path(args.artifacts)
    artifacts = load_artifacts(art_dir)
    artifact = pick_artifact(artifacts, args.artifact_file)
    A = extract_A(artifact)

    t_art, X_true_art = extract_truth_from_artifact(artifact)

    t_csv, X_true_csv = load_csv_truth(args.csv, args.time_col, args.state_cols or [])

    # Merge truth/time preference: CSV overrides artifact if provided
    t = t_csv if t_csv is not None else t_art
    X_true = X_true_csv if X_true_csv is not None else X_true_art

    if X_true is None:
        raise RuntimeError("No ground-truth states found. Provide CSV via --csv or include true_state in the artifact.")
    if t is None:
        # If no explicit time, assume uniform spacing
        t = np.arange(len(X_true), dtype=float)
    t = np.asarray(t, dtype=float)
    if X_true.shape[0] != len(t):
        raise RuntimeError(f"Time length {len(t)} != truth length {X_true.shape[0]}")

    n = X_true.shape[1]
    if A.shape != (n, n):
        raise RuntimeError(f"A-matrix shape {A.shape} does not match state dimension {n}")

    # Initial condition
    if args.x0 is not None:
        x0 = np.array(args.x0, dtype=float)
        if x0.shape[0] != n:
            raise RuntimeError(f"x0 length {len(x0)} != state dimension {n}")
    else:
        # Use first truth point if available; else small nonzero vector
        x0 = X_true[0] if X_true is not None else np.ones(n) * 1e-3

    X_pred = integrate_linear(A, t, x0)
    metrics = summarize_errors(t, X_true, X_pred)

    outdir = Path(args.save_dir) if args.save_dir else (art_dir / "plots")
    png_path = plot_comparison(outdir, t, X_true, X_pred, artifact_file=artifact.get("_file", "?"), divergence_frac=args.divergence_frac)

    # Write a compact audit record
    audit = {
        "artifact_file": artifact.get("_file"),
        "score": artifact.get("score"),
        "shape_A": list(A.shape),
        "rmse_overall": metrics["rmse_overall"],
        "mae_overall": metrics["mae_overall"],
        "rmse_per_state": metrics["rmse_per_state"],
        "mae_per_state": metrics["mae_per_state"],
        "plot_path": png_path,
        "time_len": int(len(t)),
        "state_dim": int(X_true.shape[1]),
        "divergence_fraction_threshold": args.divergence_frac,
    }
    (outdir / "comparison_audit.json").write_text(json.dumps(audit, indent=2))
    print(json.dumps(audit, indent=2))


if __name__ == "__main__":
    main()