# 3. Sharpness-aware minimization (SAM) wirh Sparse Networks - 25 points

## 3.1 Get a sparse networks through pruning
**Pruning** is a technique used to reduce the size and complexity of a neural network model by removing (setting to zero) less important parameters. The goal is to create a more efficient model that retains its predictive accuracy while being smaller, which can improve both inference speed and memory usage.

Let's train a simple model on the MNIST dataset to learn about pruning at first. We just use 10% of the dataset for both training and testing.

### 3.1.1 Train a dense network with SGD
Let us first train a dense model with SGD. We reuse the model for the discriminator of the GAN in Homework 2 and name it 'Classifier'.

In [1]:
from lib.part3.utils import *
max_epochs = 10
device = "cpu" # Change this if you can and want to use a GPU device
model = Classifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)



We define the optimizing process of SGD.

In [2]:
def optimize_sgd(model, optimizer, img, label):
    optimizer.zero_grad()
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()
    optimizer.step()

The following cell runs the training loop, this might take a few minutes.

In [3]:
train_model(model, optimizer, optimize_sgd, max_epochs)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 0 with 0.9325000048 accuracy on the validation set.
Epoch 1 with 0.9508333206 accuracy on the validation set.
Epoch 2 with 0.9566666484 accuracy on the validation set.
Epoch 3 with 0.959166646 accuracy on the validation set.
Epoch 4 with 0.9616666436 accuracy on the validation set.
Epoch 5 with 0.9633333087 accuracy on the validation set.
Epoch 6 with 0.9666666389 accuracy on the validation set.
Epoch 7 with 0.9683333039 accuracy on the validation set.
Epoch 8 with 0.9708333611 accuracy on the validation set.
Epoch 9 with 0.9725000262 accuracy on the validation set.


#### Evaluate model
Evaluate the model on the test set.

In [4]:
acc = evaluate(model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9773 on the test set.


### 3.1.2 Sparse network with magnitude-based pruning

Magnitude-based pruning specifically focuses on **removing weights that have the smallest absolute values**, under the assumption that weights with smaller magnitudes contribute less to the model's output.

**(1)** (6 points) Realize magnitude-based pruning below, which removes a part of weights that have the smallest absolute values.

In [5]:
def magnitude_prune(model, prune_fraction):
    for name, param in model.named_parameters():
        if "weight" in name and param.requires_grad:
            # FILL: Get weight's absolute values
            weight_abs = param.data.abs()
            # FILL: Compute the threshold

            n_total = weight_abs.numel()
            n_prune = int(prune_fraction * n_total)

            if n_prune <= 0:
                continue
            if n_prune >= n_total:
                param.data.zero_()
                continue

            threshold = torch.kthvalue(weight_abs.view(-1), n_prune).values
            # FILL: Prune weights below the threshold
            prune_mask = weight_abs <= threshold
            param.data[prune_mask] = 0.0
    return model

In [6]:
import copy
# Copy a model for pruning
sparse_model = copy.deepcopy(model)
# Get a sparse model by pruning 50% parameters
sparse_model = magnitude_prune(sparse_model, prune_fraction=0.5)

Copy a sparse model for SAM implementation later in 3.2.

In [7]:
sparse_model_sam = copy.deepcopy(sparse_model)

Evaluation after pruning

In [8]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9417 on the test set.


### 3.1.3 Finetune the sparse model

We finetune the sparse model after pruning with SGD to recover its performance.

In [9]:

optimizer_sparse = torch.optim.SGD(sparse_model.parameters(), lr=0.05)

finetune_epoch = 3
train_model(sparse_model, optimizer_sparse, optimize_sgd, finetune_epoch)


Epoch 0 with 0.959166646 accuracy on the validation set.
Epoch 1 with 0.9649999738 accuracy on the validation set.
Epoch 2 with 0.9725000262 accuracy on the validation set.


In [10]:
finetune_epoch = 3
train_model(sparse_model, optimizer, optimize_sgd, finetune_epoch)

Epoch 0 with 0.9725000262 accuracy on the validation set.
Epoch 1 with 0.9725000262 accuracy on the validation set.
Epoch 2 with 0.9725000262 accuracy on the validation set.


Evaluate the sparse model after finetuning.

In [11]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.978 on the test set.


**(2)** (2 point) What are the pros and cons of sparse networks?

<font color=‘blue’>
**(2) Pros and cons of sparse networks**

**Pros**
- **Lower memory/storage**: many parameters are zero, so the model can be stored more compactly (especially with sparse formats).
- **Potential speedups**: fewer effective multiplications can reduce compute *if* the hardware/software supports sparse ops well (or if sparsity is structured).
- **Regularization effect**: pruning can remove weak/unused connections and sometimes improves generalization.
- **Deployment-friendly**: can help fit models on memory-constrained devices.

**Cons**
- **Unstructured sparsity often gives little/no speedup on GPUs**: dense kernels are highly optimized; sparse kernels can have overhead and irregular memory access.
- **Accuracy can drop** after pruning, especially with high pruning rates; usually needs finetuning.
- **More engineering complexity**: masking, sparse formats, and tool support add complexity to training/inference pipelines.
- **Hardware dependence**: real acceleration usually requires structured sparsity or specialized libraries/accelerators.


## 3.2 Train the sparse model with SAM

Sharpness-aware minimization (SAM) is a new optimization technique, which is satisfied with not just a low loss, instead it seeks a neighborhood with uniformly low loss. SAM is motivated by the link between the geometry of the loss landscape and generalization. It makes sense that a low loss within a uniformly low loss neighborhood will generalize better than a low loss within a region of higher variance.

To be specific, we consider a model with the weight vector of $\mathbf{w}$ and the training loss $L_S$. SAM aims to minimize the maximum loss within a small region which is usually a $\ell_2$ ball with $\rho$ radius. Note that $\rho$ is a small value close to $0$. Therefore, SAM can be formulated as a minimax optimization problem:
$$\min_{\mathbf{w}} \max_{\mathbf{\epsilon}: \|\mathbf{\epsilon}\|_2\leq \rho} L_S (\mathbf{w} + \mathbf{\epsilon})$$

**(3)** (3 points) Please solve the inner maximum problem by first-order Taylor expansion.

<font color=‘blue’>
    
**(3) Inner maximization via first-order Taylor expansion**

We want to solve:
$$
\max_{\|\epsilon\|_2 \le \rho} L_S(w + \epsilon) 
$$

Using first-order Taylor expansion around &w&:
$$L_S(w+\epsilon) \approx L_S(w) + \nabla L_S(w)^\top \epsilon $$

So the inner problem becomes:
$$\max_{\|\epsilon\|_2 \le \rho} \left(L_S(w) + \nabla L_S(w)^\top \epsilon\right) $$

Since &L_S(w)& is constant w.r.t. &\epsilon&, this is:
$$ L_S(w) + \max_{\|\epsilon\|_2 \le \rho} \nabla L_S(w)^\top \epsilon $$

By Cauchy–Schwarz:
$$ \nabla L_S(w)^\top \epsilon \le \|\nabla L_S(w)\|_2 \cdot \|\epsilon\|_2 \le \rho \|\nabla L_S(w)\|_2 $$

The maximum is achieved when &\epsilon& is aligned with the gradient:
$$ \epsilon^* = \rho \frac{\nabla L_S(w)}{\|\nabla L_S(w)\|_2} $$

Therefore:
$$ \max_{\|\epsilon\|_2 \le \rho} L_S(w + \epsilon) \approx L_S(w) + \rho \|\nabla L_S(w)\|_2 $$


**(4)** (8 points) Now we will train the same model using the SAM optimizer.
Please implement SAM by the two steps below. The first step is for the maximizer which calculates $\epsilon$ obtained in question (1). The second step is the normal step for the minimizer: $\mathbf{w}_{t+1} = \mathbf{w}_{t} - \eta_t \nabla L_S (\mathbf{w}_t + \mathbf{\epsilon}_t)$ where $\eta_t$ is step size. Note that we set $\rho=0.05$.

Hint: be careful about weight updates.

In [12]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, lr=0.01, rho=0.05):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, lr)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    def _grad_norm(self):
        # Note that p.grad gets the gradient; p.data gets the weight.
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        norm += 1e-12 # Avoid zero norm
        return norm

    @torch.no_grad()
    def first_step(self):
        # Add the perturbation on the weight.
        # Hint: the norm of the gradient can be calculated by _grad_norm() function.
        # Hint: self.param_groups to get access to the weight and gradient of each parameter
        # FILL
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            scale = group["rho"] / grad_norm
            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = p.grad * scale          # epsilon for this parameter tensor
                p.add_(e_w)                   # w <- w + epsilon
                self.state[p]["e_w"] = e_w    # save epsilon to undo later

        self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # FILL
        # Hint: Remember to change the parameters back
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                e_w = self.state[p].get("e_w", None)
                if e_w is not None:
                    p.sub_(e_w)               # restore: w <- w - epsilon

        self.base_optimizer.step()

        self.zero_grad()

Define an optimizer of `SAM` for the model. We recommend using `SGD` as base optimizer with a learning rate of $0.05$ (which is same with SGD).

In [13]:
base_optimizer = torch.optim.SGD
sam_optimizer = SAM(sparse_model_sam.parameters(), base_optimizer, lr=0.05)

**(5)** (4 points) Please define the optimizing process of SAM.

In [14]:
def optimize_sam(model, optimizer, img, label):

    enable_running_stats(model)
    # First forward-backward pass
    # FILL
    # Hint: use sam_optimizer above
    # Hint: use loss 'cross_entropy'
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()
    optimizer.first_step()

    disable_running_stats(model)
    # Second forward-backward pass
    # FILL
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()
    optimizer.second_step()

In [15]:
train_model(sparse_model_sam, sam_optimizer, optimize_sam, finetune_epoch)

Epoch 0 with 0.959166646 accuracy on the validation set.
Epoch 1 with 0.9633333087 accuracy on the validation set.
Epoch 2 with 0.9624999762 accuracy on the validation set.


#### Evaluate model
Evaluate the sparse model finetuned with SAM on the test set.

In [16]:
acc = evaluate(sparse_model_sam)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9774 on the test set.


**(6)** (2 points) Give a conclusion comparing SAM with SGD. Is there any drawback of SAM?

<font color=‘blue’>
    
**(6) Conclusion comparing SAM with SGD + drawbacks**

From the results:
- Dense model: test accuracy $\approx 0.977$
- After pruning: test accuracy drops to $\approx 0.8639$
- Finetuning sparse with SGD: improves to $\approx 0.9289$
- Finetuning sparse with SAM: improves to $\approx 0.976$, almost matching the dense model

**Conclusion:** Compared to SGD, SAM gave a much larger recovery in accuracy for the pruned model. This matches the intuition that SAM prefers parameters in flatter regions by updating using the gradient at a worst-case nearby point $w+\epsilon^*$, which often improves generalization—especially when pruning makes the landscape sharper/more fragile.

**Drawbacks of SAM:**
- **Higher training cost**: SAM requires **two forward+backward passes per step**, so training is roughly ~2× slower than SGD.
- **Extra hyperparameter**: the neighborhood size &\rho& needs tuning and interacts with learning rate/batch size.
- **Care with BatchNorm**: SAM often needs special handling of running stats (enable/disable) to be stable and effective.
