# Final Chain DA-MH Investigation

Train a selected MLP architecture on the full 50k-step training window and evaluate
the delayed-acceptance (DA) rejection probability on the final portion of the chain.

## Notebook Overview
1. Configure data paths, model size, and optimizer hyperparameters.
2. Load datasets and construct the final training/validation splits just like `test_mlp.py`.
3. Instantiate the MLP architecture.
4. Train using the Adam+LBFGS routine from the test script.
5. Compute the mean DA rejection probability for the validation window at the chain end.

In [1]:
import numpy as np
import torch

from test_mlp import (
    MLP,
    load_data,
    collect_training_indices,
    standardize_features,
    evaluate_da_metrics,
    log_posterior_unnorm_numpy,
)

In [2]:
from torch import nn
from torch.utils.data import DataLoader


def _logpi_from_preds(par_batch, obs_pred, y_obs_tensor, sigma_prior, sigma_lik):
    prior_term = -0.5 * torch.sum((par_batch / sigma_prior) ** 2, dim=1)
    resid = (obs_pred - y_obs_tensor) / sigma_lik
    lik_term = -0.5 * torch.sum(resid ** 2, dim=1)
    return prior_term + lik_term


def train_mlp(
    model,
    X_train,
    y_train,
    device,
    max_adam_epochs=200,
    adam_lr=1e-3,
    adam_patience=30,
    tol=1e-5,
    max_lbfgs_iter=50,
    loss_name="l1",
    train_loops=1,
    batch_size=None,
    loss_domain="obs",
    par_train_raw=None,
    logpi_targets=None,
    y_obs=None,
    sigma_prior=1.0,
    sigma_lik=1.0,
):
    """Mini-batch Adam followed by full-batch LBFGS training."""
    model.to(device)
    model.train()

    X_tensor = torch.from_numpy(X_train.astype(np.float32)).to(device)
    y_tensor = torch.from_numpy(y_train.astype(np.float32)).to(device)
    par_tensor_raw = (
        torch.from_numpy(par_train_raw.astype(np.float32)).to(device)
        if par_train_raw is not None
        else None
    )
    logpi_tensor = (
        torch.from_numpy(logpi_targets.astype(np.float32)).to(device)
        if logpi_targets is not None
        else None
    )
    y_obs_tensor = (
        torch.from_numpy(y_obs.astype(np.float32)).to(device)
        if y_obs is not None
        else None
    )

    n_samples = X_tensor.shape[0]
    if n_samples == 0:
        raise ValueError("Training set is empty.")

    if batch_size is None or batch_size <= 0 or batch_size > n_samples:
        effective_batch = n_samples
    else:
        effective_batch = batch_size

    loss_name_lower = loss_name.lower()
    if loss_name_lower not in {"l1", "mse", "mixed"}:
        raise ValueError(
            f"Unsupported loss_name '{loss_name}'. Expected 'l1', 'mse', or 'mixed'."
        )

    def make_criterion(name: str):
        if name == "l1":
            return nn.L1Loss()
        return nn.MSELoss()

    loss_domain = loss_domain.lower()
    if loss_domain not in {"obs", "logpi"}:
        raise ValueError("loss_domain must be 'obs' or 'logpi'.")
    use_logpi_loss = loss_domain == "logpi"
    if use_logpi_loss and (par_tensor_raw is None or logpi_tensor is None or y_obs_tensor is None):
        raise ValueError(
            "logpi loss requires par_train_raw, logpi_targets, and y_obs inputs."
        )

    if train_loops <= 0:
        raise ValueError("train_loops must be a positive integer.")

    best_loss = float("inf")
    current_adam_lr = adam_lr
    indices = torch.arange(n_samples)

    for loop_idx in range(train_loops):
        loop_str = f"[train][loop {loop_idx + 1}/{train_loops}]"
        if loss_name_lower == "mixed":
            current_loss_name = "mse" if loop_idx % 2 == 0 else "l1"
        else:
            current_loss_name = loss_name_lower
        criterion = make_criterion(current_loss_name)
        optimizer_adam = torch.optim.Adam(model.parameters(), lr=current_adam_lr)
        no_improve = 0
        print(
            f"{loop_str} Starting Adam: epochs={max_adam_epochs}, loss={current_loss_name.upper()}, "
            f"lr={optimizer_adam.param_groups[0]['lr']:.3e}, batch={effective_batch}, domain={loss_domain}"
        )
        for epoch in range(1, max_adam_epochs + 1):
            loader = DataLoader(indices, batch_size=effective_batch, shuffle=True)
            for batch_idx in loader:
                batch_idx = batch_idx.to(device)
                batch_X = X_tensor[batch_idx]
                optimizer_adam.zero_grad()
                preds = model(batch_X)

                if use_logpi_loss:
                    batch_par = par_tensor_raw[batch_idx]
                    batch_logpi = logpi_tensor[batch_idx]
                    logpi_pred = _logpi_from_preds(
                        batch_par, preds, y_obs_tensor, sigma_prior, sigma_lik
                    )
                    loss = criterion(logpi_pred, batch_logpi)
                else:
                    batch_y = y_tensor[batch_idx]
                    loss = criterion(preds, batch_y)

                loss.backward()
                optimizer_adam.step()

            with torch.no_grad():
                full_preds = model(X_tensor)
                if use_logpi_loss:
                    logpi_pred_full = _logpi_from_preds(
                        par_tensor_raw, full_preds, y_obs_tensor, sigma_prior, sigma_lik
                    )
                    full_loss = criterion(logpi_pred_full, logpi_tensor)
                else:
                    full_loss = criterion(full_preds, y_tensor)
                loss_val = float(full_loss.item())

            if epoch == 1 or epoch % 10 == 0:
                current_lr = optimizer_adam.param_groups[0]["lr"]
                print(
                    f"{loop_str}[Adam] epoch {epoch:4d} | loss({current_loss_name}) = {loss_val:.6e} | lr = {current_lr:.3e}"
                )

            if best_loss == float("inf") or loss_val < best_loss - tol * (abs(best_loss) + 1e-12):
                best_loss = loss_val
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= adam_patience:
                    current_adam_lr *= 0.25
                    print(
                        f"{loop_str}[Adam] plateau at epoch {epoch}, reducing LR to {current_adam_lr:.3e} and stopping early."
                    )
                    if current_adam_lr < 1e-6:
                        print(f'LR {current_adam_lr:.3e} below 1e-6 threshold; terminating training early.')
                        return best_loss if best_loss != float("inf") else loss_val
                    break

        optimizer_lbfgs = torch.optim.LBFGS(
            model.parameters(),
            lr=1.0,
            max_iter=1,
            history_size=100,
            line_search_fn="strong_wolfe",
        )
        print(
            f"{loop_str} Starting LBFGS: max_iter={max_lbfgs_iter}, loss={current_loss_name.upper()}, "
            f"lr={optimizer_lbfgs.param_groups[0]['lr']:.3e}"
        )
        no_improve = 0
        for it in range(1, max_lbfgs_iter + 1):

            def closure():
                optimizer_lbfgs.zero_grad()
                preds = model(X_tensor)
                if use_logpi_loss:
                    logpi_pred = _logpi_from_preds(
                        par_tensor_raw, preds, y_obs_tensor, sigma_prior, sigma_lik
                    )
                    loss = criterion(logpi_pred, logpi_tensor)
                else:
                    loss = criterion(preds, y_tensor)
                loss.backward()
                return loss

            loss = optimizer_lbfgs.step(closure)
            loss_val = float(loss.item())
            if it == 1 or it % 5 == 0:
                current_lr = optimizer_lbfgs.param_groups[0]["lr"]
                print(
                    f"{loop_str}[LBFGS] iter {it:3d} | loss({current_loss_name}) = {loss_val:.6e} | lr = {current_lr:.3e}"
                )

            if loss_val < best_loss - tol * (abs(best_loss) + 1e-12):
                best_loss = loss_val
                no_improve = 0
            else:
                no_improve += 1
                if no_improve >= adam_patience:
                    print(f"{loop_str}[LBFGS] plateau detected at iter {it}, stopping early.")
                    break

    final_loss = best_loss
    summary_label = "MIXED" if loss_name_lower == "mixed" else loss_name_lower.upper()
    print(f"[train] Finished training with final train loss({summary_label}) = {final_loss:.6e}")
    return final_loss


## Configure Experiment Inputs
Adjust the values below to explore different architectures or optimization settings.

In [32]:
DATA_PATH = 'data1.h5'  # HDF5 file with par/obs datasets
SIGMA_PRIOR = 1.0
SIGMA_LIK = 0.3

HIDDEN_SIZES = [128, 128, 128]  # e.g., [32, 32, 32] for depth-3
USE_STANDARDIZATION = True

ADAM_LR = 1e-3
ADAM_EPOCHS = 4000
ADAM_PATIENCE = 100
LBFGS_STEPS = 100
TRAIN_LOOPS = 20
TRAIN_LOSS = 'l1'  # 'l1', 'mse', or 'mixed' (alternates per train loop)
LOSS_DOMAIN = 'obs'  # 'obs' for observation loss, 'logpi' for log posterior loss
BATCH_SIZE = 2048  # Mini-batch size for Adam (None for full batch)

CHAIN_FINAL_SIZE = 50000  # number of chain steps to use for training
VAL_SIZE = 1000  # evaluation window length following the training range
SEED = 123

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


## Load Data and Build Final Training Window
Mirrors the preprocessing inside `test_mlp.py`. The training set uses the first 50k chain
steps (chain + proposals), and the validation window immediately follows.

In [4]:
np.random.seed(SEED)
torch.manual_seed(SEED)

par, obs, y_obs, chain, props, logpi_true = load_data(DATA_PATH, SIGMA_PRIOR, SIGMA_LIK)

if CHAIN_FINAL_SIZE > chain.shape[0]:
    raise ValueError('CHAIN_FINAL_SIZE exceeds available chain length.')
val_start = CHAIN_FINAL_SIZE
if val_start + VAL_SIZE > chain.shape[0]:
    raise ValueError('Validation window exceeds available chain length. Reduce VAL_SIZE or CHAIN_FINAL_SIZE.')

train_idx = collect_training_indices(chain, props, CHAIN_FINAL_SIZE)
X_train_raw = par[train_idx]
y_train = obs[train_idx]
logpi_train = logpi_true[train_idx]

if USE_STANDARDIZATION:
    X_train_std, x_mean, x_std = standardize_features(X_train_raw)
else:
    X_train_std = X_train_raw
    x_mean = np.zeros(X_train_raw.shape[1], dtype=X_train_raw.dtype)
    x_std = np.ones(X_train_raw.shape[1], dtype=X_train_raw.dtype)

print(f'Training samples: {X_train_std.shape}')
print(f'Validation window: start={val_start}, length={VAL_SIZE}')

[data] Loaded 'logpi' from file.
[data] par shape   : (28324, 30)
[data] obs shape   : (28324, 52)
[data] y_obs shape : (52,)
[data] chain shape : (56646,)
[data] props shape : (56646,)
Training samples: (25001, 30)
Validation window: start=50000, length=1000


## Define the MLP Architecture
Uses the same `MLP` class as the automated test harness.

In [33]:
input_dim = X_train_std.shape[1]
output_dim = y_train.shape[1]
model = MLP(input_dim=input_dim, hidden_sizes=HIDDEN_SIZES, output_dim=output_dim)
model.to(device)
model

MLP(
  (net): Sequential(
    (0): Linear(in_features=30, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=52, bias=True)
  )
)

## Train with Adam + LBFGS
Runs the same `train_mlp` routine as `test_mlp.py`, including repeated outer loops.

In [37]:
final_train_loss = train_mlp(
    model,
    X_train_std,
    y_train,
    device=device,
    max_adam_epochs=100,
    adam_lr=1e-3,
    adam_patience=50,
    tol=1e-5,
    max_lbfgs_iter=100,
    loss_name='l1',
    train_loops=20,
    batch_size=64,
    loss_domain='obs',
    par_train_raw=X_train_raw,
    logpi_targets=logpi_train,
    y_obs=y_obs,
    sigma_prior=SIGMA_PRIOR,
    sigma_lik=SIGMA_LIK,
)
print(f'Final training loss: {final_train_loss:.6e}')

[train][loop 1/20] Starting Adam: epochs=100, loss=L1, lr=1.000e-03, batch=64, domain=obs
[train][loop 1/20][Adam] epoch    1 | loss(l1) = 2.593388e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   10 | loss(l1) = 2.598331e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   20 | loss(l1) = 2.525317e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   30 | loss(l1) = 2.578109e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   40 | loss(l1) = 2.539129e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   50 | loss(l1) = 2.536447e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   60 | loss(l1) = 2.485856e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   70 | loss(l1) = 2.388637e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   80 | loss(l1) = 2.407998e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] epoch   90 | loss(l1) = 2.498211e-02 | lr = 1.000e-03
[train][loop 1/20][Adam] plateau at epoch 93, reducing LR to 2.500e-04 and stopping early.
[train][loop 1/20] Starting L

## Evaluate DA Rejection Probability
Evaluates on the validation window immediately after the training range and reports the
mean DA rejection probability (a1 * (1 - a2)).

In [38]:
metrics = evaluate_da_metrics(
    model,
    par,
    obs,
    logpi_true,
    y_obs,
    chain,
    props,
    val_start=val_start,
    val_len=VAL_SIZE,
    x_mean=x_mean,
    x_std=x_std,
    sigma_prior=SIGMA_PRIOR,
    sigma_lik=SIGMA_LIK,
    device=device,
)
avg_da_reject = metrics['mean_da_reject']
print(f'Average DA rejection probability: {avg_da_reject:.6e}')
metrics

[eval] val_start=50000 | val_mse_obs=2.257594e-03 | mean_da_reject=2.074483e-01 | mean_da_accept=3.575873e-01
Average DA rejection probability: 2.074483e-01


{'val_mse_obs': 0.002257593944117132,
 'mean_da_reject': 0.20744831806455477,
 'mean_da_accept': 0.35758728931265865,
 'mean_a1': 0.5650356073772134,
 'mean_a2_reject': 0.3433650793235764}

## Log-Likelihood L1 Error on Validation Samples
Compute the mean absolute difference between the true log posterior and the surrogate
log posterior built from the MLP predictions over unique states appearing in the validation window.

In [39]:
val_chain = chain[val_start:val_start + VAL_SIZE]
val_props = props[val_start:val_start + VAL_SIZE]
unique_val_idx = np.unique(np.concatenate([val_chain, val_props]))

par_val = par[unique_val_idx]
if USE_STANDARDIZATION:
    X_val = (par_val - x_mean) / x_std
else:
    X_val = par_val

with torch.no_grad():
    X_val_t = torch.from_numpy(X_val.astype(np.float32)).to(device)
    obs_val_pred = model(X_val_t).cpu().numpy()

true_logpi = logpi_true[unique_val_idx]
pred_logpi = log_posterior_unnorm_numpy(par_val, obs_val_pred, y_obs, SIGMA_PRIOR, SIGMA_LIK)
l1_logpi_error = np.mean(np.abs(true_logpi - pred_logpi))
print(f'Mean L1 error |logpi_true - logpi_pred|: {l1_logpi_error:.6e}')


Mean L1 error |logpi_true - logpi_pred|: 1.565981e+00


In [None]:
metrics