<a href="https://colab.research.google.com/github/andrewbermingham/PACBayesPINNs/blob/main/PINNs_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Burgers PINN + PAC-Bayes — Clean Experiments (Colab)
# ----------------------------------------------------
# - Uses Raissi et al.'s burgers_shock.mat to build a fixed IC/BC pool and dense eval grid
# - Deterministic PINN baseline (9×20 tanh) + PAC-Bayes PINN (diagonal Gaussian, PBB-style)
# - Reproducible (Nu, Nf) subsampling; relative L2 on dense grid; PAC-Bayes certificate

# If SciPy isn't present:  (Colab usually has it)
# !pip install scipy --quiet

import os, math, time, random, json, urllib.request, gc
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, List

import numpy as np
import scipy.io as sio
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# Use float64 (common in PINN repros for stability with L-BFGS)
torch.set_default_dtype(torch.float64)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

Device: cuda


In [None]:
# --- REPLACE wandb_start with this ---
import wandb
from dataclasses import asdict

def wandb_start(cfg, project, group, name, tags=None):
    run = wandb.init(
        project=project,
        group=group,
        name=name,
        tags=tags or [],
        config=asdict(cfg),
        reinit=True,
    )
    # Make "epoch" the shared step axis
    wandb.define_metric("epoch")
    wandb.define_metric("train/*", step_metric="epoch")
    wandb.define_metric("lbfgs/*", step_metric="epoch")
    wandb.define_metric("pbb/*", step_metric="epoch")
    wandb.define_metric("eval/*", step_metric="epoch")
    return run


In [None]:
import matplotlib.pyplot as plt

def _wandb_images_from_prediction(U_pred, U_ref, t_grid, x_grid, title_prefix=""):
    """Return a dict of W&B Image objects: heatmap, error heatmap, and 1D slices.
       Robust to grids whose max time < 1.0 (e.g., 0.99...)."""
    imgs = {}

    # Heatmap of prediction
    fig = plt.figure(figsize=(6,3.6))
    plt.imshow(U_pred, extent=[x_grid[0], x_grid[-1], t_grid[-1], t_grid[0]],
               aspect='auto', origin='upper')
    plt.colorbar(); plt.xlabel('x'); plt.ylabel('t')
    plt.title(f'{title_prefix} U_pred')
    imgs['pred_heatmap'] = wandb.Image(fig); plt.close(fig)

    # Heatmap of absolute error
    fig = plt.figure(figsize=(6,3.6))
    plt.imshow(np.abs(U_pred - U_ref), extent=[x_grid[0], x_grid[-1], t_grid[-1], t_grid[0]],
               aspect='auto', origin='upper')
    plt.colorbar(); plt.xlabel('x'); plt.ylabel('t')
    plt.title(f'{title_prefix} |U_pred - U_ref|')
    imgs['error_heatmap'] = wandb.Image(fig); plt.close(fig)

    # Robust slices: pick targets within [t_min, t_max] and snap to nearest indices
    Nt = U_ref.shape[0]
    t_targets = np.linspace(float(t_grid[0]), float(t_grid[-1]), 4)

    def nearest_index(val):
        return int(np.argmin(np.abs(t_grid - val)))

    idxs = [nearest_index(tv) for tv in t_targets]
    # Clamp and de-duplicate
    idxs = sorted(set(max(0, min(Nt-1, k)) for k in idxs))

    fig = plt.figure(figsize=(6,3.6))
    for k in idxs:
        plt.plot(x_grid, U_ref[k, :], '--', linewidth=1.5, label=f'exact t≈{t_grid[k]:.2f}')
        plt.plot(x_grid, U_pred[k, :],  linewidth=1.0, label=f'pred  t≈{t_grid[k]:.2f}')
    plt.xlabel('x'); plt.ylabel('u'); plt.title(f'{title_prefix} slices')
    plt.legend(ncol=2, fontsize=9)
    imgs['slices'] = wandb.Image(fig); plt.close(fig)

    return imgs



In [None]:
from dataclasses import dataclass
import math
import torch
from typing import Optional

@dataclass
class PINNConfig:
    # --- Domain / PDE ---
    t_min: float = 0.0
    t_max: float = 1.0
    x_min: float = -1.0
    x_max: float = 1.0
    nu: float = 0.01 / math.pi

    # --- Data sizes (new naming) ---
    N_ic: int = 100          # total IC+BC points drawn from pool
    N_f: int = 10000         # collocation points actually used
    Nf_master: int = 20000   # size of fixed master LHS (>= max N_f you’ll sweep)

    # --- Network (Raissi baseline: 9 layers total -> 8 hidden × 20) ---
    hidden_layers: int = 8
    hidden_width: int = 20
    activation: str = "tanh"

    # --- Training / seeds ---
    seed: int = 1234

    # --- Deterministic PINN (Adam warmup) ---
    adam_epochs: int = 2000
    adam_lr: float = 1e-3

    # --- Bounded loss (kept = 1) ---
    loss_cap_B: float = 1.0
    loss_type: str = "clip"
    s_scale: float = 1e-3
    alpha: float = 2.2

    # in PINNConfig
    cert_track_every: int = 200    # evaluate & compare bound every N epochs
    cert_track_mc: int = 512       # MC samples for the *tracking* bound during training
    save_best_path: Optional[str] = None  # e.g., "pbb_best_state.pt" (if None, don’t write to disk)


    # For PAC-Bayes PINN objective (clipping budgets and weights)
    Bu: float = 1.0
    Bf: float = 1.0
    lambda_u: float = 0.5
    lambda_f: float = 0.5

    # --- PAC-Bayes confidence splits ---
    delta: float = 0.05
    delta_prime: float = 0.01

    # --- Certificates / eval ---
    eval_mat_path: Optional[str] = "burgers_shock.mat"
    relL2_mc: int = 64
    use_mat_for_data: bool = True

    # --- L-BFGS refinement for deterministic baseline ---
    use_lbfgs: bool = True
    lbfgs_max_iter: int = 10
    lbfgs_history_size: int = 50
    lbfgs_tolerance_grad: float = 1e-10
    lbfgs_tolerance_change: float = 1e-12

    # --- Device ---
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # ===========================
    # PAC-Bayes training setup (paper-aligned)
    # ===========================
    prior_sigma0: float = 0.10
    pbb_optimizer: str = "sgdm"
    pbb_lr: float = 1e-3
    pbb_epochs: int = 5000
    pbb_momentum: float = 0.95
    pbb_mc_train: int = 1
    cert_mc: int = 150000

    mat_path: Optional[str] = "burgers_shock.mat"

    # --- Optional compatibility aliases so old code that still uses Nu/Nf won’t crash ---
    def __post_init__(self):
        if not hasattr(self, "Nu"):
            self.Nu = self.N_ic
        if not hasattr(self, "Nf"):
            self.Nf = self.N_f


In [None]:
# -----------------------
# Utilities
# -----------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_burgers_mat(path: str = "burgers_shock.mat") -> str:
    if os.path.exists(path):
        return path
    urls = [
        "https://raw.githubusercontent.com/maziarraissi/PINNs/master/appendix/Data/burgers_shock.mat",
        "https://github.com/maziarraissi/PINNs/raw/master/appendix/Data/burgers_shock.mat",
    ]
    for url in urls:
        try:
            print(f"[Data] downloading {url} -> {path}")
            urllib.request.urlretrieve(url, path)
            if os.path.exists(path) and os.path.getsize(path) > 0:
                return path
        except Exception as e:
            print("Download failed from:", url, "error:", e)
    raise FileNotFoundError("Could not acquire burgers_shock.mat")

def lhs(n: int, d: int, rng: np.random.Generator) -> np.ndarray:
    cut = np.linspace(0, 1, n + 1)
    u = rng.random((n, d))
    a = cut[:n]
    b = cut[1 : n + 1]
    rd = u * (b - a)[:, None] + a[:, None]
    H = np.zeros_like(rd)
    for j in range(d):
        order = rng.permutation(n)
        H[:, j] = rd[order, 0]
    return H

def scale_tx_to_unit(tx: torch.Tensor, t_min, t_max, x_min, x_max) -> torch.Tensor:
    lb = torch.tensor([t_min, x_min], device=tx.device, dtype=tx.dtype)
    ub = torch.tensor([t_max, x_max], device=tx.device, dtype=tx.dtype)
    return 2.0 * (tx - lb) / (ub - lb) - 1.0


In [None]:
# -----------------------
# Dataset manager
# -----------------------
class BurgersData:
    """
    Loads the Raissi .mat, builds:
      - pool of IC/BC points X_u_pool (size ≈456) with exact u values
      - a master collocation set X_f_master (size Nf_master, then append X_u_pool)
      - a high-res eval grid (t_grid, x_grid, U_exact)
    Provides deterministic subsampling for any (Nu, Nf) given a seed.
    """
    def __init__(self, cfg: PINNConfig):
        self.cfg = cfg
        set_seed(cfg.seed)

        # Load .mat
        mat_path = ensure_burgers_mat(
        getattr(cfg, "mat_path", getattr(cfg, "eval_mat_path", "burgers_shock.mat"))
        )


        data = sio.loadmat(mat_path)
        t = data["t"].flatten()[:, None]  # (Nt,1)
        x = data["x"].flatten()[:, None]  # (Nx,1)
        U = np.real(data["usol"]).T       # (Nt, Nx), u(t,x)
        self.t_all = t
        self.x_all = x
        self.U_all = U

        # Pool IC/BC from mat grid (Dirichlet boundaries in this dataset)
        T, X = np.meshgrid(t, x, indexing="ij")
        xx_ic = np.hstack([T[0:1, :].T, X[0:1, :].T])  # (Nx,2) at t=0
        uu_ic = U[0:1, :].T

        xx_bcL = np.hstack([T[:, 0:1], X[:, 0:1]])    # (Nt,2) at x=-1
        uu_bcL = U[:, 0:1]
        xx_bcR = np.hstack([T[:, -1:], X[:, -1:]])    # (Nt,2) at x=+1
        uu_bcR = U[:, -1:]

        X_u_pool = np.vstack([xx_ic, xx_bcL, xx_bcR])
        u_pool = np.vstack([uu_ic, uu_bcL, uu_bcR])   # shape (≈456,1)
        self.X_u_pool = X_u_pool
        self.u_pool = u_pool

        # Master collocation design in [t_min,t_max] × [x_min,x_max], then append X_u_pool
        rng = np.random.default_rng(cfg.seed)
        H = lhs(cfg.Nf_master, 2, rng)
        lb = np.array([cfg.t_min, cfg.x_min])
        ub = np.array([cfg.t_max, cfg.x_max])
        X_f_master = lb + (ub - lb) * H
        self.X_f_master = np.vstack([X_f_master, X_u_pool])

        # High-res eval grid (default: use mat grid and exact U)
        self.eval_t = t[:, 0]
        self.eval_x = x[:, 0]
        self.eval_U = U  # (Nt, Nx)

    def sample_train_sets(self, Nu: int, Nf: int, seed: Optional[int] = None) -> Dict[str, torch.Tensor]:
        """Return tensors for (t_ic,x_ic,u_ic), (t_bc,x_bc,u_bc), and (t_f,x_f)."""
        if seed is None: seed = self.cfg.seed
        rng = np.random.default_rng(seed)

        # Subsample Nu from pool without replacement
        poolN = self.X_u_pool.shape[0]
        idx = rng.choice(poolN, size=min(Nu, poolN), replace=False)
        X_u = self.X_u_pool[idx, :]
        u_u = self.u_pool[idx, :]

        # Split into IC vs BC via t==0 mask
        is_ic = np.isclose(X_u[:, 0], 0.0)
        t_ic = torch.tensor(X_u[is_ic, 0:1])
        x_ic = torch.tensor(X_u[is_ic, 1:2])
        u_ic = torch.tensor(u_u[is_ic, :])

        t_bc = torch.tensor(X_u[~is_ic, 0:1])
        x_bc = torch.tensor(X_u[~is_ic, 1:2])
        u_bc = torch.tensor(u_u[~is_ic, :])

        # Collocation: nested subset from master
        Xf = self.X_f_master[:Nf, :]
        t_f = torch.tensor(Xf[:, 0:1])
        x_f = torch.tensor(Xf[:, 1:2])

        return {
            "t_ic": t_ic, "x_ic": x_ic, "u_ic": u_ic,
            "t_bc": t_bc, "x_bc": x_bc, "u_bc": u_bc,
            "t_f":  t_f,  "x_f":  x_f,
        }

    def eval_grid(self) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]:
        """Return (t_grid, x_grid, U_exact) where t_grid shape (Nt,), x_grid shape (Nx,)."""
        t = torch.tensor(self.eval_t)
        x = torch.tensor(self.eval_x)
        U = self.eval_U
        return t, x, U

    def rel_l2(self, U_pred: np.ndarray) -> float:
        U_ref = self.eval_U
        return float(np.linalg.norm(U_pred - U_ref) / np.linalg.norm(U_ref))

In [None]:




# -----------------------
# Models and losses
# -----------------------
class FCNet(nn.Module):
    def __init__(self, in_features=2, out_features=1, width=20, depth=8, activation="tanh"):
        super().__init__()
        acts = {"tanh": nn.Tanh, "relu": nn.ReLU, "gelu": nn.GELU}
        Act = acts.get(activation, nn.Tanh)
        layers = [nn.Linear(in_features, width), Act()]
        for _ in range(depth - 1):
            layers += [nn.Linear(width, width), Act()]
        layers += [nn.Linear(width, out_features)]
        self.net = nn.Sequential(*layers)
        self.apply(self._init)

    @staticmethod
    def _init(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, tx_scaled):
        return self.net(tx_scaled)

def pinn_residual(model: nn.Module, t: torch.Tensor, x: torch.Tensor, cfg: PINNConfig):
    t_req = t.clone().detach().requires_grad_(True).to(device)
    x_req = x.clone().detach().requires_grad_(True).to(device)
    tx = torch.cat([t_req, x_req], dim=1)
    tx_s = scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)

    u = model(tx_s)
    ones = torch.ones_like(u)
    du_dt = torch.autograd.grad(u, t_req, grad_outputs=ones, retain_graph=True, create_graph=True)[0]
    du_dx = torch.autograd.grad(u, x_req, grad_outputs=ones, retain_graph=True, create_graph=True)[0]
    d2u_dx2 = torch.autograd.grad(du_dx, x_req, grad_outputs=torch.ones_like(du_dx), retain_graph=True, create_graph=True)[0]

    f = du_dt + u * du_dx - cfg.nu * d2u_dx2
    return f

def pinn_loss(model: nn.Module, batch: Dict[str, torch.Tensor], cfg: PINNConfig) -> Tuple[torch.Tensor, Dict]:
    # Data terms
    def _pred(t, x):
        tx = torch.cat([t.to(device), x.to(device)], dim=1)
        tx_s = scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)
        return model(tx_s)

    u_ic_pred = _pred(batch["t_ic"], batch["x_ic"]) if batch["t_ic"].numel()>0 else torch.tensor([], device=device)
    u_bc_pred = _pred(batch["t_bc"], batch["x_bc"]) if batch["t_bc"].numel()>0 else torch.tensor([], device=device)

    loss_ic = torch.mean((u_ic_pred - batch["u_ic"].to(device))**2) if batch["t_ic"].numel()>0 else torch.tensor(0., device=device)
    loss_bc = torch.mean((u_bc_pred - batch["u_bc"].to(device))**2) if batch["t_bc"].numel()>0 else torch.tensor(0., device=device)

    # Physics residual
    f_pred = pinn_residual(model, batch["t_f"], batch["x_f"], cfg)
    loss_f = torch.mean(f_pred**2)

    loss = loss_ic + loss_bc + loss_f
    return loss, {"ic": float(loss_ic.detach().cpu()), "bc": float(loss_bc.detach().cpu()), "f": float(loss_f.detach().cpu())}

In [None]:
# --- REPLACE train_pinn with this ---
def train_pinn(cfg: PINNConfig, data: BurgersData,
               Nu: Optional[int]=None, Nf: Optional[int]=None,
               verbose=True,
               log_wandb: bool=False,
               wb_project: str="pinn-burgers",
               log_every: int = 100):

    # use new config fields by default
    if Nu is None: Nu = cfg.N_ic
    if Nf is None: Nf = cfg.N_f

    set_seed(cfg.seed)
    batch = data.sample_train_sets(Nu, Nf, seed=cfg.seed)
    model = FCNet(width=cfg.hidden_width, depth=cfg.hidden_layers, activation=cfg.activation).to(device)
    opt = Adam(model.parameters(), lr=cfg.adam_lr)

    run = None
    if log_wandb:
        run = wandb_start(cfg, wb_project, group="deterministic",
                          name=f"PINN_Nu{Nu}_Nf{Nf}_seed{cfg.seed}",
                          tags=[f"Nu={Nu}", f"Nf={Nf}", "baseline"])

    t0 = time.time()
    for ep in range(1, cfg.adam_epochs + 1):
        opt.zero_grad(set_to_none=True)
        loss, parts = pinn_loss(model, batch, cfg)
        loss.backward()
        opt.step()

        if (ep % log_every == 0) or (ep == 1):
            if verbose:
                print(f"[Adam] ep={ep:4d} loss={loss.item():.3e} ic={parts['ic']:.3e} bc={parts['bc']:.3e} f={parts['f']:.3e}")
            if run:
                wandb.log({
                    "train/loss": float(loss.item()),
                    "train/ic": parts['ic'],
                    "train/bc": parts['bc'],
                    "train/f":  parts['f'],
                    "epoch": ep
                })

    if cfg.use_lbfgs:
        lbfgs = torch.optim.LBFGS(
            model.parameters(),
            lr=1.0,
            max_iter=cfg.lbfgs_max_iter,
            max_eval=cfg.lbfgs_max_iter,
            tolerance_grad=cfg.lbfgs_tolerance_grad,
            tolerance_change=cfg.lbfgs_tolerance_change,
            history_size=cfg.lbfgs_history_size,
            line_search_fn='strong_wolfe'
        )
        def closure():
            lbfgs.zero_grad(set_to_none=True)
            l, _ = pinn_loss(model, batch, cfg)
            l.backward()
            return l
        l0 = closure().item()
        lbfgs.step(closure)
        l1, parts1 = pinn_loss(model, batch, cfg)
        if verbose:
            print(f"[LBFGS] before={l0:.3e} after={l1.item():.3e} ic={parts1['ic']:.3e} bc={parts1['bc']:.3e} f={parts1['f']:.3e}")
        if run:
            wandb.log({
                "lbfgs/loss_before": float(l0),
                "lbfgs/loss_after": float(l1.item()),
                "lbfgs/ic": parts1['ic'],
                "lbfgs/bc": parts1['bc'],
                "lbfgs/f":  parts1['f'],
                "epoch": cfg.adam_epochs  # final Adam epoch index
            })

    # Dense-grid evaluation AFTER L-BFGS
    t_grid, x_grid, U_exact = data.eval_grid()
    TT, XX = np.meshgrid(t_grid.numpy(), x_grid.numpy(), indexing='ij')
    tx = torch.tensor(np.stack([TT.ravel(), XX.ravel()], axis=1), dtype=torch.get_default_dtype(), device=device)
    with torch.no_grad():
        U_pred = model(scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)).cpu().numpy().reshape(U_exact.shape)

    rel = data.rel_l2(U_pred)
    elapsed = time.time() - t0
    if verbose:
        print(f"[Eval] rel-L2={rel:.3e} | time={elapsed:.1f}s")

    if run:
        imgs = _wandb_images_from_prediction(U_pred, U_exact, t_grid.numpy(), x_grid.numpy(), title_prefix="PINN")
        wandb.log({"eval/relL2": rel, "time/seconds": elapsed, **imgs, "epoch": cfg.adam_epochs})

        run.finish()
    return model, rel, {"time": elapsed}


In [None]:
# -----------------------
# PAC-Bayes PINN (Diagonal Gaussian posterior)
# -----------------------
# --- REPLACE BayesLinear with this ---
class BayesLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_sigma0=1.0):
        super().__init__()
        self.prior_sigma0 = float(prior_sigma0)

        # posterior parameters
        self.mu_w = nn.Parameter(torch.empty(out_f, in_f))
        self.mu_b = nn.Parameter(torch.empty(out_f))
        rho0 = math.log(math.expm1(prior_sigma0))
        self.rho_w = nn.Parameter(torch.full((out_f, in_f), rho0))
        self.rho_b = nn.Parameter(torch.full((out_f,), rho0))

        # prior centers μ0 (buffers; fixed during PAC-Bayes training)
        self.register_buffer("mu0_w", torch.empty(out_f, in_f))
        self.register_buffer("mu0_b", torch.empty(out_f))

        # truncated N(0, 1/√n_in), clipped at ±2σ
        std = 1.0 / math.sqrt(in_f)
        with torch.no_grad():
            self.mu0_w.copy_(truncated_normal_like(self.mu0_w, std))
            self.mu0_b.copy_(truncated_normal_like(self.mu0_b, std))
            # posterior initialised at the prior (centers & scales)
            self.mu_w.copy_(self.mu0_w)
            self.mu_b.copy_(self.mu0_b)

    @property
    def sigma_w(self): return F.softplus(self.rho_w)
    @property
    def sigma_b(self): return F.softplus(self.rho_b)

    def sample_params(self):
        eps_w = torch.randn_like(self.mu_w)
        eps_b = torch.randn_like(self.mu_b)
        W = self.mu_w + self.sigma_w * eps_w
        b = self.mu_b + self.sigma_b * eps_b
        return W, b

    def mean_params(self):
        return self.mu_w, self.mu_b
    def kl(self):
        s0 = self.prior_sigma0
        s02 = s0 * s0
        def _kl(mu, mu0, sigma):
            return 0.5 * torch.sum((sigma**2)/s02 + ((mu - mu0)**2)/s02 - 1.0 + torch.log(s02/(sigma**2 + 1e-12)))
        return _kl(self.mu_w, self.mu0_w, self.sigma_w) + _kl(self.mu_b, self.mu0_b, self.sigma_b)

def truncated_normal_like(t, std):
    out = torch.empty_like(t)
    ok = torch.zeros_like(out, dtype=torch.bool)
    while not torch.all(ok):
        s = torch.randn_like(out) * std
        mask = (s.abs() <= 2*std)
        out = torch.where(ok, out, s)
        ok = ok | mask
    return out



# --- REPLACE BayesFCNet with this ---
class BayesFCNet(nn.Module):
    def __init__(self, in_features=2, out_features=1, width=20, depth=8, activation="tanh", prior_sigma0=1.0):
        super().__init__()
        acts = {"tanh": nn.Tanh, "relu": nn.ReLU, "gelu": nn.GELU}
        Act = acts.get(activation, nn.Tanh)
        self.layers = nn.ModuleList()
        self.acts = nn.ModuleList([Act() for _ in range(depth)])
        # input
        self.layers.append(BayesLinear(in_features, width, prior_sigma0))
        # hidden
        for _ in range(depth - 1):
            self.layers.append(BayesLinear(width, width, prior_sigma0))
        # output
        self.out = BayesLinear(width, out_features, prior_sigma0)

    # ---- NEW: sample once and reuse ----
    def sample_weights(self):
        weights = []
        for layer in self.layers:
            W, b = layer.sample_params()
            weights.append((W, b))
        W, b = self.out.sample_params()
        out_w = (W, b)
        return weights, out_w

    # ---- NEW: KL once per step ----
    def kl_total(self):
        kl = torch.tensor(0.0, device=self.layers[0].mu_w.device)
        for layer in self.layers:
            kl = kl + layer.kl()
        kl = kl + self.out.kl()
        return kl

    # ---- NEW: deterministic forward with provided weights ----
    def forward_with_weights(self, tx_scaled, weights, out_w):
        h = tx_scaled
        for i, ((W, b), act) in enumerate(zip(weights, self.acts)):
            h = h @ W.T + b
            h = act(h)
        W, b = out_w
        y = h @ W.T + b
        return y

    # (keep these two for convenience)
    def forward_sample(self, tx_scaled):
        weights, out_w = self.sample_weights()
        y = self.forward_with_weights(tx_scaled, weights, out_w)
        return y, self.kl_total()

    def forward_mean(self, tx_scaled):
        h = tx_scaled
        for i, layer in enumerate(self.layers):
            W, b = layer.mean_params()
            h = h @ W.T + b
            h = self.acts[i](h)
        W, b = self.out.mean_params()
        y = h @ W.T + b
        return y


def bounded_mse(y, y_true, B):
    return torch.clamp((y - y_true)**2, max=B) / B

# -----------------------
# PAC-Bayes PINN (Diagonal Gaussian posterior)
# -----------------------
# --- REPLACE BayesLinear with this ---
class BayesLinear(nn.Module):
    def __init__(self, in_f, out_f, prior_sigma0=1.0):
        super().__init__()
        self.prior_sigma0 = float(prior_sigma0)

        # posterior parameters
        self.mu_w = nn.Parameter(torch.empty(out_f, in_f))
        self.mu_b = nn.Parameter(torch.empty(out_f))
        rho0 = math.log(math.expm1(prior_sigma0))
        self.rho_w = nn.Parameter(torch.full((out_f, in_f), rho0))
        self.rho_b = nn.Parameter(torch.full((out_f,), rho0))

        # prior centers μ0 (buffers; fixed during PAC-Bayes training)
        self.register_buffer("mu0_w", torch.empty(out_f, in_f))
        self.register_buffer("mu0_b", torch.empty(out_f))

        # truncated N(0, 1/√n_in), clipped at ±2σ
        std = 1.0 / math.sqrt(in_f)
        with torch.no_grad():
            self.mu0_w.copy_(truncated_normal_like(self.mu0_w, std))
            self.mu0_b.copy_(truncated_normal_like(self.mu0_b, std))
            # posterior initialised at the prior (centers & scales)
            self.mu_w.copy_(self.mu0_w)
            self.mu_b.copy_(self.mu0_b)

    @property
    def sigma_w(self): return F.softplus(self.rho_w)
    @property
    def sigma_b(self): return F.softplus(self.rho_b)

    def sample_params(self):
        eps_w = torch.randn_like(self.mu_w)
        eps_b = torch.randn_like(self.mu_b)
        W = self.mu_w + self.sigma_w * eps_w
        b = self.mu_b + self.sigma_b * eps_b
        return W, b

    def mean_params(self):
        return self.mu_w, self.mu_b
    def kl(self):
        s0 = self.prior_sigma0
        s02 = s0 * s0
        def _kl(mu, mu0, sigma):
            return 0.5 * torch.sum((sigma**2)/s02 + ((mu - mu0)**2)/s02 - 1.0 + torch.log(s02/(sigma**2 + 1e-12)))
        return _kl(self.mu_w, self.mu0_w, self.sigma_w) + _kl(self.mu_b, self.mu0_b, self.sigma_b)

def truncated_normal_like(t, std):
    out = torch.empty_like(t)
    ok = torch.zeros_like(out, dtype=torch.bool)
    while not torch.all(ok):
        s = torch.randn_like(out) * std
        mask = (s.abs() <= 2*std)
        out = torch.where(ok, out, s)
        ok = ok | mask
    return out



# --- REPLACE BayesFCNet with this ---
class BayesFCNet(nn.Module):
    def __init__(self, in_features=2, out_features=1, width=20, depth=8, activation="tanh", prior_sigma0=1.0):
        super().__init__()
        acts = {"tanh": nn.Tanh, "relu": nn.ReLU, "gelu": nn.GELU}
        Act = acts.get(activation, nn.Tanh)
        self.layers = nn.ModuleList()
        self.acts = nn.ModuleList([Act() for _ in range(depth)])
        # input
        self.layers.append(BayesLinear(in_features, width, prior_sigma0))
        # hidden
        for _ in range(depth - 1):
            self.layers.append(BayesLinear(width, width, prior_sigma0))
        # output
        self.out = BayesLinear(width, out_features, prior_sigma0)

    # ---- NEW: sample once and reuse ----
    def sample_weights(self):
        weights = []
        for layer in self.layers:
            W, b = layer.sample_params()
            weights.append((W, b))
        W, b = self.out.sample_params()
        out_w = (W, b)
        return weights, out_w

    # ---- NEW: KL once per step ----
    def kl_total(self):
        kl = torch.tensor(0.0, device=self.layers[0].mu_w.device)
        for layer in self.layers:
            kl = kl + layer.kl()
        kl = kl + self.out.kl()
        return kl

    # ---- NEW: deterministic forward with provided weights ----
    def forward_with_weights(self, tx_scaled, weights, out_w):
        h = tx_scaled
        for i, ((W, b), act) in enumerate(zip(weights, self.acts)):
            h = h @ W.T + b
            h = act(h)
        W, b = out_w
        y = h @ W.T + b
        return y

    # (keep these two for convenience)
    def forward_sample(self, tx_scaled):
        weights, out_w = self.sample_weights()
        y = self.forward_with_weights(tx_scaled, weights, out_w)
        return y, self.kl_total()

    def forward_mean(self, tx_scaled):
        h = tx_scaled
        for i, layer in enumerate(self.layers):
            W, b = layer.mean_params()
            h = h @ W.T + b
            h = self.acts[i](h)
        W, b = self.out.mean_params()
        y = h @ W.T + b
        return y


def bounded_mse(y, y_true, B):
    return torch.clamp((y - y_true)**2, max=B) / B

# MODIFIED FUNCTION
def pbb_objective(modelB: BayesFCNet, batch, cfg: PINNConfig, mc_samples: int = 1): # <-- REMOVED eval_mode
    Bu = cfg.Bu
    Bf = cfg.Bf
    lam_u = cfg.lambda_u
    lam_f = cfg.lambda_f

    m_u = int(batch["t_ic"].shape[0] + batch["t_bc"].shape[0])
    m_f = int(batch["t_f"].shape[0])
    cu = math.log(4.0*math.sqrt(max(1, m_u)) / cfg.delta)
    cf = math.log(4.0*math.sqrt(max(1, m_f)) / cfg.delta)

    # scale inputs
    tx_ic_s = None; tx_bc_s = None
    if m_u > 0:
        tx_ic = torch.cat([batch["t_ic"].to(device), batch["x_ic"].to(device)], dim=1)
        tx_bc = torch.cat([batch["t_bc"].to(device), batch["x_bc"].to(device)], dim=1)
        tx_ic_s = scale_tx_to_unit(tx_ic, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)
        tx_bc_s = scale_tx_to_unit(tx_bc, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)

    # residual inputs with grads
    t_f = batch["t_f"].clone().detach().requires_grad_(True).to(device)
    x_f = batch["x_f"].clone().detach().requires_grad_(True).to(device)
    tx_f_s = scale_tx_to_unit(torch.cat([t_f, x_f], dim=1), cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)

    # --- RESTORED a single, safe grad flag configuration ---
    cg_dt, rg_dt = True,  True
    cg_dx, rg_dx = True,  True
    cg_d2, rg_d2 = True,  True

    Lu_acc = 0.0; Lf_acc = 0.0; KL_acc = 0.0
    frac_clip_u_acc = 0.0; frac_clip_f_acc = 0.0

    for _ in range(mc_samples):
        weights, out_w = modelB.sample_weights()
        KL = modelB.kl_total()

        # bounded data term
        if m_u > 0:
            y_ic = modelB.forward_with_weights(tx_ic_s, weights, out_w)
            y_bc = modelB.forward_with_weights(tx_bc_s, weights, out_w)
            sq_ic = (y_ic - batch["u_ic"].to(device))**2
            sq_bc = (y_bc - batch["u_bc"].to(device))**2
            Lu = (torch.clamp(sq_ic, max=Bu)/Bu).mean() + (torch.clamp(sq_bc, max=Bu)/Bu).mean()
            Lu = 0.5 * Lu
            frac_clip_u = 0.5*((sq_ic >= Bu).float().mean() + (sq_bc >= Bu).float().mean())
        else:
            Lu = torch.tensor(0.0, device=device)
            frac_clip_u = torch.tensor(0.0, device=device)

        # bounded residual term
        y_f = modelB.forward_with_weights(tx_f_s, weights, out_w)
        du_dt = torch.autograd.grad(y_f, t_f, grad_outputs=torch.ones_like(y_f),
                                    retain_graph=rg_dt, create_graph=cg_dt)[0]
        du_dx = torch.autograd.grad(y_f, x_f, grad_outputs=torch.ones_like(y_f),
                                    retain_graph=rg_dx, create_graph=cg_dx)[0]
        d2u_dx2 = torch.autograd.grad(du_dx, x_f, grad_outputs=torch.ones_like(du_dx),
                                      retain_graph=rg_d2, create_graph=cg_d2)[0]
        f = du_dt + y_f * du_dx - cfg.nu * d2u_dx2

        sq_f = f**2
        Lf = torch.clamp(sq_f, max=Bf).mean() / Bf
        frac_clip_f = (sq_f >= Bf).float().mean()

        Lu_acc += Lu; Lf_acc += Lf; KL_acc += KL
        frac_clip_u_acc += frac_clip_u; frac_clip_f_acc += frac_clip_f

    Lu_bar = Lu_acc / mc_samples
    Lf_bar = Lf_acc / mc_samples
    KL_bar = KL_acc / mc_samples
    frac_clip_u_bar = frac_clip_u_acc / mc_samples
    frac_clip_f_bar = frac_clip_f_acc / mc_samples

    term_u = lam_u * torch.sqrt(torch.clamp((KL_bar + cu) / (2.0*max(1, m_u)), min=0.0))
    term_f = lam_f * torch.sqrt(torch.clamp((KL_bar + cf) / (2.0*max(1, m_f)), min=0.0))

    Lpin = lam_u*Lu_bar + lam_f*Lf_bar
    J = Lpin + term_u + term_f

    with torch.no_grad():
        mse_u = torch.tensor(0.0, device=device)
        if m_u > 0:
            mse_u = 0.5*((y_ic - batch["u_ic"].to(device))**2).mean() + 0.5*((y_bc - batch["u_bc"].to(device))**2).mean()
        mse_f = (f**2).mean()

    return J, {
        "Lu": float(Lu_bar.detach().cpu()),
        "Lf": float(Lf_bar.detach().cpu()),
        "KL": float(KL_bar.detach().cpu()),
        "term_u": float(term_u.detach().cpu()),
        "term_f": float(term_f.detach().cpu()),
        "clip_frac_u": float(frac_clip_u_bar.detach().cpu()),
        "clip_frac_f": float(frac_clip_f_bar.detach().cpu()),
        "mse_u": float(mse_u.detach().cpu()),
        "mse_f": float(mse_f.detach().cpu()),
    }




def _coarse_relL2_mean(modelB, data: BurgersData, cfg: PINNConfig, stride_t=4, stride_x=4):
    """Cheaper periodic check: rel-L2 on a coarser grid (no MC)."""
    t_grid, x_grid, U_exact = data.eval_grid()
    t_sub = t_grid[::stride_t]; x_sub = x_grid[::stride_x]
    U_sub  = U_exact[::stride_t, ::stride_x]
    TT, XX = np.meshgrid(t_sub.numpy(), x_sub.numpy(), indexing='ij')
    tx = torch.tensor(np.stack([TT.ravel(), XX.ravel()], axis=1), dtype=torch.get_default_dtype(), device=device)
    with torch.no_grad():
        U_mean = modelB.forward_mean(scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)).cpu().numpy().reshape(U_sub.shape)
    return float(np.linalg.norm(U_mean - U_sub) / np.linalg.norm(U_sub))

def _posterior_sigma_stats(modelB: 'BayesFCNet'):
    stats = {}
    L = len(modelB.layers)
    for i, layer in enumerate(list(modelB.layers) + [modelB.out]):
        sw = layer.sigma_w.detach().cpu().numpy().ravel()
        sb = layer.sigma_b.detach().cpu().numpy().ravel()
        stats[f"sigma/mean_w_L{i}"] = float(np.mean(sw))
        stats[f"sigma/max_w_L{i}"]  = float(np.max(sw))
        stats[f"sigma/mean_b_L{i}"] = float(np.mean(sb))
        stats[f"sigma/max_b_L{i}"]  = float(np.max(sb))
    return stats






In [None]:
from typing import Optional, Callable

@torch.no_grad()
def calculate_certificate_batched(modelB, batch, cfg, total_mc, batch_size=64):
    """
    Calculates the PBB objective by processing MC samples in batches to conserve memory.
    This prevents OutOfMemoryError during the final high-sample certificate calculation.
    """
    if total_mc == 0:
        return 0.0

    J_total = 0.0
    num_batches = (total_mc + batch_size - 1) // batch_size
    print(f"Calculating certificate with {total_mc} samples in {num_batches} batches of size {batch_size}...")

    for i in range(num_batches):
        mc_samples_in_batch = min(batch_size, total_mc - i * batch_size)

        # Gradients must be enabled here because the pbb_objective function
        # needs them to compute the PDE residual via autograd.
        with torch.enable_grad():
            J_batch, _ = pbb_objective(modelB, batch, cfg, mc_samples=mc_samples_in_batch)

        # Accumulate the result and immediately free memory
        J_total += J_batch.item() * mc_samples_in_batch
        del J_batch
        _gpu_hygiene() # A helper function to call gc.collect() and torch.cuda.empty_cache()

    # Return the final average over all Monte Carlo samples
    return J_total / total_mc


# --- REPLACE your old train_pacbayes with this version ---
def train_pacbayes(cfg: PINNConfig, data: BurgersData,
                   Nu: Optional[int]=None, Nf: Optional[int]=None,
                   verbose=True,
                   log_wandb: bool=False,
                   wb_project: str="pinn-burgers",
                   log_every: int = 100,
                   epochs: Optional[int] = None,
                   mem_check_fn: Optional[Callable] = None): # Optional: for memory debugging
    # sizes
    if Nu is None: Nu = cfg.N_ic
    if Nf is None: Nf = cfg.N_f
    if epochs is None:
        epochs = getattr(cfg, "pbb_epochs", 5000)

    track_every = getattr(cfg, "cert_track_every", 200)
    track_mc = getattr(cfg, "cert_track_mc", 64)
    save_path = getattr(cfg, "save_best_path", None)
    relL2_mc = getattr(cfg, "relL2_mc", 64)

    import copy
    import torch as _torch

    set_seed(cfg.seed)
    batch = data.sample_train_sets(Nu, Nf, seed=cfg.seed)

    modelB = BayesFCNet(width=cfg.hidden_width, depth=cfg.hidden_layers,
                        activation=cfg.activation, prior_sigma0=cfg.prior_sigma0).to(device)

    opt = torch.optim.SGD(modelB.parameters(), lr=cfg.pbb_lr, momentum=cfg.pbb_momentum)

    run = None
    if log_wandb:
        run = wandb_start(cfg, wb_project, group="pacbayes",
                          name=f"PBB_Nu{Nu}_Nf{Nf}_seed{cfg.seed}",
                          tags=[f"Nu={Nu}", f"Nf={Nf}", "pbb"])

    best_est_cert = float("inf")
    best_epoch = -1
    best_state = None

    t0 = time.time()
    for ep in range(1, epochs + 1):
        opt.zero_grad(set_to_none=True)
        J, parts = pbb_objective(modelB, batch, cfg, mc_samples=cfg.pbb_mc_train)
        J.backward()
        opt.step()

        if (ep % log_every == 0) or (ep == 1):
            if verbose:
                print(f"[PBB] ep={ep:4d} J={J.item():.3e} Lu={parts['Lu']:.3e} Lf={parts['Lf']:.3e} "
                      f"KL={parts['KL']:.3e} tu={parts.get('term_u',0):.3e} tf={parts.get('term_f',0):.3e} "
                      f"clip_u={parts.get('clip_frac_u',0):.2f} clip_f={parts.get('clip_frac_f',0):.2f}")
            if run:
                # Ensure all parts are included in the log
                log_data = {"pbb/J": J.item(), "epoch": ep}
                for k, v in parts.items():
                    log_data[f"pbb/{k}"] = v
                wandb.log(log_data)


        if (ep % track_every == 0) or (ep == 1) or (ep == epochs):
            with torch.enable_grad():
                J_est, _ = pbb_objective(modelB, batch, cfg, mc_samples=track_mc)
            J_est_val = float(J_est.detach().cpu())
            del J_est; _gpu_hygiene()

            if run:
                wandb.log({"cert/track_estimate": J_est_val, "epoch": ep})
            if J_est_val < best_est_cert:
                best_est_cert = J_est_val
                best_epoch = ep
                state_gpu = modelB.state_dict()
                best_state = {k: v.detach().cpu().clone() for k, v in state_gpu.items()}
                del state_gpu; _gpu_hygiene()
                if run:
                    wandb.log({"cert/best_track_estimate": best_est_cert,
                               "cert/best_epoch": best_epoch,
                               "epoch": ep})

    if best_state is not None:
        modelB.load_state_dict(best_state)
        if save_path:
            torch.save(best_state, save_path)
            if run: wandb.save(save_path)

    # --- MINIMAL CHANGE IS HERE ---
    # Replace the single, memory-intensive call with our new batched function.
    if mem_check_fn: mem_check_fn("PBB - Before Certificate")
    cert = calculate_certificate_batched(
        modelB, batch, cfg,
        total_mc=cfg.cert_mc,
        batch_size=64
    )
    if mem_check_fn: mem_check_fn("PBB - After Certificate")
    # --- END OF MINIMAL CHANGE ---

    t_grid, x_grid, U_exact = data.eval_grid()
    TT, XX = np.meshgrid(t_grid.numpy(), x_grid.numpy(), indexing='ij')
    tx = torch.tensor(np.stack([TT.ravel(), XX.ravel()], axis=1), dtype=torch.get_default_dtype(), device=device)

    with torch.no_grad():
        U_mean = modelB.forward_mean(scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max)).cpu().numpy().reshape(U_exact.shape)
    rel_mean = data.rel_l2(U_mean)

    rel_samples = []
    with torch.no_grad():
        for _ in range(min(relL2_mc, cfg.cert_mc)):
            weights, out_w = modelB.sample_weights()
            U_s = modelB.forward_with_weights(scale_tx_to_unit(tx, cfg.t_min, cfg.t_max, cfg.x_min, cfg.x_max),
                                              weights, out_w).cpu().numpy().reshape(U_exact.shape)
            rel_samples.append(data.rel_l2(U_s))
    rel_stoch = float(np.mean(rel_samples)) if rel_samples else float('nan')

    elapsed = time.time() - t0
    if verbose:
        print(f"[PBB Eval] cert={cert:.3e} relL2_mean={rel_mean:.3e} relL2_stoch={rel_stoch:.3e} time={elapsed:.1f}s")

    if run:
        imgs = _wandb_images_from_prediction(U_mean, U_exact, t_grid.numpy(), x_grid.numpy(), title_prefix="PBB mean")
        wandb.log({
            "cert/final": cert,
            "cert/best_track_estimate": best_est_cert,
            "cert/best_epoch": best_epoch,
            "eval/relL2_mean": rel_mean,
            "eval/relL2_stoch_avg": rel_stoch,
            "time/seconds": elapsed,
            "epoch": epochs,
            **imgs
        })
        run.finish()

    return modelB, {
        "certificate": cert,
        "relL2_mean": rel_mean,
        "relL2_stoch": rel_stoch,
        "time": elapsed,
        "certificate_track_best": best_est_cert,
        "best_epoch": best_epoch,
    }


In [None]:
# -----------------------
# Sweep utilities
# -----------------------
def sweep_pinn(cfg: PINNConfig, Nu_list: List[int], Nf_list: List[int]) -> pd.DataFrame:
    data = BurgersData(cfg)
    rows = []
    for Nu in Nu_list:
        row = {"Nu": Nu}
        for Nf in Nf_list:
            print(f"=== [PINN] Nu={Nu}, Nf={Nf} ===")
            cfg_run = PINNConfig(**{**cfg.__dict__})
            cfg_run.N_ic = Nu
            cfg_run.N_f  = Nf
            _, rel, _meta = train_pinn(cfg_run, data, verbose=False)
            row[str(Nf)] = rel
            print(f"-> rel-L2={rel:.3e}")
        rows.append(row)
    df = pd.DataFrame(rows).set_index("Nu")
    return df

def sweep_pacbayes(cfg: PINNConfig, Nu_list: List[int], Nf_list: List[int]) -> Tuple[pd.DataFrame, pd.DataFrame]:
    data = BurgersData(cfg)
    rows_cert = []
    rows_rel = []
    for Nu in Nu_list:
        row_cert = {"Nu": Nu}
        row_rel  = {"Nu": Nu}
        for Nf in Nf_list:
            print(f"=== [PBB] Nu={Nu}, Nf={Nf} ===")
            cfg_run = PINNConfig(**{**cfg.__dict__})
            cfg_run.Nu = Nu; cfg_run.Nf = Nf
            modelB, res = train_pacbayes(cfg_run, data, verbose=False)
            row_cert[str(Nf)] = res["certificate"]
            row_rel[str(Nf)]  = res["relL2_mean"]  # you can also store relL2_stoch
            print(f"-> cert={res['certificate']:.3e} relL2_mean={res['relL2_mean']:.3e} relL2_stoch={res['relL2_stoch']:.3e}")
        rows_cert.append(row_cert)
        rows_rel.append(row_rel)
    df_cert = pd.DataFrame(rows_cert).set_index("Nu")
    df_rel  = pd.DataFrame(rows_rel).set_index("Nu")
    return df_cert, df_rel

In [None]:
from dataclasses import asdict
# === Colab/Drive helpers & sweep-and-save ===
import os, gc, json, pandas as pd, torch, torch.nn as nn

def ensure_dir(p):
    os.makedirs(p, exist_ok=True); return p

def export_posterior_mean_fcnet(pb_model: BayesFCNet, cfg: PINNConfig, path: str):
    """Build a vanilla FCNet whose weights equal the posterior mean of pb_model and save state_dict()."""
    det = FCNet(width=cfg.hidden_width, depth=cfg.hidden_layers, activation=cfg.activation).to("cpu")
    lin = [m for m in det.net if isinstance(m, nn.Linear)]
    for i, bayes_layer in enumerate(pb_model.layers):
        Wmu, bmu = bayes_layer.mean_params()
        lin[i].weight.data.copy_(Wmu.detach().cpu())
        lin[i].bias.data.copy_(bmu.detach().cpu())
    Wmu, bmu = pb_model.out.mean_params()
    lin[-1].weight.data.copy_(Wmu.detach().cpu())
    lin[-1].bias.data.copy_(bmu.detach().cpu())
    torch.save(det.state_dict(), path)

@torch.no_grad()
def _gpu_hygiene():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def run_grid_and_save(cfg: PINNConfig,
                      data: BurgersData,
                      Nu_list, Nf_list,
                      outdir_base: str,
                      log_wandb: bool = True):
    """
    For each (Nu, Nf):
      - run deterministic PINN
      - run PAC-Bayes PINN (tracking best bound; restores best snapshot)
      - save: best posterior state, posterior-mean deterministic state
      - record metrics: relL2_det_baseline, relL2_posterior_mean, relL2_probabilistic, certificate, etc.
      - write a single CSV to Google Drive.
    """
    # Prepare output folders
    outdir = ensure_dir(outdir_base.rstrip("/"))
    ckpt_dir = ensure_dir(os.path.join(outdir, "checkpoints"))
    mean_dir  = ensure_dir(os.path.join(outdir, "posterior_mean_det"))

    rows = []
    seed = cfg.seed

    for Nu in Nu_list:
        for Nf in Nf_list:
            print(f"\n=== RUN (Nu={Nu}, Nf={Nf}, seed={seed}) ===")
            # fresh per-run config (don’t mutate the original)
            # AFTER (uses only real dataclass fields)
            cfg_run = PINNConfig(**asdict(cfg))
            cfg_run.N_ic, cfg_run.N_f = Nu, Nf

            # unique filenames
            tag = f"Nu{Nu}_Nf{Nf}_seed{seed}"
            cfg_run.save_best_path = os.path.join(ckpt_dir, f"pbb_best_state_{tag}.pt")
            mean_path = os.path.join(mean_dir, f"pbb_posterior_mean_det_{tag}.pt")
            metrics_path_csv = os.path.join(outdir, f"metrics_{tag}.csv")
            metrics_path_json = os.path.join(outdir, f"metrics_{tag}.json")

            # --- Deterministic baseline ---
            det_model, rel_det, meta_det = train_pinn(cfg_run, data, log_wandb=log_wandb, wb_project="pinn-burgers")
            time_det = meta_det.get("time", float("nan"))
            # free
            del det_model; _gpu_hygiene()

            # --- PAC-Bayes ---
            pb_model, out_pbb = train_pacbayes(cfg_run, data, log_wandb=log_wandb,
                                               wb_project="pinn-burgers",
                                               log_every=100, epochs=cfg_run.pbb_epochs)
            # export posterior-mean deterministic weights (for full reproducibility of relL2_mean)
            export_posterior_mean_fcnet(pb_model, cfg_run, mean_path)

            # gather metrics
            row = {
                "Nu": Nu, "Nf": Nf, "seed": seed,
                "relL2_det_baseline": rel_det,                    # deterministic PINN baseline
                "relL2_posterior_mean": out_pbb["relL2_mean"],    # deterministic w/ posterior mean @ best snapshot
                "relL2_probabilistic": out_pbb["relL2_stoch"],    # stochastic predictor @ best snapshot
                "certificate_final": out_pbb["certificate"],      # high-MC cert at best snapshot
                "certificate_track_best": out_pbb.get("certificate_track_best", float("nan")),
                "best_epoch": out_pbb.get("best_epoch", -1),
                "time_det_sec": time_det,
                "time_pbb_sec": out_pbb.get("time", float("nan")),
                "best_state_path": cfg_run.save_best_path,
                "posterior_mean_det_path": mean_path,
            }
            # per-run CSV/JSON (optional but handy)
            pd.DataFrame([row]).to_csv(metrics_path_csv, index=False)
            with open(metrics_path_json, "w") as f:
                json.dump(row, f, indent=2)

            rows.append(row)

            # free
            del pb_model; _gpu_hygiene()

    # --- Write a single tidy CSV for the whole sweep ---
    df = pd.DataFrame(rows)
    sweep_csv = os.path.join(outdir, f"sweep_summary_seed{seed}.csv")
    df.to_csv(sweep_csv, index=False)
    print(f"\nSaved sweep summary to: {sweep_csv}")
    return df


In [None]:
# 3) Define your grids
Nu_list = [20, 50, 100, 200]
Nf_list = [2000, 4000, 6000, 8000, 10000]

In [None]:
cfg = PINNConfig(
    Nf_master=max(Nf_list),

    # training knobs
    adam_epochs=1000, lbfgs_max_iter=200,
    pbb_epochs=1000, pbb_lr=1e-4, pbb_momentum=0.95,
    pbb_mc_train=1,

    # FINAL certificate as in the paper
    cert_mc=512,

    # tracking bound during training (used for picking the best snapshot)
    # you can set this equal to cert_mc if you want "full" tracking,
    # but 8192–32768 is usually plenty-stable while much cheaper.
    cert_track_every=200,
    cert_track_mc=16,
)
data = BurgersData(cfg)


[Data] downloading https://raw.githubusercontent.com/maziarraissi/PINNs/master/appendix/Data/burgers_shock.mat -> burgers_shock.mat


In [None]:
# 1) Mount Drive (once per runtime)
from google.colab import drive
drive.mount('/content/drive')

# 2) Choose an output directory in Drive
outdir = "/content/drive/MyDrive/pinn_pbb_results/burgers_clean"

# 4) Run
df_summary = run_grid_and_save(cfg, data, Nu_list, Nf_list, outdir_base=outdir, log_wandb=True)
df_summary.head()


MessageError: Error: credential propagation was unsuccessful