[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pranavm19/SBI-Tutorial/blob/main/notebooks/02_NFlows.ipynb)

In [1]:
# !python -m pip install sbi corner

In [46]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display

In [18]:
# Recall the two moons model
def two_moons_sbi(theta, sigma=0.01):
    """Generate a two moons posterior"""
    n_samples = theta.shape[0]
    alpha = np.random.uniform(-np.pi/2, np.pi/2, n_samples)
    r = sigma * np.random.randn(n_samples) + 1
    x_1 = r * np.cos(alpha) + 1 - np.abs(theta[:, 0] + theta[:, 1])/np.sqrt(2)
    x_2 = r * np.sin(alpha) + (- theta[:, 0] + theta[:, 1])/np.sqrt(2)

    x =  np.stack([x_1, x_2], axis=-1)

    return x

Outline

0. Intro (why neural SBI, inference and density estimation with neural networks)
1. What are normalizing flow, and how do they work (change of variables formula + intuition building using udl textbook Fig. 16.2)
2. The I/O of normalizing flows - what do we want them to do?
2. Build an affine coupling layer (Real NVP equations, introducing context into the equations, implementing AffineCouplingLayer, NormalizingFlow)
3. Normalizing flows using affine coupling layers and permutation masks (train on two moons model)
4. Reproduce using SBI package
4. Performing inference (corner plots, predictive checks)

### **What are normalizing flows?**

#### Inference and density estimation


#### Change of variables formula


#### Transforming a data distribution to a base density



In [None]:
# Base density, transform, and data distribution

In [None]:
# Make figure 16.2 from the udl book

#### The API

Cool, so we have seen how a bijective transform can be used to "normalize" a given density. Let's make them a bit more powerful 💪

There are two components to a normalizing flow : the bijective transform (or a set of them), and the prior distribution. Once we have these, at train time, we can compute z given a batch of (x, context), and evaluate the loss. At test time, we can draw a sample from (z, context), and generate samples from the learned data distribution x (and evaluate the probability density!!). To do this, we will need:

`NF = NormalizingFlow(flows, prior)`
- `NF.forward(x, context) -> z, ldj`, 
- `NF.sample(z, context) -> x, ldj`

The `flows` object is usually a list of layers each of which is a bijective transform:

`T = FlowTransform()`
- `T.forward(x, context) -> z, ldj`
- `T.inverse(z, context) -> x, ldj`

### **Why are they suitable for Simulation-Based Inference?**

Here, we will specifically deal with the case of Neural Posterior Estimation.

### **Real NVP - Affine Coupling Flows**

A popular implementation of normalizing flows is **Real NVP** (Dinh et al., 2017). The core of Real NVP is the **affine coupling** transformation. 

Suppose we split the data $x$ of dimension $D$ into two parts, $x = [x_{1:d},\, x_{(d+1):D}]$. In a **coupling layer**, we leave one part unchanged and apply a learnable affine transformation to the other part. Specifically, let:

$$
    y_{1:d} \;=\; x_{1:d},
$$
$$
    y_{(d+1):D} \;=\; x_{(d+1):D}\,\odot \exp\bigl(s_\theta(x_{1:d})\bigr)\;+\; t_\theta(x_{1:d}),
$$

where $\odot$ denotes elementwise multiplication. The functions $s_\theta(\cdot)$ and $t_\theta(\cdot)$ (the "scale" and "shift" networks) are typically small neural networks that depend on the "frozen" part $x_{1:d}$.

- **Invertibility**: This transformation is **invertible** because you can solve for $x_{(d+1):D}$ by reversing the shift and scale operations:
  $$
    x_{(d+1):D} 
     = \Bigl(y_{(d+1):D} - t_\theta(y_{1:d})\Bigr)\,\odot \exp\Bigl(- s_\theta(y_{1:d})\Bigr).
  $$
  The log-determinant of the Jacobian $\left\lvert \det \frac{\partial y}{\partial x} \right\rvert$ is simply
  $$
    \sum_{j=1}^{D-d} s_\theta(x_{1:d})_j,
  $$
  because the scaling is diagonal in the sub-block.

- **Why is it non-linear?** 
  Although the transformation is written as an *affine* function for the second block, **the parameters of that affine transformation are themselves neural-network outputs**, i.e., $s_\theta(\cdot)$ and $t_\theta(\cdot)$. This makes the overall mapping
  $$
    x \mapsto y
  $$
  **non-linear** in $x$. The "frozen" part $x_{1:d}$ is feeding through a neural network to produce scale and shift factors, which can be highly non-linear functions of $x_{1:d}$. 

- **Coupling layers and permutations**: In Real NVP, we often interleave such coupling layers with **permutation** layers to ensure that over multiple layers, each dimension eventually appears in the "frozen" part and the "transformed" part. This broadens the flexibility of the flow, letting it model complex dependencies across all dimensions.

In summary, **Real NVP** is a straightforward yet powerful example of how normalizing flows combine tractable Jacobians (via affine coupling) with flexible function approximators (neural networks for scale and shift). 


Question 1. We have seen above that for NPE, we need to train a conditional normalizing flow, where $x$ are the conditioning variables (also called as context), and $\theta$ are primary variables what get normalized. However, in the above equations, we don't see any conditioning variables. Can you modify the forward and backward equations such that show how context is utilized?

Question 2. With all the equations to implement a conditional normalizing flow at hand, let's implement it! Complete the AffineCouplingLayer class below...

In [3]:
class AffineCouplingLayer(nn.Module):
    def __init__(self, input_dim, context_dim):
        super().__init__()
        self.input_dim = input_dim
        self.context_dim = context_dim
        self.split_idx = input_dim - (input_dim // 2) # first part gets more dims if input_dim is odd

        # Define scale and shift networks
        self.scale_net = nn.Sequential(
            nn.Linear(self.split_idx + context_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64), 
            nn.LeakyReLU(),
            nn.Linear(64, input_dim - self.split_idx)
        )
        self.shift_net = nn.Sequential(
            nn.Linear(self.split_idx + context_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64), 
            nn.LeakyReLU(),
            nn.Linear(64, input_dim - self.split_idx)
        )

    def forward(self, x, context):
        # Split input tensor along the last dimension
        x_identity = x[..., :self.split_idx]
        x_transform = x[..., self.split_idx:]

        # Concatenate identity and context for both networks
        # identity_context = # FILL IN THE BLANK
        # scale = # FILL IN THE BLANK
        # shift = # FILL IN THE BLANK

        identity_context = torch.cat((x_identity, context), dim=-1)
        scale = self.scale_net(identity_context)
        shift = self.shift_net(identity_context)
        
        # Compute log-determinant of the Jacobian
        # ldj = # FILL IN THE BLANK
        ldj = torch.sum(scale, dim=-1)

        # Affine transformation on x_transform
        # z_transform = # FILL IN THE BLANK
        
        z_transform = x_transform * torch.exp(scale) + shift

        # Concatenate unchanged part with transformed part
        z = torch.cat((x_identity, z_transform), dim=-1)
        return z, ldj
    
    def inverse(self, z, context):
        # Inverse transform: split z into identity and transformed parts
        z_identity = z[..., :self.split_idx]
        z_transform = z[..., self.split_idx:]

        # Concatenate identity and context, pass them into the networks
        identity_context = 
        scale = 
        shift = 

        identity_context = torch.cat((z_identity, context), dim=-1)
        scale = self.scale_net(identity_context)
        shift = self.shift_net(identity_context)
        
        # Compute log-determinant of the Jacobian
        ldj = # FILL IN THE BLANK
        ldj = -torch.sum(scale, dim=-1)

        # Inverse affine transformation
        # x_transform = # FILL IN THE BLANK
        x_transform = (z_transform - shift) * torch.exp(-scale)

        # Concatenate identity and transformed parts
        x = torch.cat((z_identity, x_transform), dim=-1)
        return x, ldj


In [4]:
class NormalizingFlow(nn.Module):
    """
    A normalizing flow model composed of a sequence of affine coupling layers and a prior distribution.
    """
    def __init__(self, flows, prior=None):
        super().__init__()
        self.flows = nn.ModuleList(flows)
        self.dim = self.flows[0].input_dim
        # Initialize the prior distribution (device will be set correctly later)
        if prior is None:
            self.prior = torch.distributions.MultivariateNormal(
                torch.zeros(self.dim), torch.eye(self.dim))
        else:
            self.prior = prior

        self.train_loss = []

    def forward(self, x, context):
        """
        Applies a sequence of flow transformations and accumulates the log-determinants.
        """
        ldj = torch.zeros(x.shape[0], device=x.device)
        for flow in self.flows:
            x, ldj_ = flow(x, context)
            ldj += ldj_
        return x, ldj

    def inverse(self, z, context):
        """
        Inverts the flow transformation from latent space back to the input space.
        """
        ldj = torch.zeros(z.shape[0], device=z.device)
        for flow in reversed(self.flows):
            z, ldj_ = flow.inverse(z, context)
            ldj += ldj_  # log-determinants are already negated in inverse
        return z, ldj

    @torch.no_grad()
    def sample(self, num_samples, context):
        """
        Generate samples from the model given a context.
        """
        device = next(self.parameters()).device
        z = self.prior.sample((num_samples,)).to(device)
        x, _ = self.inverse(z, context)
        return x

    def log_prob(self, x, context): 
        """
        Compute the log probability of x under the flow model.
        """
        z, ldj = self(x, context)
        log_pz = self.prior.log_prob(z)
        return log_pz + ldj


In [5]:
class PermutationLayer(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        # Create a random permutation for the feature indices.
        perm = torch.randperm(num_features)
        self.register_buffer("perm", perm)
        self.register_buffer("inv_perm", torch.argsort(perm))
    
    def forward(self, x, context):
        # Permuting the features; no effect on the log-determinant.
        x_permuted = x[..., self.perm]
        # Log-determinant is zero for a permutation
        log_det = torch.zeros(x.size(0), device=x.device)
        return x_permuted, log_det
    
    def inverse(self, x, context):
        # Inverse permutation
        x_inv = x[..., self.inv_perm]
        log_det = torch.zeros(x.size(0), device=x.device)
        return x_inv, log_det


In [45]:
# Handle dimensions
input_dim = 2
context_dim = 2
n_layers = 4
flows = []

# Define the model and optimizer
for i in range(n_layers):
    flows.append(AffineCouplingLayer(input_dim, context_dim))
    flows.append(PermutationLayer(input_dim))

model = NormalizingFlow(flows)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# Fix an observed data point x_obs
x_obs = torch.tensor([[0.0, 0.0]], dtype=torch.float32)

# Create a grid of theta-values over which to evaluate the posterior
n_samples = 100
theta0_vals = torch.linspace(-2, 2, n_samples)
theta1_vals = torch.linspace(-2, 2, n_samples)
TH0, TH1 = torch.meshgrid(theta0_vals, theta1_vals, indexing='xy')
theta_grid = torch.cat([TH0.reshape(-1,1), TH1.reshape(-1,1)], dim=1)

# Training settings
num_iter = 5000
num_update_iter = 100
batch_size = 256
losses = []

# Prepare figure
fig, (ax_loss, ax_posterior) = plt.subplots(1, 2, figsize=(10, 4))
plt.ion()

for i in range(num_iter):
    # Generate data
    theta = np.random.uniform(-4, 4, size=(batch_size, 2))
    x = two_moons_sbi(theta)

    # Standard training step
    x = torch.tensor(x, dtype=torch.float32)
    theta = torch.tensor(theta, dtype=torch.float32)
    optimizer.zero_grad()
    loss = -model.log_prob(x, theta).mean()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    # Update plots interactively
    if i % num_update_iter == 0:
        ax_loss.cla()
        ax_posterior.cla()

        # Training loss 
        ax_loss.plot(losses, label='Train Loss')
        ax_loss.set_title('Training Loss')
        ax_loss.set_xlabel('Iteration')
        ax_loss.set_ylabel('Negative Log-Likelihood')
        ax_loss.legend()

        # Approximate posterior 
        with torch.no_grad():
            # Replicate x_obs for every point in theta_grid so the shape matches
            x_obs_tiled = x_obs.repeat(theta_grid.shape[0], 1)

            # Posterior ~ exp(log p(x_obs | theta))
            post_vals = model.log_prob(x_obs_tiled, theta_grid).exp()
            post_2d = post_vals.view(n_samples, n_samples)

        # Contour-plot the posterior in theta-space
        c = ax_posterior.contourf(
            TH0.numpy(), TH1.numpy(), post_2d.numpy(),
            levels=50, alpha=0.8
        )
        ax_posterior.set_title(f'Posterior at iteration {i}')
        ax_posterior.set_xlabel(r'$\theta_0$')
        ax_posterior.set_ylabel(r'$\theta_1$')
        
        # Optionally add a colorbar if you like
        # fig.colorbar(c, ax=ax_posterior)

        clear_output(wait=True)
        display(fig)

plt.ioff();

In [None]:
# The really cool thing is that we don't need to re-run the algorithm
# to infer posterior for a new data point! (more or less...)
x_obs = torch.tensor([[0.1, 0.1]], dtype=torch.float32)

# Create a grid of theta-values over which to evaluate the posterior
n_samples = 100
theta0_vals = torch.linspace(-3, 3, n_samples)
theta1_vals = torch.linspace(-3, 3, n_samples)
TH0, TH1 = torch.meshgrid(theta0_vals, theta1_vals, indexing='xy')
theta_grid = torch.cat([TH0.reshape(-1,1), TH1.reshape(-1,1)], dim=1)

fig, ax = plt.subplots(1, 1, figsize=[4, 4])

with torch.no_grad():
    # Replicate x_obs for every point in theta_grid so the shape matches
    x_obs_tiled = x_obs.repeat(theta_grid.shape[0], 1)

    # Posterior ~ exp(log p(x_obs | theta))
    post_vals = model.log_prob(x_obs_tiled, theta_grid).exp()
    post_2d = post_vals.view(n_samples, n_samples)

# Contour-plot the posterior in theta-space
c = ax.contourf(
    TH0.numpy(), TH1.numpy(), post_2d.numpy(),
    levels=50, alpha=0.8
)
ax.set_title(f'Posterior at iteration {i}')
ax.set_xlabel(r'$\theta_0$')
ax.set_ylabel(r'$\theta_1$')

plt.show()

In [None]:
# Plot posterior using corner

### **Repeat using `sbi` toolbox**



In [None]:
# Use sbi for the same analysis

### **Posterior checks**

### **Outro**

#### Additional reading
