In [None]:
import os
from pathlib import Path

import pandas as pd
from src import analysis, util
from src.analysis import evaluation

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [None]:
# ---------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------
model = "FNO_lhs_var10_plog100_seed9_20251125_171807"

dataset_name_id = "lhs_var10_plog100_seed9"
dataset_name_ood = "lhs_var20_plog100_seed9"

In [None]:
# ---------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------
checkpoint_path = Path(f"../data/processed/{model}/best_model_state_dict.pt")

dataset_cases_path_id = Path(f"../../data/raw/{dataset_name_id}/cases")
dataset_cases_path_ood = Path(f"../../data/raw/{dataset_name_ood}/cases") if dataset_name_ood else None

save_root_id = Path(f"../data/processed/{model}/analysis/id")
save_root_ood = Path(f"../data/processed/{model}/analysis/ood")

In [None]:
# ---------------------------------------------------------------------
# Helper: run only if parquet does NOT exist
# ---------------------------------------------------------------------
def run_or_load_artifacts_evaluation(
    dataset_name: str,
    save_root: Path,
    dataset_path: Path,
) -> tuple[pd.DataFrame, Path]:
    """
    Load an existing Parquet artifact for the given dataset, or create it if missing.

    This function checks whether the Parquet file for `dataset_name` already exists
    in `save_root`. If it does, the DataFrame is loaded directly. If not, the full
    inference pipeline is executed to generate:
        - one NPZ file per case
        - one Parquet file with global per-case statistics

    Args:
        dataset_name:
            Name of the dataset (used for the Parquet filename).
        save_root:
            Directory where artifacts are stored (contains `<dataset>.parquet` and `npz/`).
        dataset_path:
            Path to the raw `cases/` folder of the dataset.

    Returns:
        tuple[pd.DataFrame, Path]:
            df: Loaded or newly generated evaluation DataFrame.
            parquet_path: Path to the Parquet artifact.

    """
    parquet_path = save_root / f"{dataset_name}.parquet"

    # ------------------------------
    # Fast path: Parquet already exists
    # ------------------------------
    if parquet_path.exists():
        print(f"[INFO] Found existing parquet → loading: {parquet_path}")
        df = pd.read_parquet(parquet_path)
        return df, parquet_path

    # ------------------------------
    # Slow path: Generate artifacts
    # ------------------------------
    print(f"[INFO] Creating artifacts for dataset: {dataset_name}")

    model, loader, processor, device = analysis.analysis_interference.load_inference_context(
        dataset_path=dataset_path,
        checkpoint_path=checkpoint_path,
        batch_size=1,
        ood_fraction=1.0,
    )

    df, parquet_path = analysis.analysis_artifacts.generate_artifacts(
        model=model,
        loader=loader,
        processor=processor,
        device=device,
        save_root=save_root,
        dataset_name=dataset_name,
    )

    return df, parquet_path

In [None]:
# ---------------------------------------------------------------------
# ID artifacts
# ---------------------------------------------------------------------
df_id, parquet_id = run_or_load_artifacts_evaluation(
    dataset_name=dataset_name_id,
    save_root=save_root_id,
    dataset_path=dataset_cases_path_id,
)

# ---------------------------------------------------------------------
# OOD artifacts (optional)
# ---------------------------------------------------------------------
df_ood = None

if dataset_cases_path_ood is not None:
    df_ood, parquet_ood = run_or_load_artifacts_evaluation(
        dataset_name=dataset_name_ood,
        save_root=save_root_ood,
        dataset_path=dataset_cases_path_ood,
    )

In [None]:
toggle = util.util_nb.make_toggle_shortcut(df_eval_id, dataset_name_id)

global_error_analysis_plots = [
    toggle("1-1. Global error metrics", lambda: evaluation.evaluation_plots.plot_global_error_metrics(df_eval_id)),
    toggle("1-2. Global error distribution (|U_error|)", lambda: evaluation.evaluation_plots.plot_error_distribution(df_eval_id, field="U")),
    toggle("1-3. GT vs Prediction (U, global)", lambda: evaluation.evaluation_plots.plot_global_gt_vs_pred(df_eval_id, field="U")),
    toggle("1-4. Mean error maps", lambda: evaluation.evaluation_plots.plot_mean_error_maps(df_eval_id)),
    toggle("1-5. Std error maps", lambda: evaluation.evaluation_plots.plot_std_error_maps(df_eval_id)),
]

id_ood_comparison_plots = None

if df_eval_ood is not None:
    id_ood_comparison_plots = [
        toggle("2-1. ID vs OOD metrics", lambda: evaluation.evaluation_plots.plot_id_vs_ood_metrics(df_eval_id, df_eval_ood)),
        toggle(
            "2-2. ID vs OOD error distributions (|U_error|)",
            lambda: evaluation.evaluation_plots.plot_id_vs_ood_error_distributions(df_eval_id, df_eval_ood),
        ),
        toggle("2-3. OOD - ID mean error map", lambda: evaluation.evaluation_plots.plot_id_vs_ood_mean_error_difference(df_eval_id, df_eval_ood)),
    ]

permeability_sensitivity_plots = [
    toggle("3-1. Error vs permeability magnitude", lambda: evaluation.evaluation_plots.plot_error_vs_kappa_magnitude(df_eval_id)),
    toggle("3-2. Error vs anisotropy ratio", lambda: evaluation.evaluation_plots.plot_error_vs_anisotropy_ratio(df_eval_id)),
    toggle("3-3. Error vs mean permeability", lambda: evaluation.evaluation_plots.plot_error_vs_mean_kappa(df_eval_id)),
    toggle("3-4. Error vs permeability gradient", lambda: evaluation.evaluation_plots.plot_error_vs_kappa_gradient(df_eval_id)),
]

sample_viewer_plots = [
    toggle("4-1. Sample Viewer — GT / Prediction / Error", lambda: evaluation.evaluation_plots.plot_sample_prediction_overview(df_eval_id)),
    toggle("4-2. Sample Viewer — kappa tensor (3×3) overlays", lambda: evaluation.evaluation_plots.plot_sample_kappa_tensor_with_overlay(df_eval_id)),
]


sections = []

# 1. Global Error Analysis
sections.append(util.util_nb.make_dropdown_section(global_error_analysis_plots, f"{dataset_name_id} — Global Error Analysis"))

# 2. ID/OOD Comparison (optional)
if id_ood_comparison_plots is not None:
    sections.append(util.util_nb.make_dropdown_section(id_ood_comparison_plots, f"{dataset_name_id} vs {dataset_name_ood} — ID/OOD Comparison"))

# 3. Permeability Sensitivity
sections.append(util.util_nb.make_dropdown_section(permeability_sensitivity_plots, f"{dataset_name_id} — Permeability Sensitivity"))

# 4. Sample Viewer
sections.append(util.util_nb.make_dropdown_section(sample_viewer_plots, f"{dataset_name_id} — Sample Viewer"))

tab_titles = ["1. Global Error Analysis"]

if id_ood_comparison_plots is not None:
    tab_titles.append("2. ID/OOD Comparison")
    tab_titles.append("3. Permeability Sensitivity")
    tab_titles.append("4. Sample Viewer")
else:
    tab_titles.append("2. Permeability Sensitivity")
    tab_titles.append("3. Sample Viewer")

evaluation_panel = util.util_nb.make_lazy_panel_with_tabs(
    sections,
    tab_titles=tab_titles,
    open_btn_text=f"{dataset_name_id} — Open Evaluation",
    close_btn_text="Close",
)

display(evaluation_panel)