# Optimal Brain Surgeon (OBS)

## Introduction

**Which parameters should be pruned in a neural network?**

**Current gap**: Magnitude pruning and OBD prune based on weight size or diagonal curvature, which ignores interactions between weights. This often leads to removing important weights and requires retraining to recover performance.

**Improvement:** OBS uses the full Hessian inverse to evaluate true weight importance and applies a compensation update after pruning. This preserves the networkâ€™s function, avoids pruning the wrong weights, and eliminates the need for retraining.

## Approach

1. Train a "reasonably" large network to minimum error 
2. Compute inverse of Hessian
3. Find q that gives smallest saliency. If the increase in error from such removal is much smaller than E, then delete and continue. Otherwise, go to 5.
4. Use q from previous step to update all the weights. Go to step 2.
5. No more weights can be deleted without large increase in E. (At this point it may be desirable to retrain the network). 

## Result

Finds right weights to remove, whereas sometimes OBD finds incorrect weights

## Optimization Math

$$
\delta E 
= \left( \frac{\partial E}{\partial w} \right)^{\!\top} 
\delta w 
+ \frac{1}{2} \delta w^\top H \, \delta w 
+ \mathcal{O}(\|\delta w\|^3)
$$

$$
\delta E 
= \cancel{\left( \frac{\partial E}{\partial w} \right)^{\!\top} \delta w}
+ \frac{1}{2} \delta w^\top H \, \delta w 
+ \cancel{\mathcal{O}(\|\delta w\|^3)}
$$


$$
\min_{\delta w} \;\; \frac{1}{2}\,\delta w^\top H \,\delta w
\qquad \text{subject to} \qquad \text{we remove one weight completely}
$$

$$
\delta w_q = -\,w_q 
\qquad\text{or equivalently}\qquad
e_q^\top \delta w = -\,w_q
$$

$$
\min_{\delta w} \;\; \frac{1}{2}\,\delta w^\top H \,\delta w
\qquad \text{subject to} \qquad
e_q^\top \delta w + w_q = 0
$$

$$
\mathcal{L} 
= \frac{1}{2} \delta w^\top H \, \delta w 
\;+\; \lambda \left( e_q^\top \delta w + w_q \right)
$$


Take gradient w.r.t. $\delta w$:
$$
\frac{\partial \mathcal{L}}{\partial (\delta w)} 
= H \delta w + \lambda e_q = 0
$$

Thus,
$$
\delta w = -\lambda H^{-1} e_q \tag{1}
$$

Now take gradient w.r.t. $\lambda$:
$$
\frac{\partial \mathcal{L}}{\partial \lambda}
= e_q^\top \delta w + w_q = 0
$$

Substitute (1):
$$
e_q^\top (-\lambda H^{-1} e_q) + w_q = 0
$$

$$
-\lambda \, e_q^\top H^{-1} e_q + w_q = 0
$$

Thus,
$$
\lambda = -\frac{w_q}{e_q^\top H^{-1} e_q} \tag{2}
$$

Plug (2) into (1):
$$
\delta w
= \frac{w_q H^{-1} e_q}{e_q^\top H^{-1} e_q}
= \frac{w_q H^{-1} e_q}{(H^{-1})_{qq}}
$$


Now compute loss increase $L_q$:
$$
L_q 
= \frac{1}{2} \delta w^\top H \delta w 
+ \lambda( e_q^\top \delta w + w_q )
$$

Constraint is zero, so:
$$
L_q = \frac{1}{2} \delta w^\top H \delta w
$$

Substitute $\delta w = \frac{w_q}{(H^{-1})_{qq}} H^{-1} e_q$:
$$
L_q 
= \frac{1}{2}
\left( 
\frac{w_q H^{-1} e_q}{(H^{-1})_{qq}}
\right)^\top
H
\left( 
\frac{w_q H^{-1} e_q}{(H^{-1})_{qq}}
\right)
$$

Factor out $\frac{w_q^2}{(H^{-1})_{qq}^2}$:
$$
L_q 
= \frac{w_q^2}{2 (H^{-1})_{qq}^2} 
\left( e_q^\top H^{-1} H H^{-1} e_q \right)
$$

Since $H^{-1} H H^{-1} = H^{-1} \text{ and } H^{-1}$ is symmetric:
$$
L_q
= \frac{w_q^2}{2 (H^{-1})_{qq}^2}
\left( e_q^\top H^{-1} e_q \right)
$$

But:
$$
e_q^\top H^{-1} e_q = (H^{-1})_{qq}
$$

Therefore:
$$
L_q = 
\frac{1}{2} \frac{w_q^2}{(H^{-1})_{qq}}
$$
