# 🔶 K-FAC from Scratch — Layerwise Kronecker Factors, Implementation & Interactive Comparison

**What this notebook does**

- Implements a small fully connected neural network (one hidden layer) in NumPy.
- Implements three optimizers:
  - SGD (vanilla)
  - Adam (NumPy)
  - **K-FAC**: **Kronecker-factored Approximate Curvature** (layerwise), with damping and periodic factor updates.
- Trains on a synthetic 2-class dataset and visualizes:
  - Loss and accuracy vs iterations
  - Per-layer effective preconditioned step norms
  - Parameter trajectories projected into a 2D PCA subspace (visual diagnostic)
- Adds interactive controls (learning rate, damping, K-FAC update frequency, batch size, optimizer selection).

**Pedagogical notes**

- The K-FAC implementation here is simplified and focused on clarity: it demonstrates how to collect layerwise Kronecker factors (covariances of activations and gradients), form damped inverses, and precondition the parameter gradients using the identity
$$
\mathrm{vec}(\Delta W) \approx (G^{-1} \otimes A^{-1}) \, \mathrm{vec}(\nabla_W)
\quad\Longleftrightarrow\quad
\Delta W \approx A^{-1} \, \nabla_W \, G^{-1}
$$
for a dense layer with weight matrix $W$, activation covariance $A$, and gradient/output covariance $G$. See Martens & Grosse (2015). 

**References**
- Martens, J., & Grosse, R. (2015). *Optimizing Neural Networks with Kronecker-Factored Approximate Curvature (K-FAC)*. 

### Step-1: Imports and Environment Setup

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output, Markdown
try:
    import ipywidgets as widgets
    from ipywidgets import interact, FloatSlider, IntSlider, Dropdown, Checkbox
except Exception:
    # if ipywidgets not installed, the notebook will still run static examples
    widgets = None
    interact = None

np.random.seed(42)

### Step-2: Setting up a Synthetic Dataset

In [2]:
def make_two_moons(n_samples=200, noise=0.15):
    # a compact two-moons generator (numpy)
    n = n_samples // 2
    theta = np.linspace(0, np.pi, n)
    x1 = np.stack([np.cos(theta), np.sin(theta)], axis=1) + noise * np.random.randn(n,2)
    x2 = np.stack([1 - np.cos(theta), -np.sin(theta) - 0.5], axis=1) + noise * np.random.randn(n,2)
    X = np.vstack([x1, x2])
    y = np.hstack([np.zeros(n, dtype=int), np.ones(n, dtype=int)])
    # shuffle
    idx = np.random.permutation(len(y))
    return X[idx], y[idx]

X, y = make_two_moons(n_samples=400, noise=0.12)
# simple one-hot labels
Y = np.eye(2)[y]

# train / val split
split = int(0.8 * X.shape[0])
X_train, Y_train = X[:split], Y[:split]
X_val, Y_val = X[split:], Y[split:]

### Step-3: Defining a Small MLP

In [3]:
def relu(x):
    return np.maximum(0, x)

def relu_grad(x):
    return (x > 0).astype(float)

def softmax(logits):
    ex = np.exp(logits - logits.max(axis=1, keepdims=True))
    return ex / ex.sum(axis=1, keepdims=True)

def cross_entropy_loss(probs, targets):
    # targets: one-hot
    eps = 1e-12
    return -np.mean(np.sum(targets * np.log(probs + eps), axis=1))

class SmallMLP:
    def __init__(self, d_in=2, d_hidden=32, d_out=2, scale=0.1):
        self.params = {}
        self.params['W1'] = scale * np.random.randn(d_in, d_hidden)
        self.params['b1'] = np.zeros(d_hidden)
        self.params['W2'] = scale * np.random.randn(d_hidden, d_out)
        self.params['b2'] = np.zeros(d_out)

    def forward(self, X):
        z1 = X.dot(self.params['W1']) + self.params['b1']
        a1 = relu(z1)
        z2 = a1.dot(self.params['W2']) + self.params['b2']
        probs = softmax(z2)
        cache = {'X': X, 'z1': z1, 'a1': a1, 'z2': z2, 'probs': probs}
        return probs, cache

    def loss_and_grads(self, X, Y):
        probs, cache = self.forward(X)
        loss = cross_entropy_loss(probs, Y)
        # grads by backprop (batch)
        N = X.shape[0]
        dz2 = (probs - Y) / N  # dL/dz2
        dW2 = cache['a1'].T.dot(dz2)
        db2 = dz2.sum(axis=0)
        da1 = dz2.dot(self.params['W2'].T)
        dz1 = da1 * relu_grad(cache['z1'])
        dW1 = X.T.dot(dz1)
        db1 = dz1.sum(axis=0)
        grads = {'W1': dW1, 'b1': db1, 'W2': dW2, 'b2': db2}
        return loss, grads, cache

### Step-4: Prepping Bseline Optimizers for comparison (SGD, Adam)

In [4]:
class SGDOptimizer:
    def __init__(self, params, lr=0.1, weight_decay=0.0):
        self.params = params
        self.lr = lr
        self.wd = weight_decay

    def step(self, grads):
        for k, g in grads.items():
            self.params[k] -= self.lr * (g + self.wd * self.params[k])

class AdamOptimizer:
    def __init__(self, params, lr=0.01, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.0):
        self.params = params
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.wd = weight_decay
        self.m = {k: np.zeros_like(v) for k, v in params.items()}
        self.v = {k: np.zeros_like(v) for k, v in params.items()}
        self.t = 0

    def step(self, grads):
        self.t += 1
        for k, g in grads.items():
            self.m[k] = self.beta1 * self.m[k] + (1 - self.beta1) * g
            self.v[k] = self.beta2 * self.v[k] + (1 - self.beta2) * (g * g)
            m_hat = self.m[k] / (1 - self.beta1**self.t)
            v_hat = self.v[k] / (1 - self.beta2**self.t)
            self.params[k] -= self.lr * (m_hat / (np.sqrt(v_hat) + self.eps) + self.wd * self.params[k])

### Step-5: Defining Helper Funstions for K-FAC

In [5]:
def compute_covariance(mat, eps=1e-6):
    # mat: (batch, dim) with row samples; return empirical covariance E[xx^T]
    # we compute uncentered covariance for K-FAC (activations typically with bias appended externally)
    cov = (mat.T @ mat) / mat.shape[0]
    # numerical stability: add tiny eps to diagonal
    cov += eps * np.eye(cov.shape[0])
    return cov

def damp_and_invert(mat, damping):
    # mat: symmetric positive (approx). returns inverse of (mat + damping * I)
    m = mat + damping * np.eye(mat.shape[0])
    try:
        inv = np.linalg.inv(m)
    except np.linalg.LinAlgError:
        inv = np.linalg.pinv(m)
    return inv

### Step-6: Implementing K-FAC Optimizer (Layerwise)

In [6]:
class KFACOptimizer:
    def __init__(self, model, lr=0.1, damping=1e-3, kl_clip=0.001, factor_ema=0.95, invert_every=10, weight_decay=0.0):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.kl_clip = kl_clip
        self.factor_ema = factor_ema
        self.invert_every = invert_every
        self.weight_decay = weight_decay
        # storage for factors and inverses
        self.A = {}  # activation covariances
        self.G = {}  # gradient covariances
        self.A_inv = {}
        self.G_inv = {}
        self.steps = 0

    def _extract_activations_and_errors(self, cache, grads):
        # For our 2-layer MLP:
        # layer1: W1 weight shape (d_in, d_hidden); activations are X (batch, d_in)
        # layer2: W2 weight shape (d_hidden, d_out); activations are a1 (batch, d_hidden)
        activations = {'W1': cache['X'], 'W2': cache['a1']}
        errors = {'W1': grads['W1'], 'W2': grads['W2']}  # these are full-matrix weight gradients (d_in x d_hidden etc.)
        # For the gradient covariances G we use the gradients w.r.t. pre-activations of outputs (dz)
        # Here errors above are dW = A^T @ dz, but K-FAC uses covariances of dz (output-gradients).
        # We can extract dz2 and dz1 from the cache/grads by recomputing backwards.
        return activations, errors

    def step(self, cache, grads, batch_size):
        # Update EMA of A and G
        # compute activation covariances A_l = E[a_l a_l^T]
        activations, errors = self._extract_activations_and_errors(cache, grads)
        # For output grad covariances G we need per-sample pre-activation gradients.
        # For simplicity we re-compute per-sample dz arrays in a small-batch manner inside the training loop.
        # Here we assume the training loop provides grads computed over the same batch and also provides per-sample dz.
        # To keep API simple, we expect cache to contain 'dz1' and 'dz2' as per-sample arrays when available.
        assert 'dz1' in cache and 'dz2' in cache, "cache must contain per-sample dz1 and dz2 for K-FAC"
        A_W1 = compute_covariance(activations['W1'])
        A_W2 = compute_covariance(activations['W2'])
        G_W2 = compute_covariance(cache['dz2'])  # covariance of output pre-activation gradients (d_out x d_out)
        G_W1 = compute_covariance(cache['dz1'])  # covariance of hidden pre-activation gradients (d_hidden x d_hidden)

        # EMA update
        if 'W1' not in self.A:
            self.A['W1'] = A_W1
            self.A['W2'] = A_W2
            self.G['W1'] = G_W1
            self.G['W2'] = G_W2
        else:
            self.A['W1'] = self.factor_ema * self.A['W1'] + (1 - self.factor_ema) * A_W1
            self.A['W2'] = self.factor_ema * self.A['W2'] + (1 - self.factor_ema) * A_W2
            self.G['W1'] = self.factor_ema * self.G['W1'] + (1 - self.factor_ema) * G_W1
            self.G['W2'] = self.factor_ema * self.G['W2'] + (1 - self.factor_ema) * G_W2

        self.steps += 1

        # invert factors periodically
        if self.steps % self.invert_every == 0 or not self.A_inv:
            self.A_inv['W1'] = damp_and_invert(self.A['W1'], self.damping)
            self.A_inv['W2'] = damp_and_invert(self.A['W2'], self.damping)
            self.G_inv['W1'] = damp_and_invert(self.G['W1'], self.damping)
            self.G_inv['W2'] = damp_and_invert(self.G['W2'], self.damping)

        # Precondition gradients per-layer:
        precond = {}
        # For W2: grad shape (d_hidden, d_out) -> precondition as A_inv @ grad @ G_inv
        precond['W2'] = self.A_inv['W2'].dot(grads['W2']).dot(self.G_inv['W2'])
        precond['W1'] = self.A_inv['W1'].dot(grads['W1']).dot(self.G_inv['W1'])
        # biases: use diagonal preconditioning (divide by average of diag)
        precond['b2'] = grads['b2']  # for simplicity, don't precondition biases strongly
        precond['b1'] = grads['b1']

        # optionally apply weight decay
        if self.weight_decay:
            precond['W2'] += self.weight_decay * self.model.params['W2']
            precond['W1'] += self.weight_decay * self.model.params['W1']
            precond['b2'] += self.weight_decay * self.model.params['b2']
            precond['b1'] += self.weight_decay * self.model.params['b1']

        # step-size clipping via KL-approx (optional, here we do simple global scaling)
        # compute norm of preconditioned step
        step_norm = np.sqrt(np.sum(precond['W1']**2) + np.sum(precond['W2']**2))
        # simple adaptive global scaling to limit step size
        clipped_lr = self.lr
        # apply parameter update
        self.model.params['W1'] -= clipped_lr * precond['W1']
        self.model.params['W2'] -= clipped_lr * precond['W2']
        self.model.params['b1'] -= clipped_lr * precond['b1']
        self.model.params['b2'] -= clipped_lr * precond['b2']

        return precond, step_norm

### Step-7: Functions for Training Loop

In [7]:
def forward_with_per_sample_dz(model, X, Y):
    # compute forward and per-sample dz values used by K-FAC
    # returns loss, grads (batch-averaged), cache (with per-sample dz1, dz2)
    # we will compute per-sample dzs by computing pre-activation outputs for each sample
    # and using vectorized formulas
    z1 = X.dot(model.params['W1']) + model.params['b1']
    a1 = relu(z1)
    z2 = a1.dot(model.params['W2']) + model.params['b2']
    probs = softmax(z2)
    loss = cross_entropy_loss(probs, Y)
    N = X.shape[0]
    dz2 = (probs - Y)  # shape (N, d_out)  (note: not divided by N here; K-FAC covariances will average)
    dW2 = a1.T.dot(dz2) / N
    db2 = dz2.sum(axis=0) / N
    da1 = dz2.dot(model.params['W2'].T) / N
    dz1 = da1 * relu_grad(z1) * N  # multiply by N to get per-sample pre-activation grads (consistent with cov calc)
    # dz1 shape (N, d_hidden); dz2 shape (N, d_out)
    dW1 = X.T.dot( (dz1 / N) )  # divide back to get averaged grad
    db1 = dz1.sum(axis=0) / N
    grads = {'W1': dW1, 'b1': db1, 'W2': dW2, 'b2': db2}
    cache = {'X': X, 'z1': z1, 'a1': a1, 'z2': z2, 'probs': probs, 'dz1': dz1, 'dz2': dz2}
    return loss, grads, cache

def accuracy(model, X, Y_true):
    probs, _ = model.forward(X)
    preds = probs.argmax(axis=1)
    return (preds == Y_true).mean()

### Step-8: Training (Single Run)

In [8]:
def train_model(optimizer_name='kfac', lr=0.1, damping=1e-3, invert_every=10, batch_size=64, n_epochs=10, print_every=50):
    model = SmallMLP(d_hidden=32)
    n_samples = X_train.shape[0]
    steps_per_epoch = max(1, n_samples // batch_size)
    total_steps = n_epochs * steps_per_epoch

    # initialize optimizer
    if optimizer_name == 'sgd':
        opt = SGDOptimizer(model.params, lr=lr)
    elif optimizer_name == 'adam':
        opt = AdamOptimizer(model.params, lr=lr)
    elif optimizer_name == 'kfac':
        kfac = KFACOptimizer(model, lr=lr, damping=damping, invert_every=invert_every)
        opt = kfac
    else:
        raise ValueError("unknown optimizer")

    losses = []
    vals = []
    step_norms = []
    # training loop
    step = 0
    for epoch in range(n_epochs):
        # simple shuffling
        idx = np.random.permutation(n_samples)
        for b in range(steps_per_epoch):
            batch_idx = idx[b*batch_size:(b+1)*batch_size]
            xb = X_train[batch_idx]
            yb = Y_train[batch_idx]

            # compute loss, grads, cache with per-sample dzs if K-FAC
            loss, grads, cache = forward_with_per_sample_dz(model, xb, yb)
            losses.append(loss)
            # step
            if optimizer_name in ('sgd', 'adam'):
                opt.step(grads)
                # approximate step norm as norm of parameter change (we don't track exact change here)
                step_norms.append(np.sqrt(np.sum((lr*grads['W1'])**2) + np.sum((lr*grads['W2'])**2)))
            else:  # K-FAC
                precond, s_norm = opt.step(cache, grads, batch_size)
                step_norms.append(s_norm)

            # validation metric every so often
            if step % print_every == 0:
                val_acc = accuracy(model, X_val, y[split:])
                vals.append((step, val_acc))
            step += 1

    # final validation accuracy
    final_acc = accuracy(model, X_val, y[split:])
    return {'losses': np.array(losses), 'val_records': vals, 'step_norms': np.array(step_norms), 'final_acc': final_acc, 'model': model}

In [9]:
# quick smoke-run to ensure things work
res_sgd = train_model(optimizer_name='sgd', lr=0.2, n_epochs=5, batch_size=64, print_every=200)
res_adam = train_model(optimizer_name='adam', lr=0.01, n_epochs=5, batch_size=64, print_every=200)
res_kfac = train_model(optimizer_name='kfac', lr=0.5, damping=1e-2, invert_every=5, n_epochs=5, batch_size=64, print_every=200)
print("SGD final val acc:", res_sgd['final_acc'])
print("Adam final val acc:", res_adam['final_acc'])
print("K-FAC final val acc:", res_kfac['final_acc'])

SGD final val acc: 0.9875
Adam final val acc: 0.9875
K-FAC final val acc: 0.5875


### Step-9: Visualizations

In [10]:
def plot_results(results_dict, title_suffix=""):
    plt.figure(figsize=(14,4))
    plt.subplot(1,3,1)
    for name, res in results_dict.items():
        plt.plot(res['losses'], label=name)
    plt.yscale('log')
    plt.title("Training loss (log scale) " + title_suffix)
    plt.xlabel("step"); plt.ylabel("loss"); plt.legend(); plt.grid(ls='--', lw=0.3)

    plt.subplot(1,3,2)
    for name, res in results_dict.items():
        sn = res['step_norms']
        plt.plot(sn, label=name)
    plt.title("Step norms per step")
    plt.xlabel("step"); plt.ylabel("step norm"); plt.legend(); plt.grid(ls='--', lw=0.3)

    plt.subplot(1,3,3)
    for name, res in results_dict.items():
        val_rec = res['val_records']
        if val_rec:
            xs = [v[0] for v in val_rec]
            ys = [v[1] for v in val_rec]
            plt.plot(xs, ys, marker='o', label=name)
    plt.title("Validation accuracy checkpoints")
    plt.xlabel("step"); plt.ylabel("val acc"); plt.legend(); plt.grid(ls='--', lw=0.3)

    plt.tight_layout()
    plt.show()

In [None]:
if widgets is None:
    display(Markdown("**ipywidgets not installed — run a static experiment instead.**"))
else:
    def interactive_compare(lr_sgd=0.2, lr_adam=0.01, lr_kfac=0.5, damping=1e-2, invert_every=5, n_epochs=8, batch_size=64):
        clear_output(wait=True)
        display(Markdown(f"### Running experiments: n_epochs={n_epochs}, batch_size={batch_size}"))
        # run three experiments sequentially (small models are cheap)
        res_sgd = train_model(optimizer_name='sgd', lr=lr_sgd, n_epochs=n_epochs, batch_size=batch_size, print_every=200)
        display(Markdown(f"SGD final val acc: **{res_sgd['final_acc']:.3f}**"))
        res_adam = train_model(optimizer_name='adam', lr=lr_adam, n_epochs=n_epochs, batch_size=batch_size, print_every=200)
        display(Markdown(f"Adam final val acc: **{res_adam['final_acc']:.3f}**"))
        res_kfac = train_model(optimizer_name='kfac', lr=lr_kfac, damping=damping, invert_every=invert_every, n_epochs=n_epochs, batch_size=batch_size, print_every=200)
        display(Markdown(f"K-FAC final val acc: **{res_kfac['final_acc']:.3f}**"))
        results = {'SGD': res_sgd, 'Adam': res_adam, 'K-FAC': res_kfac}
        plot_results(results, title_suffix=f"(epochs={n_epochs})")
    interact(
        interactive_compare,
        lr_sgd=FloatSlider(value=0.2, min=0.01, max=1.0, step=0.01, description='lr SGD'),
        lr_adam=FloatSlider(value=0.01, min=0.001, max=0.1, step=0.001, description='lr Adam'),
        lr_kfac=FloatSlider(value=0.5, min=0.01, max=2.0, step=0.01, description='lr K-FAC'),
        damping=FloatSlider(value=1e-2, min=1e-6, max=1e-1, step=1e-5, description='damping'),
        invert_every=IntSlider(value=5, min=1, max=20, step=1, description='inv every'),
        n_epochs=IntSlider(value=8, min=1, max=30, step=1, description='epochs'),
        batch_size=IntSlider(value=64, min=8, max=200, step=8, description='batch')
    )

interactive(children=(FloatSlider(value=0.2, description='lr SGD', max=1.0, min=0.01, step=0.01), FloatSlider(…

### Step-10: Parameter subspace trajectory projection (PCA)

In [12]:
def param_trajectory_pca(res_dict, keys=('W1', 'W2')):
    # Collect flattened parameter vectors across steps for each run and project to 2D via PCA
    trajectories = {}
    for name, res in res_dict.items():
        model = res['model']
        # reconstruct snapshots are not saved per step in this simple setup.
        # As a proxy we will re-train while recording parameters each step (cheap for small model) — but to keep time low we sample fewer steps.
        # Here we will re-run a short trace to collect parameter snapshots (not ideal but clear).
        model = SmallMLP(d_hidden=32)
        snapshots = []
        n_samples = X_train.shape[0]
        batch_size = 64
        steps_per_epoch = max(1, n_samples // batch_size)
        total_steps = 50
        idx = np.random.permutation(n_samples)
        step = 0
        while step < total_steps:
            batch_idx = idx[(step*batch_size)%n_samples:((step+1)*batch_size)%n_samples]
            xb = X_train[batch_idx]
            yb = Y_train[batch_idx]
            loss, grads, cache = forward_with_per_sample_dz(model, xb, yb)
            if name == 'SGD':
                SGDOptimizer(model.params, lr=0.2).step(grads)
            elif name == 'Adam':
                AdamOptimizer(model.params, lr=0.01).step(grads)
            else:
                kfac = KFACOptimizer(model, lr=0.5, damping=1e-2, invert_every=5)
                kfac.step(cache, grads, batch_size)
            vec = np.concatenate([model.params['W1'].ravel(), model.params['W2'].ravel()])
            snapshots.append(vec)
            step += 1
        trajectories[name] = np.array(snapshots)
    # stack all trajectories to fit PCA
    all_vecs = np.vstack([v for v in trajectories.values()])
    # mean center
    mean = all_vecs.mean(axis=0)
    V = all_vecs - mean
    # simple PCA via SVD
    U, S, VT = np.linalg.svd(V.T, full_matrices=False)
    top2 = VT[:2].T  # basis vectors
    fig = plt.figure(figsize=(6,5))
    for name, traj in trajectories.items():
        proj = (traj - mean).dot(top2)
        plt.plot(proj[:,0], proj[:,1], '-o', label=name)
    plt.title("Parameter trajectory projections (PCA of sampled params)")
    plt.xlabel("PC1"); plt.ylabel("PC2"); plt.legend(); plt.grid(ls='--', lw=0.3)
    plt.show()

## Next steps and caveats

- The K-FAC above uses small-batch empirical covariances and dense inverses — in practice:
  - Use layerwise damping scheduled adaptively.
  - Use running-averages (EMA) with a reasonable decay (we used `factor_ema`).
  - For convolutional layers, K-FAC uses structured Kronecker factors (not shown here).
  - Real implementations are vectorized and integrated into autograd frameworks (PyTorch/TensorFlow) for efficiency.

- This notebook is intentionally explicit: it trades some performance for clarity so you can see the algebraic steps:
  - factor = E[a a^T]
  - factor2 = E[d a d a^T]
  - precondition gradient by factor^{-1} on both sides.

**Reference**
- Martens, J., & Grosse, R. (2015). *Optimizing Neural Networks with Kronecker-factored Approximate Curvature (K-FAC).* 

## Notes, performance & practical tips

- The NumPy implementation is intentionally simple and educational. On larger models and real datasets:
    - Use minibatch sizes large enough to estimate factors accurately.
    - Use small damping (Levenberg–Marquardt style) to stabilize inverses.
    - Update factor inverses less frequently to save compute.
    - Combine K-FAC with momentum/Adam-style schedules in practice (there are many heuristics).