# Surrogate MLP with LayerNorm & Dropout
Train and evaluate the new `SurrogateMLP` architecture on the DA chain data.
Use the configuration cell below to pick the data slice, hidden structure, loss, and dropout;
the notebook will report the L1 log-posterior error on the tail of the chain (index ≥ 50,000).

**Workflow**
1. Adjust the configuration cell (data path, number of chain steps, loss type, etc.).
2. Run the data-loading cell to fetch `par`, `obs`, `chain`, `props`, and `logpi`.
3. Execute the preprocessing + training cells to fit `SurrogateMLP`.
4. Inspect the training history and final metrics, including the requested L1 logπ error
   computed on the remainder of the chain beyond index 50,000.

In [1]:
import math
from pathlib import Path
from typing import Sequence, Literal

import numpy as np
import torch
from torch import nn

from test_mlp import (
    load_data,
    collect_training_indices,
    standardize_features,
    train_mlp,
    compute_logpi_l1_error,
)


In [2]:
ActivationName = Literal["silu", "relu", "tanh", "gelu"]


def _make_activation(name: ActivationName) -> nn.Module:
    name = name.lower()
    if name == "silu":
        return nn.SiLU()
    if name == "relu":
        return nn.ReLU()
    if name == "tanh":
        return nn.Tanh()
    if name == "gelu":
        return nn.GELU()
    raise ValueError(f"Unknown activation: {name}")


class SurrogateMLP(nn.Module):
    """Flexible MLP with optional LayerNorm and dropout."""

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: Sequence[int],
        dropout: float = 0.1,
        activation: ActivationName = "silu",
        use_layer_norm: bool = True,
    ) -> None:
        super().__init__()
        if len(hidden_dims) == 0:
            raise ValueError("hidden_dims must be non-empty")
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dims = list(hidden_dims)
        self.dropout_p = float(dropout)
        self.use_layer_norm = bool(use_layer_norm)
        self.activation_name: ActivationName = activation.lower()  # type: ignore

        layers = []
        dims = [input_dim] + list(hidden_dims)
        for in_dim, out_dim in zip(dims[:-1], dims[1:]):
            block = []
            block.append(nn.Linear(in_dim, out_dim))
            if self.use_layer_norm:
                block.append(nn.LayerNorm(out_dim))
            block.append(_make_activation(self.activation_name))
            if self.dropout_p > 0.0:
                block.append(nn.Dropout(p=self.dropout_p))
            layers.append(nn.Sequential(*block))

        self.blocks = nn.ModuleList(layers)
        self.out_layer = nn.Linear(hidden_dims[-1], output_dim)
        self._init_weights()

    def _init_weights(self) -> None:
        nonlinearity = "relu" if self.activation_name in ("relu", "silu", "gelu") else "tanh"
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5), nonlinearity=nonlinearity)
                if module.bias is not None:
                    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
                    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0.0
                    nn.init.uniform_(module.bias, -bound, bound)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = x
        for block in self.blocks:
            h = block(h)
        return self.out_layer(h)



In [92]:
# --- Configuration ---

DATA_PATH = Path('data1.h5')
SIGMA_PRIOR = 1.0
SIGMA_LIK = 0.3
TRAIN_STEPS = 8000  # number of chain steps (chain + props) for training
STANDARDIZE_INPUTS = True

HIDDEN_DIMS = [1024, 1024, 1024]
DROPOUT_P = 0.5
USE_LAYER_NORM = True
ACTIVATION = 'silu'

LOSS_NAME = 'mse'       # 'l1', 'mse', or 'mixed'
LOSS_DOMAIN = 'obs'    # 'obs' or 'logpi'
TRAIN_LOOPS = 80
ADAM_EPOCHS = 2000
ADAM_LR = 1.1e-3
ADAM_PATIENCE = 1000
LBFGS_STEPS = 100
TOL = 1e-5
BATCH_SIZE = 32
BATCH_GROWTH = 1.4    # e.g. 1.2 to grow batch every loop
LOOP_IMPROVEMENT_PCT = -0.01

EVAL_TAIL_START = 50000
SEED = 123
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

np.random.seed(SEED)
torch.manual_seed(SEED)
np.set_printoptions(precision=4, suppress=True)
print(f'Using device: {DEVICE}')



Using device: cuda


In [4]:
# --- Load data ---
par, obs, y_obs, chain, props, logpi_true = load_data(str(DATA_PATH), SIGMA_PRIOR, SIGMA_LIK)
print(f'par shape   : {par.shape}')
print(f'obs shape   : {obs.shape}')
print(f'chain length: {chain.shape[0]}')
print(f'props length: {props.shape[0]}')

[data] Loaded 'logpi' from file.
[data] par shape   : (28324, 30)
[data] obs shape   : (28324, 52)
[data] y_obs shape : (52,)
[data] chain shape : (56646,)
[data] props shape : (56646,)
par shape   : (28324, 30)
obs shape   : (28324, 52)
chain length: 56646
props length: 56646


In [86]:
# --- Prepare training tensors ---

train_idx = collect_training_indices(chain, props, TRAIN_STEPS)
X_train_raw = par[train_idx]
y_train = obs[train_idx]
logpi_train = logpi_true[train_idx]

if STANDARDIZE_INPUTS:
    X_train, x_mean, x_std = standardize_features(X_train_raw)
else:
    X_train = X_train_raw
    x_mean = np.zeros(X_train_raw.shape[1], dtype=X_train_raw.dtype)
    x_std = np.ones(X_train_raw.shape[1], dtype=X_train_raw.dtype)

print(f'unique training samples: {X_train.shape[0]}')
print(f'feature dimension     : {X_train.shape[1]}')
print(f'observation dimension : {y_train.shape[1]}')



unique training samples: 4001
feature dimension     : 30
observation dimension : 52


In [93]:
# --- Initialize and train SurrogateMLP via test_mlp.train_mlp ---

model = SurrogateMLP(
    input_dim=par.shape[1],
    output_dim=obs.shape[1],
    hidden_dims=HIDDEN_DIMS,
    dropout=DROPOUT_P,
    activation=ACTIVATION,
    use_layer_norm=USE_LAYER_NORM,
).to(DEVICE)


In [94]:

final_train_loss = train_mlp(
    model=model,
    X_train=X_train,
    y_train=y_train,
    device=DEVICE,
    max_adam_epochs=ADAM_EPOCHS,
    adam_lr=ADAM_LR,
    adam_patience=ADAM_PATIENCE,
    tol=TOL,
    max_lbfgs_iter=LBFGS_STEPS,
    loss_name=LOSS_NAME,
    train_loops=TRAIN_LOOPS,
    batch_size=BATCH_SIZE,
    loss_domain=LOSS_DOMAIN,
    par_train_raw=X_train_raw,
    logpi_targets=logpi_train,
    y_obs=y_obs,
    sigma_prior=SIGMA_PRIOR,
    sigma_lik=SIGMA_LIK,
    batch_growth=BATCH_GROWTH,
    loop_improvement_pct=LOOP_IMPROVEMENT_PCT,
    verbose=2,
)

print(f'Final training loss ({LOSS_NAME}): {final_train_loss:.6e}')



[train][loop 1/80] Starting Adam: epochs=2000, loss=MSE, lr=1.100e-03, batch=32, domain=obs
[train][loop 1/80][Adam] epoch    1 | loss(mse) = 1.134268e-01 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   10 | loss(mse) = 2.928639e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   20 | loss(mse) = 2.311688e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   30 | loss(mse) = 2.212776e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   40 | loss(mse) = 1.920779e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   50 | loss(mse) = 1.963058e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   60 | loss(mse) = 1.638783e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   70 | loss(mse) = 1.485144e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   80 | loss(mse) = 1.605938e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch   90 | loss(mse) = 1.314875e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epoch  100 | loss(mse) = 1.296617e-02 | lr = 1.100e-03
[train][loop 1/80][Adam] epo

In [91]:
# --- Compute L1 logpi error on chain tail (index >= tail_start) ---

tail_len = chain.shape[0] - EVAL_TAIL_START
if tail_len <= 0:
    raise ValueError('EVAL_TAIL_START must be smaller than the chain length.')

tail_l1 = compute_logpi_l1_error(
    model=model,
    par=par,
    logpi_true=logpi_true,
    chain=chain,
    props=props,
    val_start=EVAL_TAIL_START,
    val_len=tail_len,
    x_mean=x_mean,
    x_std=x_std,
    y_obs=y_obs,
    sigma_prior=SIGMA_PRIOR,
    sigma_lik=SIGMA_LIK,
    device=DEVICE,
)

print(f'L1 logpi error on chain tail [{EVAL_TAIL_START}, end): {tail_l1:.6f}')



L1 logpi error on chain tail [50000, end): 3.742622


5.242563
4.899523
v4.232338


iteration,train_limit,base_hidden_sizes,final_hidden_sizes,val_error,master_val_error,growth_steps,final_train_loss
1,2000,[8],"[8, 8]",5.867817679563296,7.095813837634286,1,0.08197729289531708
2,4000,"[8, 8]","[16, 32]",4.594648975642015,4.918739791627595,3,0.07835499197244644
3,6000,"[16, 32]","[16, 32]",4.907302817124111,4.165183718501073,0,0.05713796243071556
4,8000,"[16, 32]","[32, 32]",3.5517553089907246,3.359570796361799,1,0.058842990547418594
5,10000,"[32, 32]","[64, 64]",2.4533132625766885,2.6955348713135168,2,0.0424831286072731
6,12000,"[64, 64]","[64, 64]",2.5594358762574174,2.5561185208652413,0,0.028931600973010063
7,14000,"[64, 64]","[64, 64]",2.175952113928763,2.2197988679717255,0,0.027288204059004784
8,16000,"[64, 64]","[64, 64]",1.9016387067773832,2.331814047863837,0,0.028165634721517563
9,18000,"[64, 64]","[64, 64, 64]",1.9504109285714764,1.7103362801803619,1,0.026211336255073547
10,20000,"[64, 64, 64]","[64, 64, 64]",2.2927826321948723,1.6362665044396796,0,0.019929099828004837
11,22000,"[64, 64, 64]","[64, 64, 64, 64]",1.977168851832621,1.6241933904734225,1,0.019994590431451797