# Day 4: Keeping Neural Networks Simple (MDL)

Hinton & van Camp (1993): "Keeping Neural Networks Simple by Minimizing the Description Length of the Weights"

This paper showed that representing weights as **Gaussian distributions** and minimizing their description length produces the same objective as variational Bayesian inference. The result: networks that find flat, robust minima instead of sharp, brittle ones.

## What this notebook covers

1. **The gappy sine wave** — training on data with a missing region
2. **Training a Bayesian network** — the loss is error + KL divergence (complexity)
3. **Uncertainty visualization** — running multiple forward passes to see where the model is confident vs. uncertain
4. **Weight distributions** — seeing which weights matter (low sigma) and which can be pruned (high sigma)

## The core idea

Standard weights are point estimates (w = 5.123). MDL weights are distributions (w ~ N(5.1, 0.2)). If the network works fine despite weight noise, you didn't need the precision — and the saved precision is literally saved bits of description length.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

# Add parent directory to path for imports
sys.path.insert(0, str(Path.cwd()))

from implementation import MDLNetwork
from visualization import (plot_loss_dynamics, plot_uncertainty_envelope,
                           plot_weight_distributions, plot_snr_analysis,
                           analyze_compression_stats)

print("All imports successful.")

## 1. The Data: A Gappy Sine Wave

To see why Bayesian weights are useful, we need a dataset that exposes the difference between "confident" and "honest" predictions.

We generate a sine wave but **delete the middle part** (|x| < 1). A standard NN will confidently predict through the gap as if it has data there. A Bayesian NN should show high uncertainty in the gap — because it genuinely hasn't seen data there.

In [None]:
# Generate noisy sine wave with a GAP
def generate_gappy_data(n=100):
    # Left side (-3 to -1)
    X1 = np.random.uniform(-3, -1, n//2)
    # Right side (1 to 3)
    X2 = np.random.uniform(1, 3, n//2)
    # Combine
    X = np.concatenate([X1, X2])
    # Add noise
    y = np.sin(X) + np.random.normal(0, 0.1, n)
    return X.reshape(-1, 1), y.reshape(-1, 1)

X_train, y_train = generate_gappy_data(100)

# Visualization range (including the gap)
X_test = np.linspace(-4, 4, 200).reshape(-1, 1)

plt.figure(figsize=(10, 5))
plt.scatter(X_train, y_train, c='red', label='Training Data')
plt.axvspan(-1, 1, color='gray', alpha=0.2, label='The GAP (No Data)')
plt.plot(X_test, np.sin(X_test), 'k--', alpha=0.5, label='True Sine Wave')
plt.title("The Challenge: Predict what happens in the Gap")
plt.legend()
plt.show()

## 2. Training the Bayesian Network

The MDL network has **two** loss terms (this is the paper's core equation):

1. **Error Cost (NLL):** How well do the predictions fit the data?
2. **Complexity Cost (KL):** How many bits to describe the weights?

Total loss = Error + beta * KL

The `kl_weight` (beta) controls the balance. Too low: overfits (no uncertainty). Too high: underfits (ignores data). The right value gives honest uncertainty where data is missing.

In [None]:
# Initialize Network
net = MDLNetwork(input_size=1, hidden_size=20, output_size=1)

# Hyperparameters
epochs = 2000
lr = 0.01
kl_weight = 0.1  # The "Simplicity Pressure"

# Storage for plotting
history = {'total': [], 'nll': [], 'kl': []}

print("Training Bayesian Network...")
print("=" * 50)

for epoch in range(epochs):
    # 1. Forward Pass (This samples random weights!)
    # Every time we call this, the network is slightly different.
    preds = net.forward(X_train)
    
    # 2. Data Loss (MSE as proxy for Negative Log Likelihood)
    nll = np.mean((preds - y_train)**2)
    d_nll = 2 * (preds - y_train) / len(X_train)
    
    # 3. Complexity Loss (KL Divergence)
    kl = net.total_kl() / len(X_train)
    
    # 4. Total Loss
    loss = nll + kl_weight * kl
    
    # Store history
    history['total'].append(loss)
    history['nll'].append(nll)
    history['kl'].append(kl)
    
    # 5. Backward Pass
    net.backward(d_nll)
    
    # 6. Update Weights
    net.update_weights(lr, kl_weight)
    
    if epoch % 200 == 0:
        print(f"Epoch {epoch:4d} | Total: {loss:.4f} | Error: {nll:.4f} | Complexity: {kl:.4f}")

print("\nTraining complete.")

## 3. Loss Dynamics: Complexity vs. Error

Watch how the two loss terms interact during training.
The **Error** (NLL) drops quickly as the network fits the data. The **Complexity** (KL) may rise initially as weights move from the prior, then stabilizes as the network finds a compact solution.

In [None]:
plot_loss_dynamics(history)

## 4. The Uncertainty Envelope

The signature visualization for Bayesian neural networks.

We run the network **100 times** on the test data. Because weights are distributions (not fixed numbers), each forward pass samples different weights and produces a slightly different prediction.

* **Where we have training data:** predictions cluster tightly (low variance).
* **In the gap:** predictions spread out (high variance) — the network honestly reflects that it has no data here.

This spread is the model's epistemic uncertainty, and it comes directly from the Gaussian weight distributions in Hinton & van Camp's formulation.

In [None]:
plot_uncertainty_envelope(net, X_train, y_train, X_test, n_samples=100)

## 5. Weight Distributions

The paper's title is about "Minimizing Description Length of the Weights." Here's what that looks like in practice.

Each weight has a learned sigma (standard deviation):
* **Small sigma:** the network needs this weight to be precise — it carries information about the data.
* **Large sigma:** this weight can be anything — it carries almost zero bits of information.

Weights with large sigma relative to their mean (low signal-to-noise ratio) are effectively "free" under the bits-back coding scheme. They've been compressed away.

In [None]:
plot_weight_distributions(net, 'layer1')
plot_snr_analysis(net)

## 6. Compression Statistics

We can quantify compression by looking at Signal-to-Noise Ratio (SNR = |mu|/sigma).
Weights with SNR below a threshold (e.g., 0.5) carry negligible information — they could be pruned to zero without hurting the model.

In [None]:
analyze_compression_stats(net, threshold_snr=0.5)

## 7. Key Takeaways

**1. Generalization through compression (paper's core argument):**
The KL penalty forces the network to use as few bits as possible to describe its weights. Weights that don't help explain the data get pushed back toward the prior — effectively pruned. What remains is the simplest model that fits the training data, which is exactly the MDL principle.

**2. Honest uncertainty (direct consequence of the formulation):**
A standard network would draw a confident (and wrong) line through the gap in our data. The MDL network's weight distributions produce spread-out predictions where data is missing — it distinguishes "I've seen this" from "I'm guessing."

**3. Noise during training is the mechanism, not a trick:**
Sampling weights from distributions during training is how the bits-back argument works. The noise in the weights is what allows the coding cost to be reduced (Hinton & van Camp, 1993, Section 2).

---

**Next:** Pointer Networks (Vinyals et al., 2015).