# Location Finding Example (BED)

This notebook demonstrates:

- How to train deterministic and stochastic policies using the Barber–Agakov lower bound, jointly with a CouplingFlow posterior.  
- The simulator, neural networks, and training loop, along with experiments and plotting code to compare policies under prior and simulator shift.


### 1. Import Dependencies

In [None]:
# Main dependencies:
import os
if "KERAS_BACKEND" not in os.environ:
  os.environ["KERAS_BACKEND"] = "torch"

import math
import numpy as np
import copy
import torch
import torch.nn as nn
import torch.distributions as dist
from torch.distributions import transforms
from torch import Size, Tensor
from bayesflow.networks import CouplingFlow, FlowMatching

np.random.seed(42)
torch.manual_seed(42)

# For plotting:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import gaussian_kde
from scipy.interpolate import RBFInterpolator
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import MaxNLocator


### 2. Simulator

- This is the main class used to sample $\theta$ from the prior and to roll out trajectories under a given policy. 
- Each `LocationFinding` instance is associated with its own policy instance. 
- Note that we define separately the bounds of the prior (from which $\theta$ is sampled) and the bounds within which the policy can place designs, as this will be more convenient when analyzing prior shift.


In [None]:
class LocationFinding(nn.Module):
    def __init__(self,
                 policy: nn.Module,
                 prior_bounds: float = 3.0,           # Side length of uniform prior        
                 policy_bounds: float = 3.0,          # Side length of the design space
                 D: int = 1,                          # Number of sensors each time period
                 K: int = 1,                          # Number of sources
                 a: int | list = 1,                   # Signal weights for each source
                 m: float = 0.001,                    # Minimum squared distance to a source
                 b: float = 0.1,                      # Background signal
                 noise_std: float = 0.5,              # Noise added to the aggregate signal
                 T: int = 7,                          # Number of time periods
                 p: int = 2                           # Number of coordinate dimensions
    ) -> None:
        super(LocationFinding, self).__init__()

        self.policy = policy
        self.policy_bounds= policy_bounds
        self.D = D
        self.K = K
        self.m = m
        self.b = b
        self.noise_std = noise_std
        self.T = T
        self.p = p

        dtype = torch.float32
        if not isinstance(a, torch.Tensor):
            a = torch.tensor(a, dtype=dtype)
        self.register_buffer('a', a)
        self.register_buffer('_cov_matrix', noise_std**2 * torch.eye(D, dtype=dtype))
        self.register_buffer('_prior_low', -prior_bounds * torch.ones(K * p, dtype=dtype))
        self.register_buffer('_prior_high', prior_bounds * torch.ones(K * p, dtype=dtype))

    @property
    def device(self):
        return self.a.device

    @property
    def dtype(self):
        return self.a.dtype

    def prior(self):
        base = dist.Uniform(self._prior_low, self._prior_high)  # [K*p]
        return dist.Independent(base, 1)                        # Event shape: [K*p]

    def outcome_likelihood(self, theta: Tensor, designs: Tensor):
        batch_shape = theta.shape[:-1]
        theta = theta.view(*batch_shape, self.K, self.p)        # [*B, K, p]

        # Pair-wise distances between sensors and sources for each batch
        distances = torch.cdist(designs, theta)                 # [*B, D, K], since torch.cdist(BxDxp, BxKxp) -> BxDxK
        signals = self.a / (self.m + distances**2)              # [*B, D, K]
        total_signal = signals.sum(dim=-1) + self.b             # [*B, D]
        loc = torch.log(total_signal)

        noisy_signal = dist.MultivariateNormal(loc=loc, covariance_matrix=self._cov_matrix)

        return noisy_signal

    def forward(self, theta: Tensor, entropy_bonus: bool = False):
        """Simulates trajectories under a given batch of thetas."""
        batch_shape = theta.shape[:-1]
        designs  = torch.zeros(*batch_shape, self.T, self.D, self.p, device=theta.device, dtype=theta.dtype)    # [*B, T, D, p]
        outcomes = torch.zeros(*batch_shape, self.T, self.D, device=theta.device, dtype=theta.dtype)            # [*B, T, D]
        entropies = None

        if entropy_bonus:
            entropies = torch.empty(*batch_shape, self.T, device=theta.device, dtype=theta.dtype)               # [*B, T]

        # Generate trajectories under the current policy
        for t in range(self.T):
            hist_designs = designs[..., :t, :, :]                # [*B, t, D, p]
            hist_outcomes = outcomes[..., :t, :]                 # [*B, t, D]

            xi_t = self.policy(hist_designs, hist_outcomes)
            y_t = self.outcome_likelihood(theta, xi_t).rsample()

            if entropy_bonus:
                entropies[..., t] = self.policy._entropy(hist_designs, hist_outcomes)        # [*B]

            designs[..., t, :, :] = xi_t
            outcomes[..., t, :] = y_t

        return designs, outcomes, entropies

    def sample(self, batch_shape: Size, entropy_bonus: bool = False):
        """Randomly sample theta from prior and generate trajectories."""
        theta = self.prior().sample(batch_shape)      # [*B, K*p]
        designs, outcomes, entropies = self(theta, entropy_bonus=entropy_bonus)

        return {
            "theta": theta,             # [*B, K*p]
            "designs": designs,         # [*B, T, D, p]
            "outcomes": outcomes,       # [*B, T, D]
            "entropies": entropies}     # [*B, T] or None

    @torch.no_grad()
    def run_policy(self, theta: Tensor, entropy_bonus: bool = False):
        """forward() call without gradients and in evaluation mode."""
        self.policy.eval()
        designs, outcomes, entropies = self(theta, entropy_bonus=entropy_bonus)
        self.policy.train()

        return designs, outcomes, entropies

    def plot_designs(self,
                    theta: Tensor,
                    designs: Tensor) -> None:
        """Plot designs from the policy network under a given set of thetas"""
        assert self.p == 2, "only 2-D plotting supported"

        theta = theta.cpu()
        designs = designs.cpu()
        num_periods = designs.shape[1]            # T
        B = theta.size(0)                         # Only plot the first batch dimension
        theta = theta.view(B, self.K, self.p)     # [*B, K, p]

        # Initialize the signal field 
        grid_size = 100
        x_vals = torch.linspace(-self.policy_bounds, self.policy_bounds, grid_size)
        y_vals = torch.linspace(-self.policy_bounds, self.policy_bounds, grid_size)
        X, Y   = torch.meshgrid(x_vals, y_vals, indexing="ij")
        grid_pts = torch.cartesian_prod(x_vals, y_vals)     # [G, 2] := [grid_size*grid_size, 2]

        # Initialize subplots
        max_cols = 4
        ncols = min(B, max_cols)
        nrows = math.ceil(B / ncols)

        fig, axes = plt.subplots(
            nrows, ncols,
            figsize=(4.5*ncols, 4*nrows),
            sharex=True,
            sharey=True,
            squeeze=False,
            dpi=300
        )
        flat_axes = axes.flatten()

        red = np.array(mcolors.to_rgb("#b22222"))
        yellow = np.array(mcolors.to_rgb("#ffeb4d"))

        # Loop over each realization
        for b in range(B):
            ax = flat_axes[b]

            dist_field = torch.cdist(
                grid_pts.unsqueeze(0),              # [1, G, 2]
                theta[b].unsqueeze(0)               # [1, K, 2]
            )                                       # [1, G, K]

            a_cpu = self.a.cpu()
            field = (a_cpu / (self.m + dist_field.pow(2))).sum(dim=-1) + self.b    # [1, G]
            Z = torch.log(field.view(grid_size, grid_size)).detach().numpy()
        
            blues = plt.colormaps['Blues']
            colors = blues(np.linspace(0, 1, 256))
            colors[:, :3] = colors[:, :3] ** 1.25
            dark_blues = LinearSegmentedColormap.from_list('dark_blues', colors)
        
            im = ax.contourf(
                X.detach().numpy(), Y.detach().numpy(), Z,
                levels=30, cmap=dark_blues, alpha=1.0,
                vmin=np.min(Z),
                vmax=np.max(Z)
            )

            ## Plot the design trajectory
            for t in range(num_periods):
                xi_t = designs[b, t]            # [D, 2]
                frac = (t + 1) / num_periods
                color = red * (1 - frac) + yellow * frac
                ax.scatter(
                    xi_t[:, 0].detach().numpy(), xi_t[:, 1].detach().numpy(),
                    c=[color], edgecolors="k", s=70
                )

            ax.set_aspect("equal")
            ax.set_xlim(-self.policy_bounds, self.policy_bounds)
            ax.set_ylim(-self.policy_bounds, self.policy_bounds)
            ax.grid(True, color="lightgray", linewidth=0.3, alpha=0.3)
            ax.set_title(f"Trajectory {b+1}")

        for ax in flat_axes[B:]:
            ax.set_visible(False)

        fig.tight_layout()

        # Colorbar 1: Signal field 
        signal_cbar = fig.colorbar(
            im,
            ax=flat_axes[:B],
            location="right",
            label="Log Total Signal",
            format="%.2f",
            pad=-0.09,
            shrink=0.9,
        )
        vmin, vmax = im.get_clim()
        signal_cbar.set_ticks(np.linspace(vmin, vmax, 6))

        # Colorbar 2: Time-step 
        time_cmap = LinearSegmentedColormap.from_list('red_yellow', [red, yellow], N=num_periods)
        time_sm = cm.ScalarMappable(cmap=time_cmap, norm=plt.Normalize(vmin=1, vmax=num_periods))
        time_sm.set_array([])
        time_cbar = fig.colorbar(
            time_sm,
            ax=flat_axes[:B],
            location="right",
            label="Timestep",
            format="%d",
            pad=0.025,
            shrink=0.9,
        )
        time_cbar.set_ticks(np.arange(1, num_periods + 1))

        plt.show()

### 3. Neural Networks

- Defines the main deterministic and stochastic policy networks, as well as the history encoder module.  
- Defines a separate random policy that places uniformly random designs within the policy bounds. This network has no trainable parameters but is implemented as an `nn.Module` for compatibility with the `LocationFinding` class.  
- Defines a posterior network, which serves as a convenient wrapper around BayesFlow amortized posteriors. Either `CouplingFlow` or `FlowMatching` can be used as the base posterior, but `CouplingFlow` is recommended because it is significantly faster while still reasonably flexible.  


In [None]:
class HistoryEncoder(nn.Module):
    def __init__(self,
                 D: int,
                 p: int,
                 enc_h_dim: int = 256,
                 enc_out_dim: int = 128):
        super().__init__()
        self.enc_out_dim = enc_out_dim
        self.enc_fc1 = nn.Linear(D*p + D, enc_h_dim)
        self.enc_fc2 = nn.Linear(enc_h_dim, enc_out_dim)
        self.relu = nn.ReLU()

    def forward(self, hist_designs, hist_outcomes):
        batch_shape = hist_designs.shape[:-3]
        T = hist_designs.shape[-3]
        device = hist_designs.device
        dtype = hist_designs.dtype

        if T == 0:
            return torch.zeros(*batch_shape, self.enc_out_dim, device=device, dtype=dtype)

        xi_flat = hist_designs.flatten(start_dim=-2, end_dim=-1)    # [*B, T, D*p]
        inp = torch.cat([xi_flat, hist_outcomes], dim=-1)           # [*B, T, D*p + D]

        h = self.relu(self.enc_fc1(inp))    # [*B, T, enc_h_dim]
        out = self.enc_fc2(h)               # [*B, T, enc_out_dim]
        rep = out.mean(dim=-2)              # [*B, enc_out_dim]

        return rep

class DeterministicPolicyNet(nn.Module):
    def __init__(self,
                 D: int,
                 p: int,
                 policy_bounds: float,
                 enc_h_dim: int = 256,
                 enc_out_dim: int = 128,
                 h_dim: int = 64):
        super().__init__()
        self.D = D
        self.p = p
        self.policy_bounds = policy_bounds

        self.history_encoder = HistoryEncoder(D=D, p=p, enc_h_dim=enc_h_dim, enc_out_dim=enc_out_dim)

        self.fc1 = nn.Linear(enc_out_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, D*p)
        self.relu = nn.ReLU()

        # Initialization
        nn.init.xavier_uniform_(self.fc1.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight, gain=1.0)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, hist_designs, hist_outcomes):
        batch_shape = hist_designs.shape[:-3]

        rep = self.history_encoder(hist_designs, hist_outcomes)

        h = self.relu(self.fc1(rep))     # [*B, D*p]
        out = self.fc2(h)
        out = out.view(*batch_shape, self.D, self.p)

        return torch.tanh(out) * self.policy_bounds

class StochasticPolicyNet(nn.Module):
    def __init__(self,
                 D: int,
                 p: int,
                 policy_bounds: float,
                 enc_h_dim: int = 256,
                 enc_out_dim: int = 128,
                 h_dim: int = 64,
                 min_std: float = 0.01,
                 init_mean: float = 0.0,
                 init_std: float = 0.5):
        super().__init__()
        self.D = D
        self.p = p
        self.policy_bounds = policy_bounds
        self.min_std = min_std

        self.history_encoder = HistoryEncoder(D=D, p=p, enc_h_dim=enc_h_dim, enc_out_dim=enc_out_dim)

        self.fc1_mean = nn.Linear(enc_out_dim, h_dim)
        self.fc2_mean = nn.Linear(h_dim, D*p)
        self.fc1_raw_std = nn.Linear(enc_out_dim, h_dim)
        self.fc2_raw_std = nn.Linear(h_dim, D*p)
        self.relu = nn.ReLU()

        self.tanh_transform = transforms.TanhTransform(cache_size=1)
        self.scale_transform = transforms.AffineTransform(loc=0, scale=policy_bounds)

        # Initialization
        nn.init.xavier_uniform_(self.fc1_mean.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.fc1_mean.bias)
        nn.init.xavier_uniform_(self.fc2_mean.weight, gain=1.0)
        nn.init.zeros_(self.fc2_mean.bias)
        self.fc2_mean.bias.data.add_(init_mean)
        nn.init.xavier_uniform_(self.fc1_raw_std.weight, gain=nn.init.calculate_gain("relu"))
        nn.init.zeros_(self.fc1_raw_std.bias)
        nn.init.xavier_uniform_(self.fc2_raw_std.weight, gain=1.0)
        nn.init.zeros_(self.fc2_raw_std.bias)
        self.fc2_raw_std.bias.data.add_(math.log(max(init_std, 1e-8)))

    def _base_distribution(self, hist_designs, hist_outcomes):
        batch_shape = hist_designs.shape[:-3]

        rep = self.history_encoder(hist_designs, hist_outcomes)    # [*B, enc_out_dim]

        h_mean = self.relu(self.fc1_mean(rep))                     # [*B, h_dim]
        out_mean = self.fc2_mean(h_mean)                           # [*B, D*p]
        mean = out_mean.view((*batch_shape, self.D, self.p))       # [*B, D, p]

        h_raw_std = self.relu(self.fc1_raw_std(rep))
        out_raw_std = self.fc2_raw_std(h_raw_std)
        raw_std = out_raw_std.view((*batch_shape, self.D, self.p))
        std = torch.exp(raw_std) + self.min_std

        return dist.Independent(dist.Normal(mean, std), reinterpreted_batch_ndims=2)    # Event shape: [D, p]

    def _distribution(self, hist_designs, hist_outcomes):
        base_dist = self._base_distribution(hist_designs, hist_outcomes)
        return dist.TransformedDistribution(base_dist, [self.tanh_transform, self.scale_transform])

    def _entropy(self, hist_designs, hist_outcomes, n_samples=1000):
        base_dist = self._base_distribution(hist_designs, hist_outcomes)
        base_entropy = base_dist.entropy()                    # [*B]
        x = base_dist.rsample((n_samples,))                   # [n_samples, *B, D, p]

        y_tanh = self.tanh_transform(x)
        log_det_tanh = self.tanh_transform.log_abs_det_jacobian(x, y_tanh)
        y_final = self.scale_transform(y_tanh)
        log_det_scale = self.scale_transform.log_abs_det_jacobian(y_tanh, y_final)

        total_log_det = log_det_tanh + log_det_scale          # [n_samples, *B, D, p]
        total_log_det = total_log_det.sum(dim=[-2, -1])       # [n_samples, *B]

        # Total entropy: H(Y) = H(X) + E[log|det J|]
        return base_entropy + total_log_det.mean(dim=0)

    def forward(self, hist_designs, hist_outcomes):
        return self._distribution(hist_designs, hist_outcomes).rsample()     # [*B, D, p]
    
class RandomPolicyNet(nn.Module):
    def __init__(self,
                D: int,
                p: int,
                policy_bounds: float,
                ):
        super().__init__()
        self.D = D
        self.p = p
        self.policy_bounds = policy_bounds

    def forward(self, hist_designs, hist_outcomes):
            batch_shape = hist_designs.shape[:-3]
            device = hist_designs.device
            dtype = hist_designs.dtype

            rand_design = torch.rand(*batch_shape, self.D, self.p, device=device, dtype=dtype)
            rand_design = rand_design.mul(2 * self.policy_bounds).sub(self.policy_bounds)

            return rand_design
    
class PosteriorNet(nn.Module):
      def __init__(self,
                   D: int,
                   p: int,
                   K: int,
                   enc_h_dim: int = 256,
                   enc_out_dim: int = 128,
                   inf_net: str = "CouplingFlow"   # or "FlowMatching"
                   ):
          super().__init__()
          self.history_encoder = HistoryEncoder(D=D, p=p, enc_h_dim=enc_h_dim, enc_out_dim=enc_out_dim)

          if inf_net == "CouplingFlow":
              self.flow = CouplingFlow(
                  subnet="mlp",
                  depth=6,
                  transform="affine",
                  permutation="random",
                  use_actnorm=True,
                  base_distribution="normal",
                  )

          elif inf_net == "FlowMatching":
              self.flow = FlowMatching(
                  subnet="mlp",
                  subnet_kwargs={"widths": (256,)*6},      
                  base_distribution="normal",
                  use_optimal_transport=True,
                  loss_fn="mse"
                  )

          self.flow.build(xz_shape=(32, K*p), conditions_shape=(32, enc_out_dim))

      def forward(self, theta, designs, outcomes):
          enc = self.history_encoder(designs, outcomes)
          return self.flow.log_prob(theta, conditions=enc)

### 4. Training Helpers

- Defines two functions that are useful for creating the joint policy–posterior training loop.  
  - `train_policy()`: updates only the policy for a given posterior.  
  - `warmup_posterior()`: updates only the posterior, and can be used to optionally warm up the posterior on purely random designs.  


In [None]:
def train_policy(simulator, 
                 posterior,
                 optimizer,
                 scheduler,
                 num_steps: int = 3000,
                 batch_size: int = 256,
                 clip_norm: float | None = 2.0,
                 entropy_bonus: bool = False,
                 alpha: float = 0,
                 alpha_decay: float = 1.0,
                 verbose: bool = True):
    
    metrics = {"ba_loss": [], "grad_norm": [], "mean_entropy":[], "alpha_values": [], "learning_rate": []}
    current_alpha = alpha

    for step in range(num_steps):
        theta, designs, outcomes, entropies = simulator.sample(batch_shape=(batch_size,), entropy_bonus=entropy_bonus).values()

        log_prob = posterior(theta, designs, outcomes)        # [*B]
        loss = -log_prob.mean()

        if entropies is not None and entropies.numel() > 0:
            entropy_sum = entropies.sum(dim=-1)               # [*B]
            mean_entropy = entropy_sum.mean()
        else:
            mean_entropy = torch.tensor(0.0, device=loss.device)

        total_loss = loss - current_alpha*mean_entropy

        optimizer.zero_grad()
        total_loss.backward()

        if clip_norm is not None:
            torch.nn.utils.clip_grad_norm_(simulator.policy.parameters(), clip_norm)

        grad_norm = torch.sqrt(
            sum((p.grad.detach().pow(2).sum() 
                for p in simulator.policy.parameters() 
                if p.grad is not None), 
                torch.tensor(0.0)))

        optimizer.step()
        scheduler.step()

        current_alpha *= alpha_decay
        current_lr = optimizer.param_groups[0]['lr']

        # Store training metrics
        metrics["ba_loss"].append(loss.item())
        metrics["grad_norm"].append(grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm)
        metrics["mean_entropy"].append(mean_entropy.item() if isinstance(mean_entropy, torch.Tensor) else mean_entropy)
        metrics["alpha_values"].append(current_alpha)
        metrics["learning_rate"].append(current_lr)

        if verbose and ((step + 1) % 50 == 0 or step + 1 == 1):
            if entropy_bonus:
                print(f"Step {step + 1}: BA loss {metrics['ba_loss'][-1]:.3f}   grad_norm {metrics['grad_norm'][-1]:.3f}  Entropy {metrics['mean_entropy'][-1]:.3f}  alpha {current_alpha:.4f}  lr {current_lr:.2e}")
            else:
                print(f"Step {step + 1}: BA loss {metrics['ba_loss'][-1]:.3f}   grad_norm {metrics['grad_norm'][-1]:.3f}  lr {current_lr:.2e}")

    return metrics

def warmup_posterior(simulator_rand,
                     posterior,
                     num_warmup_steps: int = 1000,
                     batch_size: int = 256,
                     verbose: bool = True):
    with torch.enable_grad():  
        optimizer = torch.optim.AdamW(posterior.parameters(), lr=1e-4)

        for step in range(num_warmup_steps):
            theta, designs, outcomes, _ = simulator_rand.sample(batch_shape=(batch_size,)).values()

            log_prob = posterior(theta, designs, outcomes)
            loss = -log_prob.mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if verbose and ((step + 1) % 50 == 0 or step + 1 == 1):
                print(f"Posterior warmup step {step + 1}: Loss = {loss.item():.3f}")
    
    return posterior

### 5. Joint Training Loop (Amortized Policy + Posterior)

- Uses the helper functions to jointly train a policy and a posterior network using the Barber–Agakov lower bound.  
- Optionally accepts a pre-trained (frozen) posterior. If provided, joint updates and warm-up are skipped, and only the policy is updated each iteration.  
- Separately defines `posterior_bounds`, which specifies the area over which the posterior is warmed up.  


In [None]:
def train_joint(sim_params,
                train_params,
                policy_type="det",       # "det", "sto", or "rand"
                posterior=None):         # Optional (frozen) posterior overrides joint training and warmup

    # Unwrap simulation params
    (T, D, K, p, a, m, b, noise_std, prior_bounds, policy_bounds) = [sim_params[k] for k in [
        "T", "D", "K", "p", "a", "m", "b",
        "noise_std", "prior_bounds", "policy_bounds"]]

    # Unwrap training params
    (posterior_bounds, num_steps, num_warmup_steps, batch_size, clip_norm,
    entropy_bonus, alpha, alpha_decay, device, verbose, plot_training) = [train_params[k] for k in [
        "posterior_bounds", "num_steps", "num_warmup_steps",
        "batch_size", "clip_norm", "entropy_bonus", "alpha",
        "alpha_decay", "device", "verbose", "plot_training"]]

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if policy_type == "det":
        policy = DeterministicPolicyNet(D=D, p=p, policy_bounds=policy_bounds)
    elif policy_type == "rand":
        policy = RandomPolicyNet(D=D, p=p, policy_bounds=policy_bounds)
    elif policy_type == "sto":
        policy = StochasticPolicyNet(D=D, p=p, policy_bounds=policy_bounds)

    simulator = LocationFinding(
        policy=policy,
        prior_bounds=prior_bounds, 
        policy_bounds=policy_bounds, 
        D=D, K=K, a=a, m=m, b=b, noise_std=noise_std, T=T, p=p).to(device)
    
    if posterior is None:
        simulator_warmup = LocationFinding(
            policy=RandomPolicyNet(D=D, p=p, policy_bounds=posterior_bounds),
            policy_bounds=posterior_bounds,
            prior_bounds=posterior_bounds,
            D=D, K=K, a=a, m=m, b=b, noise_std=noise_std, T=T, p=p).to(device)
    
        with torch.enable_grad():
            posterior = PosteriorNet(D=D, p=p, K=K, inf_net="CouplingFlow").to(device)
            posterior = warmup_posterior(simulator_warmup, posterior, 
                                           num_warmup_steps=num_warmup_steps,
                                           batch_size=batch_size, 
                                           verbose=verbose)
    else:  
        posterior_new = PosteriorNet(D=D, p=p, K=K, inf_net="CouplingFlow").to(device)
        posterior_new.load_state_dict(posterior.state_dict())
        posterior = posterior_new
    
    # Training: for random policy, only train posterior; for others, train both policy and posterior
    with torch.enable_grad():
        if policy_type == "rand":
            optimizer_params = list(posterior.parameters())
        else:
            optimizer_params = list(simulator.policy.parameters()) + list(posterior.parameters())
        
        optimizer = torch.optim.AdamW(optimizer_params)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                        max_lr=1e-3,
                                                        total_steps=num_steps,
                                                        pct_start=0.4,
                                                        anneal_strategy="cos",
                                                        div_factor=25,
                                                        final_div_factor=25)
        # Collect policy training metrics
        metrics = train_policy(simulator=simulator,
                               posterior=posterior,
                               optimizer=optimizer,
                               scheduler=scheduler,
                               num_steps=num_steps,
                               batch_size=batch_size,
                               clip_norm=clip_norm,
                               entropy_bonus=entropy_bonus,
                               alpha=alpha,
                               alpha_decay=alpha_decay,
                               verbose=verbose)
        
    if plot_training: 
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

        ax1.plot(metrics["ba_loss"], linewidth=2)
        ax2.plot(metrics["grad_norm"], linewidth=2)
        ax1.set_xlabel("Step")
        ax1.set_ylabel("BA Loss")
        ax1.set_title("Training Loss")
        ax1.grid(True, alpha=0.3)
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Gradient Norm")
        ax2.set_title("Training Gradient Norm")
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    return simulator, posterior, metrics

### 6. Example Training Run


- Deterministic policy + CouplingFlow posterior with warm-up.  
- To skip warm-up, set `num_warmup_steps = 0`. To train only the policy, provide a frozen posterior to `train_joint()`.  
- For stochastic policies, setting `entropy_bonus = True` enables `alpha` and `alpha_decay`. To train with a fixed `alpha`, set `alpha_decay = 1`.  
- In this example, we set `prior_bounds = policy_bounds = posterior_bounds = 3`, which means that $\theta$ is sampled uniformly from $[-3, 3]^2$, and both the policy and posterior are trained exactly within this area.  


In [None]:
sim_params = {
    "T": 7,
    "D": 1,
    "K": 1,
    "p": 2,
    "a": [1.0],
    "m": 0.001,
    "b": 0.1,
    "noise_std": 0.5,
    "prior_bounds": 3.0,
    "policy_bounds": 3.0,    
}

train_params = {
    "posterior_bounds": 3.0,
    "num_steps": 3000,
    "num_warmup_steps": 1000,
    "batch_size": 256,
    "clip_norm": 2.0,
    "entropy_bonus": False,
    "alpha": 0.0,
    "alpha_decay": 1.0,
    "device": None,
    "verbose": True,
    "plot_training": True,
}

sim_trained, posterior_trained, _ = train_joint(policy_type="det", 
                                                      sim_params=sim_params, 
                                                      train_params=train_params,
                                                      posterior=None)

In [None]:
# Roll out the trained policy on randomly sampled theta
num_test_theta = 3
test_theta = sim_trained.prior().sample((num_test_theta,))
designs, outcomes, entropies= sim_trained.run_policy(test_theta)
sim_trained.plot_designs(test_theta, designs)

- We can also plot the trained policy trajectories alongside a Kernel Density Estimate (KDE) of the posterior conditioned on the history.

In [None]:
def plot_trajectory_and_posterior(simulator, 
                                  posterior,
                                  theta, 
                                  policy_suptitle: str | None = None) -> None:
    
    assert simulator.p == 2, "only 2-D plotting supported"

    # Use provided theta
    designs, outcomes, _ = simulator.run_policy(theta)

    theta_cpu = theta.cpu()
    designs_cpu = designs.cpu()
    outcomes_cpu = outcomes.cpu()

    B = theta.size(0)
    num_periods = designs.shape[1]   # T
    theta_cpu = theta_cpu.view(B, simulator.K, simulator.p)

    # Initialize signal field
    bounds = simulator.policy_bounds
    grid_size = 100
    x_vals = torch.linspace(-bounds, bounds, grid_size)
    y_vals = torch.linspace(-bounds, bounds, grid_size)
    X, Y = torch.meshgrid(x_vals, y_vals, indexing="ij")
    grid_pts = torch.cartesian_prod(x_vals, y_vals)
    positions = np.vstack([X.numpy().ravel(), Y.numpy().ravel()])

    fig, axes = plt.subplots(
        2, B,
        figsize=(4*B, 7),
        sharex=True,
        sharey=True,
        squeeze=False,
        dpi=300,
        constrained_layout=True
    )

    red = np.array(mcolors.to_rgb("#b22222"))
    yellow = np.array(mcolors.to_rgb("#ffeb4d"))

    # Row 1: Design trajectories
    for b in range(B):
        ax = axes[0][b]

        # Compute signal field for this trajectory
        dist_field = torch.cdist(grid_pts.unsqueeze(0), theta_cpu[b].unsqueeze(0))
        a_cpu = simulator.a.cpu()
        field = (a_cpu / (simulator.m + dist_field.pow(2))).sum(dim=-1) + simulator.b
        Z = torch.log(field.view(grid_size, grid_size)).detach().numpy()

        blues = plt.colormaps['Blues']
        colors = blues(np.linspace(0, 1, 256))
        colors[:, :3] = colors[:, :3] ** 1.25
        dark_blues = LinearSegmentedColormap.from_list('dark_blues', colors)

        im = ax.contourf(
            X.numpy(), Y.numpy(), Z,
            levels=40, cmap=dark_blues, alpha=1.0,
            vmin=np.min(Z),
            vmax=np.max(Z)
        )

        # Plot design trajectory
        for t in range(num_periods):
            xi_t = designs_cpu[b, t]
            frac = (t + 1) / num_periods
            color = red * (1 - frac) + yellow * frac
            ax.scatter(
                xi_t[:, 0].numpy(), xi_t[:, 1].numpy(),
                c=[color], edgecolors="k", s=70
            )

        ax.set_aspect("equal")
        ax.set_xlim(-bounds, bounds)
        ax.set_ylim(-bounds, bounds)
        ax.grid(True, color="lightgray", linewidth=0.3, alpha=0.3)
        ax.set_title(f"Scenario {b+1}", fontsize=12)

        if b == 0:
            ax.set_ylabel("Design Trajectory", fontsize=12)

    # Row 2: Corresponding posteriors
    n_samples = 500
    kde_levels = 300

    for b in range(B):
        ax = axes[1][b]

        with torch.no_grad():
            device = next(posterior.parameters()).device
            designs_single = designs_cpu[b:b+1].to(device)  # [1, T, D, p]
            outcomes_single = outcomes_cpu[b:b+1].to(device)  # [1, T, D]
            theta_single = theta[b:b+1].to(device)  # [1, K*p]

            # Compute EIG for this trajectory
            log_prob = posterior(theta_single, designs_single, outcomes_single)
            prior_entropy = simulator.prior().entropy()
            eig = (log_prob.mean() + prior_entropy).item()

            enc = posterior.history_encoder(designs_single, outcomes_single)
            samples = posterior.flow.sample(batch_shape=(n_samples,),
                                                conditions=enc.expand(n_samples, -1)).cpu().numpy()

            flat_samples = samples.reshape(-1, simulator.p).T

            # KDE
            kde = gaussian_kde(flat_samples)
            Z_post = kde(positions).reshape(X.shape)
            Z_post = Z_post / Z_post.max()

            contour = ax.contourf(
                X.numpy(), Y.numpy(), Z_post,
                levels=kde_levels, cmap="Purples", alpha=1.0
            )

            # Plot true parameters
            true_theta = theta_cpu[b].numpy()
            for k in range(simulator.K):
                scatter = ax.scatter(
                    true_theta[k, 0], true_theta[k, 1],
                    c='red', s=150, marker='x',
                    linewidth=2.5, zorder=10
                )

            # Add EIG legend
            ax.text(0.97, 0.97, f"Final EIG: {eig:.2f}",
                    transform=ax.transAxes,
                    verticalalignment='top',
                    horizontalalignment='right',
                    bbox=dict(boxstyle='square', facecolor='white', alpha=0.8),
                    fontsize=10)
    
        ax.set_aspect("equal")
        ax.set_xlim(-bounds, bounds)
        ax.set_ylim(-bounds, bounds)
        ax.grid(True, color="lightgray", linewidth=0.3, alpha=0.3)

        if b == 0:
            ax.set_ylabel("Final Posterior", fontsize=13)

    # Colorbar 1: Signal field (for row 1)
    row1_axes = axes[0, :]
    signal_cbar = fig.colorbar(im, ax=row1_axes, location="right",
                                label="Log Total Signal", format="%.2f",
                                pad=0.02, shrink=0.9)
    signal_cbar.ax.locator_params(nbins=5)
    vmin, vmax = im.get_clim()
    ticks = np.linspace(vmin, vmax, 6)
    signal_cbar.set_ticks(ticks)

    # Colorbar 2: Time step (for row 1)
    time_cmap = LinearSegmentedColormap.from_list('red_yellow', [red, yellow], N=num_periods)
    time_sm = cm.ScalarMappable(cmap=time_cmap, norm=plt.Normalize(vmin=1, vmax=num_periods))
    time_sm.set_array([])
    time_cbar = fig.colorbar(time_sm, ax=row1_axes, location="right",
                            label="Time Step", format="%d",
                            pad=0.03, shrink=0.9)
    time_cbar.set_ticks(np.arange(1, num_periods + 1))

    # Colorbar 3: Posterior density (for row 2)
    row2_axes = axes[1, :]
    posterior_cbar = fig.colorbar(contour, ax=row2_axes, location="right",
                                label="Posterior Density", format="%.2f",
                                pad=0.03, shrink=0.9)
    vmin, vmax = contour.get_clim()
    ticks = np.linspace(vmin, vmax, 6)
    posterior_cbar.set_ticks(ticks)

    if policy_suptitle:
        fig.suptitle(f"{policy_suptitle} Policy in {B} Scenarios", fontsize=16)

    plt.show()

In [None]:
plot_trajectory_and_posterior(sim_trained, posterior_trained, test_theta, policy_suptitle="Deterministic")

### 7. Experiment #1: Effect of Entropy Regularization Coefficient ($\alpha$)

- The experiment code is divided into two stages: training and evaluation.  
- In the training stage, the deterministic model (policy + posterior) and the random model (posterior only) are trained once, since these do not depend on `alpha`.  
- The stochastic policy is then re-trained for each fixed value in `alpha_values` (with `alpha_decay = 1`), and the results are saved.  


In [None]:
# ---------- Training Stage ---------- #

alpha_values = [0.0, 0.001, 0.01, 0.1, 1.0]

sim_params = {
    "T": 7,
    "D": 1,
    "K": 1,
    "p": 2,
    "a": [1.0],
    "m": 0.001,
    "b": 0.1,
    "noise_std": 0.5,
    "prior_bounds": 3.0,
    "policy_bounds": 3.0,
}

train_params = {
    "posterior_bounds": 3.0,
    "num_steps": 3000,          
    "num_warmup_steps": 1000,
    "batch_size": 256,
    "clip_norm": 2.0,
    "entropy_bonus": False,
    "alpha": 0.0,
    "alpha_decay": 1.0,
    "device": None,
    "verbose": True,
    "plot_training": False,
}

device = train_params["device"] or torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Warmup shared posterior once
print("Warming up shared posterior...")
simulator_warmup = LocationFinding(
    policy=RandomPolicyNet(D=sim_params["D"], p=sim_params["p"], policy_bounds=train_params["posterior_bounds"]),
    policy_bounds=train_params["posterior_bounds"],
    prior_bounds=train_params["posterior_bounds"],
    D=sim_params["D"], K=sim_params["K"], a=sim_params["a"], m=sim_params["m"], b=sim_params["b"],
    noise_std=sim_params["noise_std"], T=sim_params["T"], p=sim_params["p"]).to(device)

shared_posterior = PosteriorNet(D=sim_params["D"], p=sim_params["p"], K=sim_params["K"], inf_net="CouplingFlow").to(device)
shared_posterior = warmup_posterior(simulator_warmup,
                                    shared_posterior,
                                    num_warmup_steps=train_params["num_warmup_steps"],
                                    batch_size=train_params["batch_size"],
                                    verbose=train_params["verbose"])

# Train random policy once
print("Training Random Policy...")
sim_rand, post_rand, _ = train_joint(sim_params=sim_params,
                                     train_params=train_params,
                                     policy_type="rand",
                                     posterior=shared_posterior)

# Train deterministic policy once
print("Training Deterministic Policy...")
sim_det, post_det, _ = train_joint(sim_params=sim_params,
                                   train_params=train_params,  
                                   policy_type="det",
                                   posterior=shared_posterior)

# Train stochastic policy with different alpha values
models_by_alpha = {}

for alpha in alpha_values:
    print(f"Training Stochastic Policy (alpha = {alpha})...")
    tp_sto = {**train_params, "entropy_bonus": True, "alpha": alpha, "alpha_decay": 1.0}
    sim_sto, post_sto, _ = train_joint(sim_params=sim_params,
                                       train_params=tp_sto,
                                       policy_type="sto",
                                       posterior=shared_posterior)
    models_by_alpha[alpha] = {"sim_sto": sim_sto, "post_sto": post_sto}

- For the evaluation stage, we first define a helper function to approximate the average EIG of the trained policy–posterior pair on a testing simulator with specified parameters.  
- The parameter `n_mc` controls the number of sampled $\theta$, while `n_trajectories` specifies the number of trajectories rolled out for each $\theta$.  


In [None]:
@torch.no_grad()
def eval_policy(policy, 
                posterior, 
                sim_params_test,
                n_mc=300, 
                n_trajectories=10,
                eval_thetas = None,      # Optionally provide pre-sampled eval thetas 
                metric: str = "eig"      # or "log_prob" to evaluate log p(theta | h_T)" 
                ):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Unwrap test simulation params
    (T, D, K, p, a, m, b, noise_std, prior_bounds, policy_bounds) = [sim_params_test[k] for k in [
        "T", "D", "K", "p", "a", "m", "b",
        "noise_std", "prior_bounds", "policy_bounds"]]
    
    simulator_test = LocationFinding(
        policy=policy,
        prior_bounds=prior_bounds, 
        policy_bounds=policy_bounds, 
        D=D, K=K, a=a, m=m, b=b, noise_std=noise_std, T=T, p=p).to(device)
    
    metrics = []

    # Sample thetas uniformly from policy_bounds
    if eval_thetas is None:
        eval_thetas = torch.rand(n_mc, K*p, device=device) * (2 * policy_bounds) - policy_bounds     # [n_mc, K*p]

    for theta in eval_thetas:
        theta_batch = theta.unsqueeze(0)      # [1, K*p]

        # Use batched trajectory processing
        theta_repeated = theta_batch.repeat(n_trajectories, 1)              # [n_trajectories, K*p]
        designs, outcomes, _ = simulator_test.run_policy(theta_repeated)
        log_probs = posterior(theta_repeated, designs, outcomes)            # [n_trajectories]

        # Average log probability across trajectories for this theta
        avg_log_prob = log_probs.mean().item()

        if metric == "log_prob":
            metrics.append(avg_log_prob)

        elif metric == "eig":
            prior_entropy = simulator_test.prior().entropy()
            eig_estimate = avg_log_prob + prior_entropy.item()
            metrics.append(eig_estimate)
    
    metrics_np = np.array(metrics)
    mean = np.mean(metrics_np)
    se = np.std(metrics_np) / np.sqrt(len(metrics_np))

    return metrics_np, mean, se

In [None]:
# ---------- Evaluation Stage ---------- #

n_mc = 100            
n_trajectories = 5 

sim_params_test = {
    "T": 7,
    "D": 1,
    "K": 1,
    "p": 2,
    "a": [1.0],
    "m": 0.001,
    "b": 0.1,
    "noise_std": 0.5,
    "prior_bounds": 3.0,
    "policy_bounds": 3.0,
}

# Evaluate deterministic policy
print(f"Evaluating Deterministic Policy...")
det_eigs, det_mean, det_se = eval_policy(sim_det.policy, 
                       post_det, 
                       sim_params_test,
                       n_mc=n_mc, 
                       n_trajectories=n_trajectories,
                       metric="eig")

# Evaluate random policy
print(f"Evaluating Random Policy...")
rand_eigs, rand_mean, rand_se = eval_policy(sim_rand.policy, 
                        post_rand, 
                        sim_params_test,
                        n_mc=n_mc, 
                        n_trajectories=n_trajectories,
                        metric="eig")

# Evaluate stochastic policies with different alpha values
sto_mean_list = []
sto_se_list = []
sto_eigs_list = []

for alpha in alpha_values:
    print(f"Evaluating Stochastic Policy with alpha={alpha}...")
    sim_sto = models_by_alpha[alpha]['sim_sto']
    post_sto = models_by_alpha[alpha]['post_sto']

    sto_eigs, sto_mean, sto_se = eval_policy(sim_sto.policy, 
                           post_sto, 
                           sim_params_test,
                           n_mc=n_mc, 
                           n_trajectories=n_trajectories, 
                           metric="eig")
    
    # Store results
    sto_mean_list.append(sto_mean)
    sto_se_list.append(sto_se)
    sto_eigs_list.append(sto_eigs)
    models_by_alpha[alpha]['eigs'] = sto_eigs

In [None]:
# ---------- Plot Results ---------- #

fig, ax = plt.subplots(figsize=(7, 5), dpi=300)

# Log scale and ticks
ax.set_xscale('log')
alpha_plot = [a if a > 0 else 1e-4 for a in alpha_values]
xticks = [1e-4, 1e-3, 1e-2, 1e-1, 1e0]
ax.set_xticks(xticks)
ax.set_xticklabels(["0", r"$10^{-3}$", r"$10^{-2}$", r"$10^{-1}$", r"$10^{0}$"])
ax.grid(True, which='major', alpha=0.3)
ax.minorticks_off()

# Ensure red/blue line extend to subplot walls
left_buffer  = xticks[0] * 0.8
right_buffer = xticks[-1] * 1.2
ax.set_xlim(left_buffer, right_buffer)
xL, xR = ax.get_xlim()

# Deterministic policy: mean line + SE band 
ax.axhline(y=det_mean, color='red', linestyle='-', linewidth=2, label='Deterministic', alpha=0.8)
ax.fill_between([xL, xR], det_mean - det_se, det_mean + det_se, color='red', alpha=0.2)

# Random policy: mean line + SE band 
ax.axhline(y=rand_mean, color='blue', linestyle='-', linewidth=2, label='Random', alpha=0.8)
ax.fill_between([xL, xR], rand_mean - rand_se, rand_mean + rand_se, color='blue', alpha=0.2)

# Stochastic policy: mean line + SE band 
ax.plot(alpha_plot, sto_mean_list, '-', color='purple', linewidth=2, label='Stochastic')
sto_means_arr = np.asarray(sto_mean_list)
sto_ses_arr = np.asarray(sto_se_list)  
ax.fill_between(alpha_plot, sto_means_arr - sto_ses_arr, sto_means_arr + sto_ses_arr, color='purple', alpha=0.2, zorder=0)

# Labels and formatting
ax.set_xlabel(r'Entropy Regularization Coefficient ($\alpha$)', fontsize=12)
ax.set_ylabel('Final EIG (BA bound)', fontsize=12)
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax.set_title(r'Effect of Entropy Regularization Strength ($T=7$)', fontsize=14, y=1.02)
ax.grid(True, alpha=0.3, which='both')
ax.legend(loc='best', fontsize=11)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\n--- Alpha Experiment Summary ---")
print(f"Deterministic: {det_mean:.3f} ± {det_se:.3f}")
print(f"Random:        {rand_mean:.3f} ± {rand_se:.3f}")
print("\nStochastic by Alpha:")
for alpha, mean, se in zip(alpha_values, sto_mean_list, sto_se_list):
        print(f"  α = {alpha:5.3f}: {mean:.3f} ± {se:.3f}")


### 8. Experiment #2: Prior Shift

- Repeat a similar procedure as in previous experiment, but this time our `prior_bounds` = 3.0 is smaller than `posterior_bounds` = `policy_bounds` = 12.0
- For the stochastic policy, we proceed with a fixed `alpha` = 0.01.

In [None]:
# ---------- Training Stage ---------- #

sim_params = {
    "T": 7,
    "D": 1,
    "K": 1,
    "p": 2,
    "a": [1.0],
    "m": 0.001,
    "b": 0.1,
    "noise_std": 0.5,
    "prior_bounds": 3.0,
    "policy_bounds": 12.0,
}

train_params = {
    "posterior_bounds": 12.0,
    "num_steps": 3000,
    "num_warmup_steps": 1000,
    "batch_size": 256,
    "clip_norm": 2.0,
    "entropy_bonus": False,
    "alpha": 0.0,
    "alpha_decay": 1.0,
    "device": None,
    "verbose": False,
    "plot_training": False,
}

device = train_params["device"] or torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Warmup shared posterior once with random policy exploring the full posterior_bounds space
print("Warming up shared posterior...")
simulator_warmup = LocationFinding(
    policy=RandomPolicyNet(D=sim_params["D"], p=sim_params["p"], policy_bounds=train_params["posterior_bounds"]),
    policy_bounds=train_params["posterior_bounds"],
    prior_bounds=train_params["posterior_bounds"],
    D=sim_params["D"], K=sim_params["K"], a=sim_params["a"], m=sim_params["m"], b=sim_params["b"],
    noise_std=sim_params["noise_std"], T=sim_params["T"], p=sim_params["p"]).to(device)

shared_posterior = PosteriorNet(D=sim_params["D"], p=sim_params["p"], K=sim_params["K"], inf_net="CouplingFlow").to(device)
shared_posterior = warmup_posterior(simulator_warmup,
                                    shared_posterior,
                                    num_warmup_steps=train_params["num_warmup_steps"],
                                    batch_size=train_params["batch_size"],
                                    verbose=train_params["verbose"])

# Train deterministic policy
print("Training Deterministic Policy...")
sim_det, post_det, _ = train_joint(sim_params=sim_params,
                                             train_params=train_params,
                                             policy_type="det",
                                             posterior=shared_posterior)

# Train random policy
print("Training Random Policy...")
sim_rand, post_rand, _ = train_joint(sim_params=sim_params,
                                                train_params=train_params,
                                                policy_type="rand",
                                                posterior=shared_posterior)

# Train stochastic policy
print("Training Stochastic Policy...")
tp_sto = {**train_params, "entropy_bonus": True, "alpha": 0.01, "alpha_decay": 1.0}
sim_sto, post_sto, _ = train_joint(sim_params=sim_params,
                                             train_params=tp_sto,
                                             policy_type="sto",
                                             posterior=shared_posterior)


In [None]:
# ---------- Evaluation Stage ---------- #

n_mc = 1000            
n_trajectories = 10 

# Sample shared eval thetas in [-policy_bounds, policy_bounds]
shared_eval_thetas = torch.rand(n_mc, sim_params["K"]*sim_params["p"], 
                                device=device) * (2 * sim_params["policy_bounds"]) - sim_params["policy_bounds"]

# Evaluate random policy
log_prob_rand, _, _ = eval_policy(sim_rand.policy, 
                            post_rand,
                            sim_params,
                            n_mc=n_mc,
                            n_trajectories=n_trajectories,
                            eval_thetas=shared_eval_thetas,
                            metric="log_prob")

# Evaluate deterministic policy
log_prob_det, _, _ = eval_policy(sim_det.policy, 
                           post_det,
                           sim_params,
                           n_mc=n_mc,
                           n_trajectories=n_trajectories,
                           eval_thetas=shared_eval_thetas,
                           metric="log_prob")

# Evaluate stochastic policy
log_prob_sto, _, _ = eval_policy(sim_sto.policy, 
                           post_sto,
                           sim_params,
                           n_mc=n_mc,
                           n_trajectories=n_trajectories,
                           eval_thetas=shared_eval_thetas,
                           metric="log_prob")

# Store results
results = {"rand": log_prob_rand,
           "det":  log_prob_det,
           "sto":  log_prob_sto}

In [None]:
# ---------- Plot Results ---------- #

def plot_misspecification_heatmaps(eval_thetas, results, sim_params) -> None:

    prior_bounds = sim_params['prior_bounds']      # 3.0
    policy_bounds = sim_params['policy_bounds']    # 12.0

    fig, axes = plt.subplots(1, 3,
                            figsize=(12, 4),
                            sharex=True,
                            sharey=True,
                            constrained_layout=True,
                            dpi=300)

    policy_names = ['Random', 'Deterministic', 'Stochastic']
    policy_keys = ['rand', 'det', 'sto']

    # Helper function to re-scale the log_probs
    def _transform_values(values):
        return -np.log(np.abs(values) + 1e-10)

    all_transformed = []
    for p in policy_keys:
        transformed = _transform_values(results[p])
        all_transformed.append(transformed)

    # Find global min and max for consistent colormap
    all_transformed_concat = np.concatenate(all_transformed)
    vmin, vmax = np.percentile(all_transformed_concat, [5, 95])
    levels = np.linspace(vmin, vmax, 60)

    # Create grid for interpolation
    x_grid = np.linspace(-policy_bounds, policy_bounds, 200)
    y_grid = np.linspace(-policy_bounds, policy_bounds, 200)
    X_grid, Y_grid = np.meshgrid(x_grid, y_grid)

    for idx, (ax, policy_key, policy_name) in enumerate(zip(axes, policy_keys, policy_names)):
        
        log_probs = results[policy_key]
        transformed_values = _transform_values(log_probs)

        XY = np.column_stack([X_grid.ravel(), Y_grid.ravel()])

        rbf = RBFInterpolator(eval_thetas.cpu(), transformed_values,
                              kernel='gaussian',
                              smoothing=0,
                              epsilon=3,
                              neighbors=3)
        
        Z = rbf(XY).reshape(X_grid.shape)
        Zp = np.clip(Z, vmin, vmax)
        im = ax.contourf(X_grid, Y_grid, Zp, levels=levels, cmap='viridis')

        # Add boundary square for prior bounds
        if policy_key == "sto": # Only the stochastic subplot gets a legend 
            rect = plt.Rectangle((-prior_bounds, -prior_bounds),
                                 2*prior_bounds, 2*prior_bounds,
                                 fill=False, edgecolor='red', linewidth=1.3,
                                 linestyle='-', alpha=0.9, label=f"Prior (±{prior_bounds:.1f})")
            ax.add_patch(rect)
            ax.legend(loc="upper right", fontsize=9, frameon=True)
        else:
            rect = plt.Rectangle((-prior_bounds, -prior_bounds),
                                 2*prior_bounds, 2*prior_bounds,
                                 fill=False, edgecolor='red', linewidth=1.3,
                                 linestyle='-', alpha=0.9)
            ax.add_patch(rect)

        # Formatting
        ax.set_aspect('equal')
        ax.set_xlim(-policy_bounds, policy_bounds)
        ax.set_ylim(-policy_bounds, policy_bounds)

        tick_vals = np.arange(-policy_bounds, policy_bounds+1e-6, 3.0)
        ax.set_xticks(tick_vals)
        ax.set_yticks(tick_vals)

        ax.set_title(policy_name, fontsize=14)
        ax.grid(True, alpha=0.2, color='white', linewidth=0.3)

        if idx == 0:
            ax.set_ylabel(r'$\theta_2$', fontsize=12)
        ax.set_xlabel(r'$\theta_1$', fontsize=12)

    # Add colorbar
    cbar = fig.colorbar(im, ax=axes, location='right',
                        label = r'$-\log\,|\,\mathrm{Posterior\ Log\ Prob}\,|$', format='%.2f',
                        pad=0.02, shrink=0.9)
    cbar.set_label(r'$-\log\,|\,\mathrm{Posterior\ Log\ Prob}\,|$', fontsize=12)
    cbar.ax.locator_params(nbins=6)

    plt.show()


plot_misspecification_heatmaps(shared_eval_thetas, results, sim_params) 

### Experiment #3: Simulator Shift

- Same procedure, but this time there is no prior shift, and the policies are trained on baseline parameters $\sigma_y^2 = 0.5$, $a = 1.0$, $b = 0.1$.
- As in the previous experiment, the stochastic policy is trained with fixed `alpha`= 0.01.


In [None]:
# ---------- Training Stage ---------- #

sim_params = {
    "T": 7,
    "D": 1,
    "K": 1,
    "p": 2,
    "a": [1.0],
    "m": 0.001,
    "b": 0.1,
    "noise_std": 0.5,
    "prior_bounds": 3.0,
    "policy_bounds": 3.0,
}

train_params = {
    "posterior_bounds": 3.0,
    "num_steps": 3000,
    "num_warmup_steps": 1000,
    "batch_size": 256,
    "clip_norm": 2.0,
    "entropy_bonus": False,
    "alpha": 0.0,
    "alpha_decay": 1.0,
    "device": None,
    "verbose": False,
    "plot_training": False,
}

device = train_params["device"] or torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Warmup shared posterior once with random policy exploring the full posterior_bounds space
print("Warming up shared posterior...")
simulator_warmup = LocationFinding(
    policy=RandomPolicyNet(D=sim_params["D"], p=sim_params["p"], policy_bounds=train_params["posterior_bounds"]),
    policy_bounds=train_params["posterior_bounds"],
    prior_bounds=train_params["posterior_bounds"],
    D=sim_params["D"], K=sim_params["K"], a=sim_params["a"], m=sim_params["m"], b=sim_params["b"],
    noise_std=sim_params["noise_std"], T=sim_params["T"], p=sim_params["p"]).to(device)

shared_posterior = PosteriorNet(D=sim_params["D"], p=sim_params["p"], K=sim_params["K"], inf_net="CouplingFlow").to(device)
shared_posterior = warmup_posterior(simulator_warmup,
                                    shared_posterior,
                                    num_warmup_steps=train_params["num_warmup_steps"],
                                    batch_size=train_params["batch_size"],
                                    verbose=train_params["verbose"])

# Train deterministic policy
print("Training Deterministic Policy...")
sim_det, post_det, _ = train_joint(sim_params=sim_params,
                                             train_params=train_params,
                                             policy_type="det",
                                             posterior=shared_posterior)

# Train random policy
print("Training Random Policy...")
sim_rand, post_rand, _ = train_joint(sim_params=sim_params,
                                                train_params=train_params,
                                                policy_type="rand",
                                                posterior=shared_posterior)

# Train stochastic policy
print("Training Stochastic Policy...")
tp_sto = {**train_params, "entropy_bonus": True, "alpha": 0.01, "alpha_decay": 1.0}
sim_sto, post_sto, _ = train_joint(sim_params=sim_params,
                                             train_params=tp_sto,
                                             policy_type="sto",
                                             posterior=shared_posterior)


In [None]:
# ---------- Evaluation Stage ---------- #

param_ranges = {'noise_std': np.linspace(0.0001, 2.0, 10),
                'a': np.linspace(0.0, 4.0, 10),
                'b': np.linspace(0.0, 0.8, 10)}

n_mc = 100          
n_trajectories = 5 

# Evaluate each policy on every parameter value in param_ranges
sensitivity_results = {}

for param_name in param_ranges.keys():
    print(f"Evaluating {param_name.upper()}...")

    param_values = param_ranges[param_name]
    sensitivity_results[param_name] = {'values': param_values,
                                       'det': {'means': [], 'ses': [], 'all_eigs': []},
                                       'rand': {'means': [], 'ses': [], 'all_eigs': []},
                                       'sto': {'means': [], 'ses': [], 'all_eigs': []}}

    for i, param_value in enumerate(param_values):
    
        if i % 2 == 0: 
            print(f"  [{i+1}/{len(param_values)}] {param_name}={param_value:.3f}")

        test_params = sim_params.copy()
        test_params[param_name] = param_value

        det_eigs, det_mean, det_se = eval_policy(sim_det.policy,
                               post_det,
                               test_params,
                               n_mc=n_mc,
                               n_trajectories=n_trajectories,
                               metric="eig")
        
        sto_eigs, sto_mean, sto_se = eval_policy(sim_sto.policy,
                               post_sto,
                               test_params,
                               n_mc=n_mc,
                               n_trajectories=n_trajectories,
                               metric="eig")

        rand_eigs, rand_mean, rand_se = eval_policy(sim_rand.policy,
                               post_rand,
                               test_params,
                               n_mc=n_mc,
                               n_trajectories=n_trajectories,
                               metric="eig")

        # Store results
        sensitivity_results[param_name]['det']['means'].append(det_mean)
        sensitivity_results[param_name]['det']['ses'].append(det_se)
        sensitivity_results[param_name]['det']['all_eigs'].append(det_eigs)
        sensitivity_results[param_name]['rand']['means'].append(rand_mean)
        sensitivity_results[param_name]['rand']['ses'].append(rand_se)
        sensitivity_results[param_name]['rand']['all_eigs'].append(rand_eigs)
        sensitivity_results[param_name]['sto']['means'].append(sto_mean)
        sensitivity_results[param_name]['sto']['ses'].append(sto_se)
        sensitivity_results[param_name]['sto']['all_eigs'].append(sto_eigs)


In [None]:
# ---------- Plot Results ---------- #

fig, axes = plt.subplots(1, 3,
                        figsize=(14, 5),
                        sharey=False)

x_labels = {'noise_std': r'Noise Std. ($\sigma_y$)',
            'a': r'Signal Strength ($a$)',
            'b': r'Background Signal ($b$)'
}

policy_styles = {'det': {'color': 'red', 'label': 'Deterministic'},
                 'sto': {'color': 'purple', 'label': 'Stochastic'},
                 'rand': {'color': 'blue', 'label': 'Random'}}

for idx, param_name in enumerate(param_ranges.keys()):
    ax = axes[idx]

    param_values = sensitivity_results[param_name]['values']
    baseline_value = sim_params[param_name]

    for policy_name in ['det', 'sto', 'rand']:
        means = np.array(sensitivity_results[param_name][policy_name]['means'])
        ses = np.array(sensitivity_results[param_name][policy_name]['ses']) 
        style = policy_styles[policy_name]

        ax.plot(param_values, means,
                color=style['color'],
                label=style['label'], linestyle='-',
                linewidth=1.5, markersize=8, alpha=0.8)

        ax.fill_between(param_values,
                        means - ses,
                        means + ses,
                        color=style['color'], alpha=0.2)

    ax.axvline(x=baseline_value, color='black', linestyle='--', alpha=0.7, linewidth=1.5, label="Baseline")

    # Formatting
    ax.set_xlabel(x_labels[param_name], fontsize=12)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=6)) 
    ax.grid(True, alpha=0.3)

    if idx == 2:
        ax.legend(loc='best')

axes[0].set_ylabel('Final EIG (BA bound)', fontsize=12)

plt.suptitle("Policy + Posterior Sensitivity to Simulator Shift", fontsize=14)
plt.tight_layout()
plt.show()