# 03 — FNO Surrogate Fidelity Validation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SharathSPhD/RLpower/blob/main/notebooks/03_surrogate_validation.ipynb)

Validate the **Fourier Neural Operator (FNO)** surrogate model trained to emulate the sCO₂ FMU physics simulation.

**All data loaded from `data/` — runs on Google Colab without FMU or model weights.**

## Architecture
FNO implemented with **NVIDIA PhysicsNeMo** (`physicsnemo.models.fno.FNO`): 546K parameters, spectral convolutions over 719 time steps.

## Dataset
- **V1 (failed)**: 75K rows but only 2,100 unique initial states — FNO overfit to repeated sequences. R² = −77.  
- **V2 (current)**: 76,600 strictly-unique LHS-sampled FMU roll-outs (3.98 GB), trained on DGX Spark GB10.

In [None]:
import subprocess, sys, os

IN_COLAB = False
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    pass

REPO_URL = "https://github.com/SharathSPhD/RLpower.git"
REPO_DIR = "/content/RLpower" if IN_COLAB else os.environ.get("WORKSPACE_DIR", os.getcwd())

if IN_COLAB:
    if not os.path.exists(REPO_DIR):
        subprocess.run(["git", "clone", "--depth=1", REPO_URL, REPO_DIR], check=True)
    os.chdir(REPO_DIR)
    subprocess.run([sys.executable, "-m", "pip", "install", "-q", "matplotlib", "numpy"], check=True)
else:
    if os.path.exists(REPO_DIR):
        os.chdir(REPO_DIR)

import matplotlib; matplotlib.use("Agg")
from pathlib import Path
import json, matplotlib.pyplot as plt, numpy as np

print(f"Environment: {'Colab' if IN_COLAB else 'Local'}  |  cwd: {os.getcwd()}")

In [None]:
ROOT = Path('.').resolve()

def _load(data_rel, artifact_rel):
    for path in [ROOT / data_rel, ROOT / artifact_rel]:
        if path.exists():
            return json.loads(path.read_text())
    return None

v1 = _load('data/surrogate_fidelity_report.json', 'artifacts/surrogate/fidelity_report.json')

print('=== V1 fidelity (degenerate 75K, upsampled) ===')
if v1:
    print(f'  overall R2:   {v1["overall_r2"]:.4f}  (negative = worse than mean predictor)')
    print(f'  overall RMSE: {v1["overall_rmse_normalized"]:.4f}')
    print(f'  passed:       {v1["passed"]}')
    print()
    print('Root cause: 75,000 trajectories had only 2,100 unique initial states.')
    print('Remediation: 76,600 unique LHS-sampled trajectories (V2).')
else:
    print('  Report not found.')

print()
print('=== V2 target thresholds ===')
print('  overall R2 >= 0.80')
print('  T_compressor_inlet R2 >= 0.95')
print('  normalized RMSE <= 0.10')

In [None]:
if v1 is None:
    print('No V1 report — skipping plot.')
else:
    key_vars = ['T_compressor_inlet','T_turbine_inlet','T_compressor_outlet',
                'T_turbine_outlet','W_turbine','W_main_compressor','Q_recuperator']
    pv = v1.get('per_variable', {})
    names = [v for v in key_vars if v in pv]
    r2_v1   = [pv[v]['r2']   for v in names]
    rmse_v1 = [pv[v]['rmse'] for v in names]

    fig, axs = plt.subplots(1, 2, figsize=(14, 5))

    axs[0].barh(names, r2_v1, color='#c0392b', alpha=0.82, label='V1 R² (degenerate)')
    axs[0].axvline(0.80, color='#27ae60', linestyle='--', lw=2, label='V2 gate R²=0.80')
    axs[0].axvline(0.95, color='#2980b9', linestyle=':', lw=2, label='V2 critical R²=0.95')
    axs[0].axvline(0.0, color='black', lw=1)
    axs[0].set_title('FNO R² per variable — V1 failure'); axs[0].legend(fontsize=8)
    axs[0].set_xlabel('R²'); axs[0].grid(True, axis='x', alpha=0.3)

    axs[1].barh(names, rmse_v1, color='#c0392b', alpha=0.82, label='V1 RMSE')
    axs[1].axvline(0.10, color='#27ae60', linestyle='--', lw=2, label='V2 gate RMSE<=0.10')
    axs[1].set_title('FNO normalized RMSE — V1 failure'); axs[1].legend(fontsize=8)
    axs[1].set_xlabel('Normalized RMSE'); axs[1].grid(True, axis='x', alpha=0.3)

    plt.suptitle('FNO Surrogate V1 (failed) → V2 Remediation with 76.6K Unique LHS Trajectories',
                 fontsize=11, fontweight='bold')
    plt.tight_layout(rect=[0,0,1,0.93])
    plt.savefig('/tmp/surrogate_fidelity.png', dpi=100, bbox_inches='tight')
    plt.show()
    print('V1 overall R² =', round(v1['overall_r2'], 2), '(negative: catastrophic failure)')

In [None]:
print('PhysicsNeMo FNO Architecture (V2)')
print('=' * 45)
specs = [
    ('Library',       'NVIDIA PhysicsNeMo (nvidia-physicsnemo)'),
    ('Model class',   'physicsnemo.models.fno.FNO'),
    ('Parameters',    '546,190'),
    ('Input dim',     '18 (14 obs + 4 actions) x 719 time steps'),
    ('Output dim',    '14 (predicted observations) x 719 steps'),
    ('Dataset V2',    '76,600 unique LHS FMU trajectories (3.98 GB)'),
    ('Training hw',   'NVIDIA DGX Spark GB10 Grace Blackwell GPU'),
    ('Train epochs',  '200 (early-stop patience=20)'),
    ('Optimizer',     'Adam lr=1e-3, wd=1e-4'),
    ('Gate: R²',      '>= 0.80 overall, >= 0.95 critical vars'),
    ('Gate: RMSE',    '<= 0.10 normalized'),
]
for k, v in specs:
    print(f'  {k:<20} {v}')