# Posterior Network: Uncertainty Estimation without OOD Samples via Density-Based Pseudo-Counts

Posterior Networks (PostNet) extend the idea of Evidential Deep Learning (EDL) by producing a full Dirichlet distribution over class probabilities for each input. However, instead of evidence being directly predicted by the neural network, PostNet does so by deriving evidence from class-conditional density estimates in a latent space. This assures that out-of-distribution (OOD) samples are not needed during training, as uncertainty increases for inputs that lie outside the learned density.

In this notebook, we will:
- Build a small encoder to map inputs into a latent space
- Train a separate normalizing flow per class to model the class-conditional densities
- Convert densities into evidence (pseudo-counts)
- Construct Dirichlet posteriors and evaluate uncertainty
- Compare uncertainty on in-distribution (ID) and OOD samples.


## Imports and Setup

In [6]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T

%uv pip install nflows
from nflows.distributions.normal import StandardNormal
from nflows.flows import Flow
from nflows.nn.nets import MLP
from nflows.transforms import AffineCouplingTransform, CompositeTransform, ReversePermutation

device = "cuda" if torch.cuda.is_available() else "cpu"
device

/Users/nadzuken/probly/.venv/bin/python: No module named uv
Note: you may need to restart the kernel to use updated packages.


'cpu'

## Data Preparation

Posterior Networks require:
- an in-distribution (ID) dataset used for training and standard evaluation
- an out-of-distribution dataset used only for testing epistemic uncertainty

Here, we use **MNIST** as the ID dataset and **FashionMNIST** as the OOD dataset

In [7]:
transform = T.Compose([T.ToTensor()])

# In-distribution data

train_data = torchvision.datasets.MNIST(
    root="~/datasets",
    train=True,
    download=True,
    transform=transform,
)

test_data = torchvision.datasets.MNIST(
    root="~/datasets",
    train=False,
    download=True,
    transform=transform,
)

train_loader    = DataLoader(train_data, batch_size=256, shuffle=True)
test_loader     = DataLoader(test_data, batch_size=256, shuffle=False)

print ("MNIST loaded (ID).")

# Out-of-distribution data

ood_data = torchvision.datasets.FashionMNIST(
    root="~/datasets",
    train=False,
    download=True,
    transform=transform,
)

ood_loader = DataLoader(ood_data, batch_size=256, shuffle=False)


MNIST loaded (ID).


## Model Definition

Posterior Networks are composed of:
1. **Encoder**: maps each image to a low-dimensional latent vector.
2. **Class-conditional normalizing flows**: one flow per class, modeling the density P(z|c) in latent space. These densities provide the evidence used to construct the Dirichlet distribution.

In [8]:
# Encoder: maps images (x) -> latent (z)

class Encoder(nn.Module):
    def __init__(self, latent_dim=2) -> None:  # noqa: ANN001
        """Initializes an instance of the Encoder class."""
        super().__init__()

        self.net = nn.Sequential(
            nn.Flatten(), # turns a 28x28 image into a vector of size 784
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim), # final output = latent vector z
        )

    def forward(self, x) -> None:  # noqa: ANN001, D102
        return self.net(x)

# Normalizing Flow for P(z | c)

class ContextIgnoreNet(nn.Module):
    def __init__(self, in_features, out_features) -> None:  # noqa: ANN001, D107
        super().__init__()
        self.net = MLP(
            in_shape=(in_features,),
            out_shape=(out_features,),
            hidden_sizes=[32, 32],
        )

    def forward(self, x, context=None) -> None:  # noqa: ANN001, ARG002, D102
        return self.net(x)



def make_flow(latent_dim) -> None:  # noqa: ANN001

    # Required function that returns a transform network
    def transform_net_create_fn(in_features, out_features) -> None:  # noqa: ANN001
        return ContextIgnoreNet(in_features, out_features)

    base_dist = StandardNormal([latent_dim])

    transform = CompositeTransform([
        AffineCouplingTransform(
            mask=[0, 1],
            transform_net_create_fn=transform_net_create_fn,
        ),
        ReversePermutation(features=latent_dim),
        AffineCouplingTransform(
            mask=[1, 0],
            transform_net_create_fn=transform_net_create_fn,
        ),
    ])

    return Flow(transform, base_dist)

latent_dim = 2
encoder = Encoder(latent_dim).to(device)
flows = nn.ModuleList([make_flow(latent_dim).to(device) for _ in range(10)])

## Training Function

In [9]:
def train_postnet(encoder, flows, train_loader, epochs=5, lr=1e-3, device="cuda") -> None:  # noqa: ANN001
    encoder.train()
    flows.train()

    # Combine encoder & flows paramters so one optimizer updates all of them
    params = list(encoder.parameters()) + list(flows.parameters())
    optimizer = optim.Adam(params, lr=lr)

    class_counts = torch.zeros(10).to(device)
    for _, y in train_loader:
        for c in range(10):
            class_counts[c] += (y == c).sum()

    for epoch in range(epochs):
        total_loss = 0.0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)  # noqa: PLW2901

            # Encode image -> latent vector
            z = encoder(x)

            # Compute density P(z|c) for all classes
            densities = torch.stack(
                [flow.log_prob(z).exp() for flow in flows], dim=1,
            )

            # Compute pseudo-counts beta
            beta = densities * class_counts

            # Dirichlet parameters alpha
            alpha = beta + 1.0
            alpha0 = alpha.sum(dim=1, keepdim=True)

            # Expected cross-entropy term
            digamma = torch.digamma
            expected_ce = digamma(alpha0) - digamma(alpha[range(len(y)), y])

            # Entropy of Dirichlet
            entropy = -(alpha * (digamma(alpha) - digamma(alpha0))).sum(dim=1)

            # Total loss
            loss = (expected_ce - entropy).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/len(train_loader):.4f}")

## Training Loop

In [10]:
epochs = 5

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    train_postnet(
        encoder, flows, train_loader,
        epochs=1,
        lr=1e-3,
        device=device,
    )


Epoch 1/5
Epoch 1/1 - Loss: -20575.7173

Epoch 2/5
Epoch 1/1 - Loss: -22023.9898

Epoch 3/5
Epoch 1/1 - Loss: -22041.8797

Epoch 4/5
Epoch 1/1 - Loss: -22042.5871

Epoch 5/5
Epoch 1/1 - Loss: -22042.8257
