### What are we missing?

<div style="font-size:14px">

| **Aspect** | **MLE (Maximum Likelihood Estimation) 最大似然估计** | **MAP (Maximum A Posteriori Estimation) 最大后验估计** |
|------------|------------------------------------------|--------------------------------------------|
| **Objective** | Estimate parameter $\theta$ that maximizes the likelihood of the observed data. | Estimate parameter $\theta$ that maximizes the posterior probability given the data. |
| **Optimization Goal** | $\displaystyle \hat{\theta}_{\text{MLE}} = \arg\max_{\theta} P(D \mid \theta)$ | $\displaystyle \hat{\theta}_{\text{MAP}} = \arg\max_{\theta} P(\theta \mid D)$ |
| **Formula Derivation** | Maximize the likelihood: <br> $\displaystyle \mathcal{L}(\theta) = \prod_{i=1}^{n} P(x_i \mid \theta)$ <br> Take the log: <br> $\displaystyle \log \mathcal{L}(\theta) = \sum_{i=1}^{n} \log P(x_i \mid \theta)$ <br> Then: <br> $\displaystyle \hat{\theta}_{\text{MLE}} = \arg\max_{\theta} \log P(D \mid \theta)$ | Use Bayes’ Theorem: <br> $\displaystyle P(\theta \mid D) = \frac{P(D \mid \theta) P(\theta)}{P(D)}$ <br> Ignore constant $P(D)$: <br> $\displaystyle \hat{\theta}_{\text{MAP}} = \arg\max_{\theta} P(D \mid \theta) P(\theta)$ <br> or log-form: <br> $\displaystyle \hat{\theta}_{\text{MAP}} = \arg\max_{\theta} \left[ \log P(D \mid \theta) + \log P(\theta) \right]$ |
| **Includes Prior?** | ❌ No | ✅ Yes |
| **Sensitivity to Prior** | Not sensitive (no prior used) | Sensitive to prior choice |
| **Overfitting Risk** | Higher, especially for small data | Lower, prior acts as regularizer |
| **Asymptotic Behavior** | As $n \to \infty$, MLE is consistent | As $n \to \infty$, MAP $\to$ MLE |
| **Computational Complexity** | Lower (no prior term) | Higher (includes prior) |
| **Interpretation** | Frequentist — parameters are fixed | Bayesian — parameters are random variables |
| **Uniform Prior Case** | MLE = MAP | Yes, if $P(\theta)$ is uniform |
| **Regularization View** | No regularization | Prior acts like regularization <br> Gaussian prior $\Rightarrow L_2$ <br> Laplace prior $\Rightarrow L_1$ |
| **Example: Gaussian Likelihood** | $x_i \sim \mathcal{N}(\mu, \sigma^2)$ <br> $\displaystyle \hat{\mu}_{\text{MLE}} = \frac{1}{n} \sum x_i$ | Prior: $\mu \sim \mathcal{N}(\mu_0, \tau^2)$ <br> $\displaystyle \hat{\mu}_{\text{MAP}} = \frac{n\sigma^{-2}}{n\sigma^{-2} + \tau^{-2}} \bar{x} + \frac{\tau^{-2}}{n\sigma^{-2} + \tau^{-2}} \mu_0$ |

<div>


<div style="font-size:14px">
<p>Modeling uncertainty is key to capture sparse signal in low SNR environments.<br>
Alternatives like voting/agreement rate in Random-Forest–like models are not good enough.<br>
Probabilistic Programming models provide more principled uncertainty estimation (❌ not prefect though).</p>

$
\underbrace{P(w \mid D)}_{\text{posterior}} 
= \frac{
    \overbrace{P(D \mid w)}^{\text{likelihood}} \cdot 
    \overbrace{P(w)}^{\text{prior}}
}{
    \underbrace{P(D)}_{\text{evidence}}
}
$

<p>General Form of Prior P(w):</p>

$
P(w) := \mathbb{E}_{x, t, \theta, \varepsilon} \left[ P(w \mid x, t, \theta, \varepsilon) \right] = \int P(w \mid x, t, \theta, \varepsilon) \, P(x) P(t) P(\theta) P(\varepsilon) \, dx \, dt \, d\theta \, d\varepsilon \ \text{(weighted average)}\\
P(w \mid x, t, \theta, \varepsilon) = \mathcal{F}(x, t, \theta, \varepsilon) \approx \underbrace{P(w \mid \theta)}_{
\begin{array}{c}
    \text{usually in (Deep)ProbProg models (e.g. BNN)}\\
    \text{assume static, noise-free and feature-independent}
\end{array}
}
$

<p>Conditioning terms:</p>

- $x$: input/context — adapts prior to input
- $t$: time — allows temporal dynamics
- $\theta$: hyperparameters — controls prior structure
- $\varepsilon$: noise — models stochasticity


## Bayesian Neural Network(BNN) of Deep Probabilistic Programming as an approximated implementation of MAP Estimation

<p>Assuming:</p>

- Likelihood $P(D \mid w) = P(y_{1:N} \mid x_{1:N}, w) \overbrace{=}^{
    \begin{array}{c}
        \text{autoregressive}\\
        \text{decomposition}
    \end{array}
}\\
\prod_{i=1}^N P(y_i \mid y_{<i}, x_{\le i}, w) \overbrace{\approx}^{\text{model}}\\
\prod_{i=1}^N \mathcal{N}(y_i; f_w(y_{<i}, x_{\le i}), \sigma^2)\\
\Rightarrow \log P(D \mid w) = \sum_{i=1}^N \log P(y_i \mid y_{<i}, x_{\le i}, w) = -\frac{1}{2\sigma^2} \sum_{i=1}^N (y_i - f_w(y_{\le i}, x_{\le i}))^2 + \text{const}$

- Prior $P(w) = \mathbb{E}_{x,t,\varepsilon}[P(w \mid x, t, \theta, \varepsilon)] \overbrace{\approx}^{\text{model}} P(w \mid \theta) = \mathcal{N}(w; 0, \tau^2 I)\\
\Rightarrow \log P(w) = -\frac{1}{2\tau^2} \|w\|^2 + \text{const}$


<p>The MAP objective becomes:</p>

$
w^* = \arg\max_w P(w \mid D) = \arg\max_w P(D \mid w) \cdot P(w) = \arg\max_w log P(D \mid w) + log P(w)
= \arg\max_w \left( \sum_{i=1}^N \log P(y_i \mid y_{<i}, x_{\le i}, w) + \log P(w) \right)\\
= \arg\min_w \left( - \sum_{i=1}^N \log P(y_i \mid y_{<i}, x_{\le i}, w) - \log P(w) \right)
$

- Note that because **MLE** and **MAP** use argmax/argmin to formulate the problem, they are **point estimates**, but it can be used to retrieve approximation of the full distribution

<p>with Gaussian assumption:</p>

$
= \arg\min_w \left(\frac{1}{2\sigma^2} \sum_{i=1}^N (y_i - f_w(y_{<i}, x_{\le i}))^2 + \frac{1}{2\tau^2} \|w\|^2 \right)
$

<p>with i.i.d. assumption:</p>

$
= \arg\min_w \left(\frac{1}{2\sigma^2} \sum_{i=1}^N (y_i - f_w(x_i))^2 + \frac{1}{2\tau^2} \|w\|^2 \right)
$

<p>Which is equivalent to:</p>

$
\text{Loss}(w) = \text{MSE loss} + \text{L2 regularization} = \sum_{i=1}^N \left( y_i - f_w(x_i) \right)^2 + \lambda \cdot \|w\|^2, \quad \text{with } \lambda = \frac{\sigma^2}{\tau^2}
$

<div>

<div style="font-size:14px">

## All BNN Training Methods (forward → loss compute → backward → weight update):

| **Method**            | **Posterior Type**                           | **Inference Type**     | **Assumptions Made**                                                                                                                                          | **Uncertainty Quality** | **Scalability** | **Compute**    | **Packages**                           | **References**                                       | **Assumptions Handled By Model?**                                        | **Exact Posterior in Limit?**         | **Overfitting Risk**                                                         |
|----------------------|-----------------------------------------------|-------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|------------------|----------------|----------------------------------------|------------------------------------------------------|---------------------------------------------------------------------------|----------------------------------------|--------------------------------------------------------------------------------|
| **Bayes by Backprop** | Mean-field Gaussian                          | Variational              | Weights are independent; posterior is fully factorized Gaussian                                                                                               | Medium                   | High             | Low            | Pyro, Blitz-BNN, Bayesian-Torch        | Blundell et al., 2015                               | ❌ Posterior factorization not modeled explicitly                           | ❌ No                                 | **High** — limited posterior capacity encourages under-regularized solutions |
| **Flipout**           | Factorized Gaussian (decorrelated noise)     | Variational              | Weights are independent; Gaussian posterior; noise decorrelated across examples                                                                               | Medium+                  | High             | Low–Medium     | TensorFlow Probability                 | Wen et al., 2018                                  | ⚠️ Decorrelated noise reduces gradient variance, not structural assumptions  | ❌ No                                 | **Medium–High** — slightly improved over BBP                                 |
| **SGLD**              | Sampled Posterior                            | MCMC                     | Langevin dynamics without MH; uses minibatches; assumes step size ε→0; assumes stochastic gradient noise does not dominate Langevin noise    | High                     | High             | Medium         | PyTorch-Bayes, Emukit                  | Welling & Teh, 2011                                | ⚠️ Approximate posterior unless step size decays and noise is unbiased      | ⚠️ Only asymptotically (ε→0)           | **Medium** — implicit noise helps, but bias may hurt posterior accuracy     |
| **pSGLD**             | Preconditioned Posterior Samples             | MCMC                     | Same as SGLD; assumes curvature can be estimated online to scale gradients; assumes stability in adaptive noise statistics                                     | High+                    | High             | Medium+        | TFP (custom), Pyro (custom)            | Li et al., 2016                                    | ⚠️ Curvature modeled adaptively; still asymptotic correctness only           | ⚠️ Only asymptotically (ε→0)           | **Low–Medium** — better exploration helps avoid local minima                |
| **HMC**               | Exact Posterior                              | MCMC                     | No approximation; assumes smooth potential energy; full data gradients; no subsampling allowed                                                                | Very High                | Low              | Very High      | Stan, PyMC3, TF Probability            | Neal, 2011                                         | ✅ Fully nonparametric; assumptions hold for smooth models                  | ✅ Yes (only asymptotically on sample -> ∞)    | **Low** — proper posterior prevents overfitting                             |
| **Laplace Approx.**   | Gaussian around MAP                          | Deterministic            | Posterior is Gaussian near MAP; assumes curvature (Hessian) captures uncertainty                                                                              | Low–Medium               | High             | Low (Post-hoc) | LaplaceTorch, GPyTorch                 | MacKay, 1992                                       | ❌ Strong Gaussianity assumption near MAP                                   | ❌ No                                 | **High** — narrow posterior underestimates uncertainty                      |
| **Expectation Prop.** | Moment-matched Gaussian                      | Deterministic Approx.    | Approximates each likelihood term with Gaussian; moment-matching used to update posterior                                                                     | Medium–High              | Low              | High           | Edward1, GPy                           | Minka, 2001                                        | ❌ Still assumes Gaussian factors; better fit than mean-field                | ❌ No                                 | **Medium** — improved fit, but still approximate                            |
| **Functional BNN**    | Posterior over Functions (not weights)       | Hybrid (VI + GP-style)   | Prior and posterior over output functions; architecture defines function space distribution                                                                   | Very High                | Low              | Very High      | Neural Processes, GPJax, Functorch     | Garnelo et al., 2018; Rasmussen & Williams, 2006  | ✅ Function-space inference avoids parametric assumptions in weights         | ✅ Yes                                | **Low** — function-level regularization is strong                           |

</div>


In [None]:
# https://www.youtube.com/watch?v=LlzVlqVzeD8&list=PLHSMzCAQRltMGNQ9MxE7YBV87N0btrlUo&ab_channel=PyData
# https://www.youtube.com/watch?v=KhAUfqhLakw&list=PLHSMzCAQRltMGNQ9MxE7YBV87N0btrlUo&index=2&ab_channel=Enthought
# https://www.youtube.com/watch?v=i5PEMt21dO8&list=PLBjSxdPpAJGz-zSjO1Lpkc-0ibLTcz2o9&ab_channel=SMILES-SummerSchoolofMachineLearningatSK

# | Feature / Library                | **Pyro**                  | **Blitz-Bayesian-PyTorch**  | **Bayesian-Torch**            | **PyMC**                       | **NumPyro**               | **TensorFlow Probability (TFP)**        |
# | -------------------------------- | ------------------------- | --------------------------- | ----------------------------- | ------------------------------ | ------------------------- | --------------------------------------- |
# | **Backend**                      | PyTorch                   | PyTorch                     | PyTorch                       | Aesara / JAX                   | JAX                       | TensorFlow                              |
# | **Type**                         | Probabilistic Programming | Lightweight BNN             | Modular Bayesian Layers       | Probabilistic Programming      | Probabilistic Programming | Probabilistic Programming + Layers      |
# | **Inference Methods**            | SVI, HMC, NUTS            | Variational Inference       | VI, MC Dropout                | NUTS, HMC, ADVI                | NUTS, HMC, SVI            | HMC, VI, EM                             |
# | **BNN Support**                  | ✔️ Custom BNNs            | ✔️ Easy BNNs via decorators| ✔️ Deep BNNs & drop-in layers | ⚠️ Basic BNN support          | ⚠️ Some support (manual)  | ✔️ Keras BNN Layers                    |
# | **Ease of Use**                  | Medium                    | Easy                        | Medium                        | Easy                           | Medium                    | Medium                                  |
# | **Deep Learning Scale**          | ✔️ Yes                    | ✔️ Yes                     | ✔️ Yes                        | ❌ Limited                    | ⚠️ Limited GPU support    | ✔️ Yes (via TensorFlow)                |
# | **GPU Acceleration**             | ✔️ Yes                    | ✔️ Yes                     | ✔️ Yes                        | ⚠️ Limited (JAX backend only) | ✔️ JAX (fast!)            | ✔️ TensorFlow                          |
# | **Good for Probabilistic Logic** | ✔️ Yes                    | ❌                         | ❌                            | ✔️ Yes                        | ✔️ Yes                    | ✔️ Yes                                 |
# | **Learning Curve**               | Steep                     | Low                         | Medium                        | Medium                         | Medium                    | Medium                                  |
# | **Community & Maturity**         | Large (Uber, academic)    | Small                       | Medium                        | Large & mature                 | Growing fast (Google)     | Large (Google)                          |
# | **Best Use Case**                | Custom probabilistic BNNs | Quick, practical BNNs       | Plug-and-play BNNs            | Statistical models, small BNNs | Fast HMC/VI for research  | Keras-style probabilistic deep learning |

<div style="font-size:14px">

---
### MCMC-pSGLD: (Markov Chain Monte Carlo - preconditioned Stochastic Gradient Langevin Dynamics)
#### Forward Pass:
- **Monte Carlo**: most MAP models are generative (modeling the approximated real joint distribution), including this one<br>
  we cannot simply use mean from each parameter node to calculate the output<br>
  A single prediction requires multiple forward passes (samples) from our trained model $p(\theta \mid \mathcal{D}_{\text{train}})$,<br>
  using a random sampling method that matches the true posterior distribution (high-dimensional, intractable, unnormalized)<br>

$$
\begin{aligned}
\underbrace{p(y_{\text{test}} \mid X_{\text{test}}, \mathcal{D}_{\text{train}})}_{\textbf{Bayesian Prediction}}
&= \int 
\underbrace{p(y_{\text{test}} \mid X_{\text{test}}, \theta)}_{\textbf{Likelihood (model output sample)}}
\cdot 
\underbrace{p(\theta \mid \mathcal{D}_{\text{train}})}_{\textbf{True Posterior}}
\, d\theta 
\\[1.2em]
&= \underbrace{\mathbb{E}_{\theta \sim p(\theta \mid \mathcal{D}_{\text{train}})} \left[ p(y_{\text{test}} \mid X_{\text{test}}, \theta) \right]}_{\textbf{Expectation over Posterior}}
\\[1.2em]
&\approx \underbrace{\frac{1}{T} \sum_{i=1}^T p(y_{\text{test}} \mid X_{\text{test}}, \theta^{(i)})}_{\textbf{Monte Carlo Estimate}}
\quad \text{where } \theta^{(i)} \sim p(\theta \mid \mathcal{D}_{\text{train}})
\end{aligned}
$$

#### Loss Computation:
- Refer to previous section (without Gaussian/i.i.d. assumptions)

- The MAP point estimates contains loss definition: <br>
    $ w^* = \arg\min_w \left( - \sum_{i=1}^N \log P(y_i \mid y_{<i}, x_{\le i}, w) - \log P(w) \right) $

#### Backward Pass:
- we try to evaluate weight through the joint distribution of posterior, not just MAP point estimate
- instead of computing gradient of loss (negative log-likelihood), here we compute gradient of log-posterior
- the "Training"(posterior sampling) happens after all data is present, we use an algorithm to explore parameter space(state space of MC, support for posterior distribution) to find the most 'fitted'(approximate) posterior distribution(joint) over many iterations (iteration (time in Markov Chain in latent space) != sample (time in sequential samples))
- **Markov Chain**: the True Posterior $p(\theta \mid \mathcal{D}_{\text{train}})$ can be approximated as stationary distribution of a Markov Chain(discrete) as number of steps goes to infinity (the state space of this Markov process is also the support of the true posterior, which is $\theta \in \mathbb{R}^D$)<br>
    if we assume:
    - **Ergodicity**: The chain forgets its starting point.
        - $\forall \theta, \theta', \exists t \in \mathbb{N} \text{ such that } K^t(\theta' \mid \theta) > 0$
        - Ergodicity ⇐ Aperiodicity + Irreducibility
            - **Aperiodicity**: No cyclic pattern in transitions
            - **Irreducibility**: Every state is reachable from every other state in finite steps
    - **Time-homogeneity**: Transition probabilities $K$ are fixed over time
    - Target Invariance via **Detailed Balance**:
        - $p(\theta \mid \mathcal{D}) K(\theta' \mid \theta) = p(\theta' \mid \mathcal{D}) K(\theta \mid \theta') \quad \text{for all } \theta, \theta'$
        - this is actually Microscopic symmetry in parameter space: forward flow = backward flow
        - implies **Stationarity** if K is chosen correctly regarding D: distribution is fixed
            - $p(\theta' \mid \mathcal{D}) = \int_\Theta K(\theta' \mid \theta) p(\theta \mid \mathcal{D}) \, d\theta \quad \text{for all } \theta'$
    - ✅ all previous assumptions:
        - can always be constructed(exist) via K regardless of D (training data)
        - However, in some assumptions, K is also dependent on D, which means K needs to be carefully constructed

- Let $\{\theta_t\}_{t=0}^\infty$ be a Markov chain
- $A$ be a measurable region of parameter space (e.g. i-th component of θ>0.5, accuracy(θ)>90%, etc.)

$$
\begin{aligned}
\lim_{t \to \infty} \mathbb{P}(\theta_t \in A)
&= \lim_{t \to \infty}
\underbrace{
\int_{\Theta} \cdots \int_{\Theta}
}_{t \text{ nested integrals}} \;
\underbrace{K(\theta_t \mid \theta_{t-1})}_{\text{transition kernel}} \cdots K(\theta_1 \mid \theta_0)
\underbrace{\mu_0(\theta_0)}_{\text{initial distribution}} \,
\mathbf{1}_A(\theta_t)
\; d\theta_0 \cdots d\theta_t
\\[2ex]
&\quad \textcolor{gray}{\text{// Expand marginal probability of } \theta_t \in A \text{ via the full joint chain law: } \mu_0 \cdot K \cdots K}
\\[2ex]

&= \lim_{t \to \infty}
\int_A
\left(
\int_{\Theta} \cdots \int_{\Theta}
K(\theta_t \mid \theta_{t-1}) \cdots K(\theta_1 \mid \theta_0)
\mu_0(\theta_0)
\; d\theta_0 \cdots d\theta_{t-1}
\right)
d\theta_t
\\[2ex]
&\quad \textcolor{gray}{\text{// Pull indicator } \mathbf{1}_A(\theta_t) \text{ outside as domain of outermost integral becomes } A}
\\[2ex]

&=
\int_A
\left(
\lim_{t \to \infty}
(K^t \mu_0)(\theta_t)
\right)
d\theta_t
\\[2ex]
&\quad \textcolor{gray}{\text{// Recognize the nested integral as repeated application of the Markov operator: } K^t \mu_0}
\\[2ex]

&=
\int_A p(\theta \mid \mathcal{D}) \, d\theta
\\[2ex]
&\quad \textcolor{gray}{
\text{// By ergodic theorem: if } K \text{ is ergodic and satisfies detailed balance w.r.t. } p(\theta \mid \mathcal{D})
\Rightarrow \lim_{t \to \infty} K^t \mu_0 = p(\theta \mid \mathcal{D}) \text{ in distribution}
}
\\[2ex]
\end{aligned}
$$

$$
\begin{aligned}
p(\theta \mid \mathcal{D}) 
&= \frac{p(\mathcal{D}, \theta)}{p(\mathcal{D})}
= \frac{p(\mathcal{D} \mid \theta)\, p(\theta)}{\int_{\mathbb{R}^d} p(\mathcal{D} \mid \vartheta)\, p(\vartheta)\, d\vartheta}
= \frac{1}{\int_{\mathbb{R}^d} p(\mathcal{D} \mid \vartheta)\, p(\vartheta)\, d\vartheta} \cdot p(\mathcal{D} \mid \theta)\, p(\theta) \\[10pt]

&= \frac{1}{Z} \cdot p(\mathcal{D} \mid \theta)\, p(\theta)
= \frac{1}{Z} \cdot \exp\left( \log p(\mathcal{D} \mid \theta) + \log p(\theta) \right)
= \frac{1}{Z} \cdot \exp\left( -[-\log p(\mathcal{D} \mid \theta) - \log p(\theta)] \right) \\[10pt]

&= \underbrace{\frac{1}{Z} \cdot \exp\left( -U(\theta) \right)}_{\text{Gibbs (Boltzmann) form}}
\qquad \text{where:} \quad
\begin{cases}
\text{Potential Energy}: U(\theta) := -\log p(\mathcal{D} \mid \theta) - \log p(\theta) >= 0 \quad \text{(both non-negative)}\\[4pt]
\begin{array}{c}
    \text{Partition Function}\\
    \text{Normalization Constant}
\end{array}
: Z := \int_{\mathbb{R}^d} \exp(-U(\vartheta))\, d\vartheta
\end{cases}
\end{aligned}
$$

- There are many models in physics(statistical mechanics, thermo/quantum dynamics) that relates potential energy field to probability distribution
- the Posterior distribution can be written as a Potential-Energy-based Particle-Diffusion model in latent-space as well
    - this is only to help us intuitively understand the distribution, the physical analogy is not necessary
    - we need to find a Markov Chain that:
        - has transition kernel that this potential distribution is one of its stationary solutions
        - satisfy previous assumptions
    - some properties that we want/realized:
        - the lower the energy, the higher the probability (posterior in latent space)
        - the shape of posterior is complex (multiple modes/peaks) in latent space
        - the kernel needs to work as a compass to guide the transition towards nearest (depends on space type) local maxima of probability or minima of potential energy
            - only then, the Markov Chain can stay longer in the more probable region to form the correct distribution
        - also it needs to have some random/stochastic/drifting/diffusion properties to help explore the whole latent space (guarantee some of previous assumptions)

$$
\begin{aligned}
&\underbrace{d\theta_t = -\nabla U(\theta_t)\,dt + \sqrt{2}\,dW_t}_{\substack{\text{Over-Damped Langevin dynamics SDE:} \\ \text{gradient drift + Gaussian noise}}}
\\[1.2em]
&\Rightarrow 
\underbrace{
\frac{\partial \rho(\theta, t)}{\partial t} = \nabla \cdot \left( \nabla U(\theta)\, \rho(\theta, t) + \nabla \rho(\theta, t) \right)
}_{\substack{\text{Fokker–Planck equation:} \\ \text{evolution of density}}}
\\[1.5em]
&\Rightarrow 
\nabla \cdot \left( \nabla U(\theta)\, \rho(\theta) + \nabla \rho(\theta) \right)
= \nabla \cdot \left( \nabla U(\theta) \cdot \tfrac{1}{Z} e^{-U(\theta)} + \nabla \left( \tfrac{1}{Z} e^{-U(\theta)} \right) \right)
= \nabla \cdot \left( \tfrac{1}{Z} e^{-U(\theta)} \nabla U(\theta) - \tfrac{1}{Z} e^{-U(\theta)} \nabla U(\theta) \right)
= \nabla \cdot (0) = 0
\\[1.5em]
&\Leftrightarrow 
\rho(\theta) = \tfrac{1}{Z} \exp(-U(\theta)) \;\text{is stationary under Langevin dynamics}
\end{aligned}
$$

- Time-homogeneous: $U(\theta)$ is fixed for a given posterior
- Ergodicity: (the Fokker-Planck equation)
    - covers full support $\theta \in \mathbb{R}^D$
    - each state is reachable due to diffusion (Brownian term)
    - No periodicity due to stochasticity
- Detailed Balance (Microscopic Reversibility):
    - The Fokker–Planck operator is self-adjoint in the weighted space $L^2(p^*)$
    - The generator of Langevin dynamics is reversible with respect to $p^*(\theta)$

- Alternatively:
    - instead of Over-Damped Langevin: Potential + Diffusion(Noise) (friction high enough just to remove momentum)
    - we can have:
        - Under-Damped Langevin: Potential + Diffusion + Kinetic(Momentum)
        - Hamiltonian SDE: Potential + Kinetic (energy perfectly conserved)
        - Noisy Hamiltonian SDE: Potential + Kinetic + Diffusion (energy not perfectly conserved)
    - Kinetic => better preserve energy => inertial exploration => better long-range exploration

$$
H(\theta, p) = \underbrace{-\log p(\theta \mid \mathcal{D})}_{\text{Potential } U(\theta)} + \underbrace{\frac{1}{2} p^T M^{-1} p}_{\text{Kinetic } K(p)}
$$


#### Weight Update:

$$
\begin{aligned}
&\underbrace{d\theta_t = -\nabla_\theta U(\theta_t)\,dt + \sqrt{2}\,dW_t}_{\text{Langevin SDE (Itô)}} 
= 
\underbrace{\theta_{t+1} = \theta_t - \epsilon \nabla_\theta U(\theta_t) + \sqrt{2\epsilon} \, \xi_t}_{\text{Euler–Maruyama discretization} \quad \xi_t \sim \mathcal{N}(0, I)} 
\Rightarrow
\underbrace{q(\theta'|\theta_t) = \mathcal{N}\left(\theta' \mid \theta_t - \epsilon \nabla_\theta U(\theta_t), 2\epsilon I\right)}_{\text{Proposal distribution}} 
\\
&\Rightarrow 
\underbrace{
\alpha(\theta_t, \theta') = \min\left(1, 
\frac{
e^{-U(\theta')} \cdot 
\exp\left(-\frac{1}{4\epsilon} \|\theta_t - \theta' + \epsilon \nabla_\theta U(\theta')\|^2 \right)
}{
e^{-U(\theta_t)} \cdot 
\exp\left(-\frac{1}{4\epsilon} \|\theta' - \theta_t + \epsilon \nabla_\theta U(\theta_t)\|^2 \right)
}
\right)}_{\text{Metropolis–Hastings acceptance prob. (Detailed Balance)}} 
\\
&\Rightarrow 
\underbrace{
\theta_{t+1} =
\begin{cases}
\theta', & \text{with probability } \alpha(\theta_t, \theta') \\
\theta_t, & \text{otherwise}
\end{cases}
}_{\text{MALA (Metropolis-Adjusted Langevin Algorithm)}} 
= 
\underbrace{
\theta_{t+1} =
\begin{cases}
\theta_t - \epsilon \nabla_\theta U(\theta_t) + \sqrt{2\epsilon}\,\xi_t, & \text{if accepted} \\
\theta_t, & \text{otherwise}
\end{cases}
}_{\text{SGLD with MH correction: samples from } \pi(\theta) \propto e^{-U(\theta)}}
\end{aligned}
$$

- in pSGLD:
    - we use mini-batches and preconditioning to approximate and skip Metropolis-Hastings correction for bias introduced in Euler-Maruyama discretization
        - but as long as step size $\epsilon_t \to 0$ slowly and the preconditioner stabilizes, the sampling bias from ignoring MH can be minimized (Li et al., 2016)
    - pSGLD has better uncertainty estimation than SGLD because it incorporates local curvature information (via preconditioning)
        - the injected noise and gradient step are scaled according to local curvature, rather than isotropic noise in standard SGLD

$$
\begin{aligned}
d\theta_t &= \theta_t - \epsilon \left\{ \underbrace{ - \frac{1}{N} \sum_{i=1}^N \nabla_\theta \log p(y_i|x_i, \theta_t) - \nabla_\theta \log p(\theta_t) }_{ \nabla_\theta U(\theta_t) } \right\} + \sqrt{2\epsilon} \, \xi_t \\
&\xRightarrow{\text{minibatch approx.}} \theta_{t+1} = \theta_t - \epsilon \left\{ \underbrace{ - \frac{N}{|\mathcal{B}_t|} \sum_{i \in \mathcal{B}_t} \nabla_\theta \log p(y_i|x_i, \theta_t) - \nabla_\theta \log p(\theta_t) }_{ \hat{\nabla}_\theta U(\theta_t) \quad \text{stochastic gradient estimate from mini-batch}} \right\} + \sqrt{2\epsilon} \, \xi_t \\
&\xRightarrow{\text{preconditioned}} \theta_{t+1} = \theta_t - \epsilon \cdot \frac{1}{2} \underbrace{G(\theta_t)}_{\text{diagonal preconditioning matrix (RMSprop-style)}} \cdot \hat{\nabla}_\theta U(\theta_t) + \underbrace{\mathcal{N}(0, \epsilon G(\theta_t))}_{\text{noise}} \\
&\Rightarrow \theta_{t+1} = \theta_t - \frac{\epsilon}{2} G(\theta_t) \hat{\nabla}_\theta U(\theta_t) + \eta_t, \quad \eta_t \sim \mathcal{N}(0, \epsilon G(\theta_t))
\end{aligned}
$$


</div>


<div style="font-size:14px">

---
### VI-Flipout: (Variational Inference - Flipout)

<span style="color:yellow">
Warning:

- in Variational Inference(VI), by forcing each of neural network’s parameter to marginally live inside a fixed, pre-chosen distributional family (e.g., Gaussians):
    - No multi-modal behavior
    - No skewness
    - No heavy tails
    - No nonlinear dependencies between weights
    - which violates the bare-minimum to model the true posterior (high-dimensional, highly-complex, unnormalized, intractable)

- This has far-reaching consequences beyond posterior approximation, it leaks into the model’s functional behavior:
    - by imposing strong geometric constraints on the latent space, it may group together very different functions (far apart in function/feature space)
        -  this poses a very serious threat when randomly sampling using it as a generative model (unnatural in-between regions)
    - in some scenarios, it is possible to act as a regularizer that improves generalization — even if it’s fundamentally incorrect

- in MCMC, it preserves probability(true posterior), also not geometry(local latent space <-> local feature space):
    - interpolation: it also cannot make sure that nearby points in latent space correspond to similar outputs in data/feature space
    - geometry-aware models include:
        - Autoencoder-based models
        - Normalizing Flows
        - Diffusion Models
        - Energy-Based Models + Score-Based Learning
        - Neural ODEs & Continuous Normalizing Flows
        - Metric Learning / Contrastive Learning

- even though there is no mathematical guarantee, in reality, basic NN arch still offer some protection that features are still mostly "continuous", although with high variance, in latent space:
    - Gradient descent (or pSGLD) keeps weight updates small, most of the time
    - you stay in a "smooth" part of the network function space, most of the time
    - smooth activations reduce the chance of sharp transitions
    - gradient-based optimization encourages local stability
    - **DO NOT take these for granted**

</span> 

Let a **fully connected layer** have weight matrix $\mathbf{W} \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}$. In Flipout:

* Variational posterior:

  $$
  q(\mathbf{W}) = \mathcal{N}(\mu, \sigma^2)
  \quad \text{with reparam:} \quad
  \mathbf{W} = \mu + \sigma \odot \epsilon
  \quad \epsilon \sim \mathcal{N}(0, I)
  $$

* Flipout perturbs each weight **per sample** using:

  $$
  \mathbf{W}^{(i)} = \mu + (\sigma \odot \epsilon) \cdot r^{(i)} s^{(i)\top}
  $$

  where:

  * $r^{(i)} \in \{+1, -1\}^{d_{\text{in}}}$, $s^{(i)} \in \{+1, -1\}^{d_{\text{out}}}$
  * $r^{(i)}, s^{(i)} \sim \text{Bern}(0.5)$ (Rademacher random variables)
  * Note: $\epsilon \in \mathbb{R}^{d_{\text{out}} \times d_{\text{in}}}$ shared across batch

#### FORWARD PASS (with Flipout)

Let the input batch be $\mathbf{X} \in \mathbb{R}^{B \times d_{\text{in}}}$, where $B$ is the batch size. For each sample $i \in \{1, \dots, B\}$:

1. **Shared weight perturbation**:

   $$
   \Delta \mathbf{W} = \sigma \odot \epsilon
   $$

2. **Pseudo-independent perturbed weights (via Flipout)**:

   $$
   \Delta \mathbf{W}^{(i)} = \Delta \mathbf{W} \odot \left( s^{(i)} r^{(i)\top} \right)
   $$

3. **Compute output** (for each example $i$):

   $$
   \mathbf{y}^{(i)} = \mathbf{x}^{(i)} \cdot \mu^\top + \left( \left( \mathbf{x}^{(i)} \odot r^{(i)} \right) \cdot \Delta \mathbf{W}^\top \right) \odot s^{(i)}
   $$

4. **Batched matrix form** (Flipout trick):

   $$
   \mathbf{Y} = \mathbf{X} \mu^\top + \left( (\mathbf{X} \odot \mathbf{R}) \cdot \Delta \mathbf{W}^\top \right) \odot \mathbf{S}
   $$

   where:

   * $\mathbf{R}, \mathbf{S} \in \mathbb{R}^{B \times d}$ are matrices of random signs

#### LOSS FUNCTION: Evidence Lower Bound (ELBO)

We optimize the **negative ELBO**:

$$
\mathcal{L}_{\text{VI}}(\theta) = -\mathbb{E}_{q(\mathbf{W})} \left[ \log p(\mathcal{D} \mid \mathbf{W}) \right] + \text{KL}\left[ q(\mathbf{W}) \,\|\, p(\mathbf{W}) \right]
$$

* First term (expected log-likelihood) approximated via Monte Carlo (Flipout samples).
* Second term (KL divergence) is analytical for Gaussian:

If:

* Prior: $p(\mathbf{W}) = \mathcal{N}(0, \sigma_0^2 I)$
* Posterior: $q(\mathbf{W}) = \mathcal{N}(\mu, \sigma^2)$

Then:

$$
\text{KL}(q \| p) = \sum_{j=1}^d \left[ \log \frac{\sigma_0}{\sigma_j} + \frac{\sigma_j^2 + \mu_j^2}{2\sigma_0^2} - \frac{1}{2} \right]
$$

#### BACKWARD PASS

We compute gradients of:

$$
\mathcal{L} = - \sum_{i=1}^B \log p(y^{(i)} \mid \mathbf{x}^{(i)}, \mathbf{W}^{(i)}) + \text{KL}(q \| p)
$$

Backprop proceeds as:

* Gradients flow through Flipout layers via the reparameterization trick.
* The key is to ensure that gradients w\.r.t. $\mu$ and $\sigma$ are **unbiased** estimates of the gradient of the ELBO.

Let’s derive:

- Gradient w\.r.t. μ

$$
\nabla_\mu \mathcal{L} \approx \frac{1}{B} \sum_{i=1}^B \nabla_{\mathbf{W}^{(i)}} \log p(y^{(i)} \mid \mathbf{x}^{(i)}, \mathbf{W}^{(i)}) \cdot \nabla_\mu \mathbf{W}^{(i)} + \nabla_\mu \text{KL}
$$

But:

$$
\nabla_\mu \mathbf{W}^{(i)} = I
\quad \Rightarrow \quad
\nabla_\mu \mathcal{L} \approx \frac{1}{B} \sum_{i=1}^B \nabla_{\mathbf{W}^{(i)}} \log p(y^{(i)} \mid \mathbf{x}^{(i)}, \mathbf{W}^{(i)}) + \nabla_\mu \text{KL}
$$

- Gradient w\.r.t. σ

$$
\nabla_\sigma \mathbf{W}^{(i)} = \epsilon \odot (r^{(i)} s^{(i)\top})
\quad \Rightarrow \quad
\nabla_\sigma \mathcal{L} \approx \frac{1}{B} \sum_{i=1}^B \nabla_{\mathbf{W}^{(i)}} \log p(y^{(i)} \mid \mathbf{x}^{(i)}, \mathbf{W}^{(i)}) \odot \left( \epsilon \odot (r^{(i)} s^{(i)\top}) \right) + \nabla_\sigma \text{KL}
$$

- KL Gradient Terms

Analytical:

$$
\frac{\partial}{\partial \mu_j} \text{KL} = \frac{\mu_j}{\sigma_0^2}, \quad
\frac{\partial}{\partial \sigma_j} \text{KL} = -\frac{1}{\sigma_j} + \frac{\sigma_j}{\sigma_0^2}
$$

#### WEIGHT UPDATE (Stochastic Gradient Descent)

Let $\eta$ be the learning rate. We apply gradient descent (or Adam):

* $\mu \leftarrow \mu - \eta \cdot \nabla_\mu \mathcal{L}$
* $\sigma \leftarrow \sigma - \eta \cdot \nabla_\sigma \mathcal{L}$

To ensure numerical stability, you often optimize $\rho = \log \exp(\sigma)$ instead.


<div>


In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# ---------- Bayesian Neural Network Definition ----------
class BayesianMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(BayesianMLP, self).__init__()
        dims = [input_dim] + hidden_dims + [output_dim]
        self.layers = nn.ModuleList()
        for i in range(len(dims)-1):
            self.layers.append(nn.Linear(dims[i], dims[i+1]))

    def forward(self, x):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
        return self.layers[-1](x)

# ---------- pSGLD Sampler ----------
class pSGLD:
    def __init__(self, params, lr=1e-3, alpha=0.99, eps=1e-8, weight_decay=1e-2):
        self.params = list(params)
        self.lr = lr
        self.alpha = alpha
        self.eps = eps
        self.weight_decay = weight_decay
        # Initialize running average of squared gradients
        self.state = {p: torch.zeros_like(p.data) for p in self.params}

    @torch.no_grad()
    def step(self):
        for p in self.params:
            if p.grad is None:
                continue
            grad = p.grad.data + self.weight_decay * p.data
            v = self.state[p]
            v.mul_(self.alpha).addcmul_(grad, grad, value=1 - self.alpha)
            precond = 1.0 / (torch.sqrt(v) + self.eps)
            noise = torch.randn_like(p.data) * torch.sqrt(self.lr * precond)
            p.data.add_(-0.5 * self.lr * precond * grad + noise)

# ---------- Training and Sampling ----------
def train_mcmc(model, X, y, num_epochs=1000, batch_size=64,
               lr=1e-3, alpha=0.99, eps=1e-8, weight_decay=1e-2,
               burn_in=500, collect_every=10, num_samples=100,
               likelihood_noise=1.0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    sampler = pSGLD(model.parameters(), lr=lr, alpha=alpha,
                    eps=eps, weight_decay=weight_decay)

    samples = []
    total_steps = 0
    for epoch in range(num_epochs):
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            sampler.params_grad_zero = False
            model.zero_grad()
            preds = model(xb).squeeze()
            # Gaussian likelihood log p(y|f) ~ -(1/(2*sigma^2))*(y-f)^2
            neg_log_lik = 0.5 / (likelihood_noise**2) * F.mse_loss(preds, yb, reduction='sum')
            neg_log_prior = 0.0
            # Gaussian prior N(0, I): -(1/2)*w^2
            for p in model.parameters():
                neg_log_prior += 0.5 * torch.sum(p**2)
            loss = neg_log_lik + neg_log_prior
            loss.backward()
            sampler.step()
            total_steps += 1

            # Collect samples after burn-in
            if total_steps > burn_in and total_steps % collect_every == 0:
                # Deep copy parameters
                state = {k: v.clone().cpu() for k, v in model.state_dict().items()}
                samples.append(state)
                if len(samples) >= num_samples:
                    return samples
    return samples

# ---------- Prediction with Posterior Samples ----------
def predict_mcmc(model, samples, X_test):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    X_test = torch.Tensor(X_test).to(device)
    preds = []
    with torch.no_grad():
        for state in samples:
            model.load_state_dict(state)
            out = model(X_test).cpu().numpy()
            preds.append(out)
    preds = np.stack(preds, axis=0)  # [num_samples, N_test, output_dim]
    mean = preds.mean(axis=0)
    std = preds.std(axis=0)
    return mean, std

# ---------- Usage Example ----------
if __name__ == '__main__':
    # Generate synthetic data if X, y not provided
    N, D = 1000, 10
    X = np.random.randn(N, D)
    true_w = np.random.randn(D)
    y = X.dot(true_w) + 0.1 * np.random.randn(N)

    # Initialize model
    model = BayesianMLP(input_dim=D, hidden_dims=[50, 50], output_dim=1)

    # Train and collect posterior samples
    samples = train_mcmc(model, X, y,
                         num_epochs=50,
                         batch_size=64,
                         lr=1e-4,
                         weight_decay=1e-4,
                         burn_in=1000,
                         collect_every=20,
                         num_samples=200,
                         likelihood_noise=0.1)

    # Predict on test set
    X_test = np.random.randn(100, D)
    mean_pred, std_pred = predict_mcmc(model, samples, X_test)
    print("Mean predictions:\n", mean_pred)
    print("Uncertainty (std) predictions:\n", std_pred)


ValueError: need at least one array to stack

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import math

# Define the BNN model
class BNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define the pSGLD optimizer
class PSGLD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, beta=0.99, epsilon=1e-8):
        defaults = dict(lr=lr, beta=beta, epsilon=epsilon)
        super(PSGLD, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]
                if 'v' not in state:
                    state['v'] = torch.zeros_like(p.data)
                v = state['v']
                beta = group['beta']
                v.mul_(beta).add_((1 - beta) * grad ** 2)
                preconditioner = 1 / (v.sqrt() + group['epsilon'])
                noise = torch.normal(0, 1, size=p.data.size(), device=p.device)
                p.data.add_(-group['lr'] * preconditioner * grad + math.sqrt(2 * group['lr']) * preconditioner * noise)

# Generate synthetic data (assuming X and y are not provided)
n_samples = 120
X = torch.linspace(-3, 3, n_samples).reshape(-1, 1)
y = torch.sin(X) + 0.1 * torch.randn(n_samples, 1)

# Split into train and test sets
n_train = 100
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]

# Move to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_train = X_train.to(device)
y_train = y_train.to(device)
X_test = X_test.to(device)
y_test = y_test.to(device)

# Set up the model
input_dim = 1
hidden_dim = 50
output_dim = 1
model = BNN(input_dim, hidden_dim, output_dim).to(device)

# Set up the optimizer
lr = 0.001
beta = 0.99
epsilon = 1e-8
optimizer = PSGLD(model.parameters(), lr=lr, beta=beta, epsilon=epsilon)

# Set up the loss function
loss_fn = nn.MSELoss()

# Set up the DataLoader
batch_size = 32
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Training parameters
total_epochs = 1000
burn_in = 500
collect_interval = 10
model_samples = []

# Training loop
for epoch in range(total_epochs):
    model.train()
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        optimizer.step()
    if epoch > burn_in and (epoch - burn_in) % collect_interval == 0:
        model_samples.append({k: v.clone() for k, v in model.state_dict().items()})

# Testing
model.eval()
predictions = []
for state in model_samples:
    model.load_state_dict(state)
    with torch.no_grad():
        y_pred = model(X_test)
    predictions.append(y_pred)
y_pred_mean = torch.mean(torch.stack(predictions), dim=0)
mse = loss_fn(y_pred_mean, y_test)
print(f'Test MSE: {mse.item()}')

Test MSE: inf


In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_regression

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Generate synthetic regression dataset
X, y = make_regression(n_samples=1000, n_features=10, noise=0.1, random_state=42)
X = X.astype(np.float32)
y = y.astype(np.float32).reshape(-1, 1)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize features and target
scaler_x = StandardScaler().fit(X_train)
scaler_y = StandardScaler().fit(y_train)
X_train = scaler_x.transform(X_train)
X_test = scaler_x.transform(X_test)
y_train = scaler_y.transform(y_train).flatten()
y_test = scaler_y.transform(y_test).flatten()

# Convert to PyTorch tensors
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Bayesian Neural Network Architecture
class BNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.out(x)

# Custom pSGLD Optimizer (Preconditioned SGLD)
class pSGLD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay)
        super(pSGLD, self).__init__(params, defaults)
    
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                grad = p.grad.data
                state = self.state[p]
                
                # Initialize state
                if len(state) == 0:
                    state['step'] = 0
                    state['square_avg'] = torch.zeros_like(p.data)
                
                state['step'] += 1
                square_avg = state['square_avg']
                alpha = group['alpha']
                lr = group['lr']
                eps = group['eps']
                weight_decay = group['weight_decay']
                
                # Add weight decay (Gaussian prior)
                if weight_decay != 0:
                    grad.add_(p.data, alpha=weight_decay)
                
                # Update squared gradient average
                square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
                
                # Compute preconditioner (RMS)
                preconditioner = 1.0 / (torch.sqrt(square_avg) + eps)
                
                # Add Gaussian noise for Langevin dynamics
                noise_std = torch.sqrt(torch.tensor(2.0 * lr * preconditioner))
                noise = torch.normal(mean=0.0, std=noise_std.item(), size=p.data.size())
                noise = torch.from_numpy(noise).float()
                
                # Update parameters
                p.data.add_(-lr * 0.5 * preconditioner * grad + noise)

# Hyperparameters
input_dim = X_train.shape[1]
hidden_dim = 128
output_dim = 1
num_epochs = 500
burn_in = 100  # Discard first 100 samples
num_samples = 50  # Posterior samples to keep

# Initialize model and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BNN(input_dim, hidden_dim, output_dim).to(device)
optimizer = pSGLD(model.parameters(), lr=1e-3, alpha=0.99, weight_decay=1e-4)

# Training with MCMC sampling
model.train()
samples = []  # Store weight samples for Bayesian averaging

for epoch in range(num_epochs):
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_x).flatten()
        loss = F.mse_loss(outputs, batch_y)
        loss.backward()
        optimizer.step()
    
    # Store weights after burn-in period with thinning
    if epoch > burn_in and epoch % 10 == 0:
        samples.append({k: v.detach().clone().cpu() for k, v in model.state_dict().items()})
        if len(samples) >= num_samples:
            break

# Bayesian Model Averaging for Prediction
def predict_bma(model, samples, dataloader):
    model.eval()
    preds = []
    with torch.no_grad():
        for batch_x, _ in dataloader:
            batch_x = batch_x.to(device)
            batch_preds = []
            
            # Sample predictions
            for sample in samples:
                model.load_state_dict(sample)
                outputs = model(batch_x).flatten().cpu().numpy()
                batch_preds.append(outputs)
            
            # Calculate mean and std across samples
            batch_preds = np.array(batch_preds)
            pred_mean = np.mean(batch_preds, axis=0)
            preds.append(pred_mean)
    
    return np.concatenate(preds)

# Generate predictions and evaluate
y_pred = predict_bma(model, samples, test_loader)
test_mse = np.mean((y_pred - y_test) ** 2)
print(f"Test MSE: {test_mse:.4f}")

# Uncertainty quantification example
print("\nUncertainty Quantification (First 5 test points):")
for i in range(5):
    print(f"True: {y_test[i]:.3f}, Pred: {y_pred[i]:.3f}")

  noise_std = torch.sqrt(torch.tensor(2.0 * lr * preconditioner))


RuntimeError: a Tensor with 1280 elements cannot be converted to Scalar