# TransformedEmulator: Sampling vs Analytical (Delta)

This notebook benchmarks speed and accuracy of two prediction modes in `TransformedEmulator`:
- `output_from_samples=False` (analytical/approx via inverse + delta method, diagonal variance)
- `output_from_samples=True` (sampling-based inversion, diagonal variance)

We compare timings for `.fit()`, `.predict_mean()`, and `.predict_mean_and_variance()` and report accuracy (RMSE and mean NLL) across different target dimensionalities.

Settings:
- Model: GaussianProcess
- full_covariance=False (always)
- x transforms: Standardize
- y transforms benchmarked: Standardize, Standardize+PCA, Standardize+VAE
- n_samples (sampling path): 256

In [None]:
import math
import time
from dataclasses import dataclass

import torch
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from autoemulate.data.utils import set_random_seed
from autoemulate.emulators import GaussianProcess
from autoemulate.emulators.transformed.base import TransformedEmulator
from autoemulate.transforms import StandardizeTransform, PCATransform, VAETransform

device = None  # use default CPU
set_random_seed(123)
torch.set_default_dtype(torch.float64)

def make_synthetic(n_train=600, n_test=200, x_dim=10, y_dim=32, noise=0.05):
    """
    Create a simple nonlinear multioutput regression dataset.
    y = A * sin(B @ x) + C * x + noise
    """
    set_random_seed(123 + y_dim)
    X_train = torch.randn(n_train, x_dim)
    X_test = torch.randn(n_test, x_dim)
    B = torch.randn(x_dim, y_dim) / math.sqrt(x_dim)
    A = torch.randn(y_dim)
    C = torch.randn(x_dim, y_dim) / math.sqrt(x_dim)
    def f(X):
        lin = X @ C
        s = torch.sin(X @ B) * A
        return lin + s
    Y_train = f(X_train) + noise * torch.randn(n_train, y_dim)
    Y_test = f(X_test) + noise * torch.randn(n_test, y_dim)
    return X_train, Y_train, X_test, Y_test

@dataclass
class BenchResult:
    y_dim: int
    y_transform: str
    output_from_samples: bool
    fit_time_s: float
    pred_mean_time_s: float
    pred_mv_time_s: float
    rmse: float
    nll: float

def rmse(y_true: torch.Tensor, y_pred: torch.Tensor) -> float:
    return float(torch.sqrt(torch.mean((y_true - y_pred) ** 2)))

def mean_diag_nll(y_true: torch.Tensor, mean: torch.Tensor, var: torch.Tensor) -> float:
    min_var = 1e-6
    var = torch.clamp(var, min=min_var)
    nll = 0.5 * (
        torch.log(2 * torch.pi * var) + (y_true - mean) ** 2 / var
    )
    return float(nll.mean())

def y_transform_configs(y_dim: int):
    return [
        ("standardize", [StandardizeTransform()]),
        ("standardize+pca", [StandardizeTransform(), PCATransform(n_components=min( max(1, y_dim//4), max(1, y_dim-1) ))]),
        ("standardize+vae", [StandardizeTransform(), VAETransform(latent_dim=max(1, y_dim//4))]),
    ]

def benchmark_one(y_dim: int, y_t_name: str, y_t_list, output_from_samples: bool, n_samples: int = 256) -> BenchResult:
    Xtr, Ytr, Xte, Yte = make_synthetic(y_dim=y_dim)
    em = TransformedEmulator(
        Xtr,
        Ytr,
        x_transforms=[StandardizeTransform()],
        y_transforms=y_t_list,
        model=GaussianProcess,
        output_from_samples=output_from_samples,
        n_samples=n_samples,
        full_covariance=False,
        device=device,
    )
    t0 = time.perf_counter()
    em.fit(Xtr, Ytr)
    fit_t = time.perf_counter() - t0

    # Warm-up (exclude first-call overhead from timings)
    Xw = Xte[: min(64, Xte.shape[0])]
    try:
        _ = em.predict_mean(Xw)
    except Exception:
        pass
    try:
        _ = em.predict_mean_and_variance(Xw)
    except Exception:
        pass

    t0 = time.perf_counter()
    mu = em.predict_mean(Xte)
    pm_t = time.perf_counter() - t0

    t0 = time.perf_counter()
    mu2, var = em.predict_mean_and_variance(Xte)
    pmv_t = time.perf_counter() - t0

    # Sanity: mu and mu2 should be close
    _ = float(torch.mean((mu - mu2).abs()))

    r = BenchResult(
        y_dim=y_dim,
        y_transform=y_t_name,
        output_from_samples=output_from_samples,
        fit_time_s=fit_t,
        pred_mean_time_s=pm_t,
        pred_mv_time_s=pmv_t,
        rmse=rmse(Yte, mu),
        nll=mean_diag_nll(Yte, mu2, var),
    )
    return r

def run_benchmarks(y_dims=(4, 32, 128), n_samples=256) -> pd.DataFrame:
    results = []
    for yd in y_dims:
        for y_t_name, y_t_list in y_transform_configs(yd):
            for ofs in (False, True):
                r = benchmark_one(yd, y_t_name, y_t_list, ofs, n_samples=n_samples)
                results.append(r.__dict__)
    return pd.DataFrame(results)

In [None]:
df = run_benchmarks(y_dims=(4, 32, 128), n_samples=256)
# df = run_benchmarks(y_dims=(4, 32), n_samples=256)
df.sort_values(['y_dim', 'y_transform', 'output_from_samples'], inplace=True)
df

In [None]:
# Plot timings
fig, axes = plt.subplots(3, 3, figsize=(15, 12), sharey=False)
for i, metric in enumerate(['fit_time_s', 'pred_mean_time_s', 'pred_mv_time_s']):
    for j, y_t in enumerate(sorted(df['y_transform'].unique())):
        ax = axes[i, j]
        sub = df[df['y_transform'] == y_t]
        sns.barplot(data=sub, x='y_dim', y=metric, hue='output_from_samples', ax=ax)
        ax.set_title(f"{metric} — {y_t}")
        ax.set_xlabel('y_dim')
        if i == 0:
            ax.legend(title='from_samples')
        else:
            ax.get_legend().remove()
plt.tight_layout()
plt.show()

In [None]:
# Plot accuracy metrics
fig, axes = plt.subplots(2, len(df['y_transform'].unique()), figsize=(16, 6), sharey=False)
for j, y_t in enumerate(sorted(df['y_transform'].unique())):
    sub = df[df['y_transform'] == y_t]
    sns.barplot(data=sub, x='y_dim', y='rmse', hue='output_from_samples', ax=axes[0, j])
    axes[0, j].set_title(f'RMSE — {y_t}')
    axes[0, j].set_xlabel('y_dim')
    sns.barplot(data=sub, x='y_dim', y='nll', hue='output_from_samples', ax=axes[1, j])
    axes[1, j].set_title(f'Mean NLL — {y_t}')
    axes[1, j].set_xlabel('y_dim')
for i in range(2):
    for j in range(len(df['y_transform'].unique())):
        if i == 0:
            axes[i, j].legend(title='from_samples')
        else:
            axes[i, j].get_legend().remove()
plt.tight_layout()
plt.show()

## Notes
- Both modes share the same underlying fitted model; differences are in inversion to original space and variance handling.
- Sampling variance estimates can be noisy for small `n_samples`; increase it to trade speed for stability.
- Analytical (delta) path assumes approximate linearization; may bias variance under strong nonlinearity.