# 3. Sharpness-aware minimization (SAM) wirh Sparse Networks - 25 points

## 3.1 Get a sparse networks through pruning
**Pruning** is a technique used to reduce the size and complexity of a neural network model by removing **(setting to zero)** less important parameters. The goal is to create a more efficient model that retains its predictive accuracy while being smaller, which can improve both inference speed and memory usage.

Let's train a simple model on the MNIST dataset to learn about pruning at first. We just use 10% of the dataset for both training and testing.

### 3.1.1 Train a dense network with SGD
Let us first train a dense model with SGD. We reuse the model for the discriminator of the GAN in Homework 2 and name it 'Classifier'.

In [1]:
from lib.part3.utils import *
max_epochs = 10
device = "cpu" # Change this if you can and want to use a GPU device
model = Classifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)

We define the optimizing process of SGD.

In [2]:
def optimize_sgd(model, optimizer, img, label):
    optimizer.zero_grad()
    output = model(img)
    loss = cross_entropy(output, label)
    loss.backward()
    optimizer.step()

The following cell runs the training loop, this might take a few minutes.

In [3]:
train_model(model, optimizer, optimize_sgd, max_epochs)

Epoch 0 with 0.933 accuracy on the validation set.
Epoch 1 with 0.954 accuracy on the validation set.
Epoch 2 with 0.957 accuracy on the validation set.
Epoch 3 with 0.962 accuracy on the validation set.
Epoch 4 with 0.963 accuracy on the validation set.
Epoch 5 with 0.967 accuracy on the validation set.
Epoch 6 with 0.967 accuracy on the validation set.
Epoch 7 with 0.966 accuracy on the validation set.
Epoch 8 with 0.966 accuracy on the validation set.
Epoch 9 with 0.966 accuracy on the validation set.


#### Evaluate model
Evaluate the model on the test set.

In [4]:
acc = evaluate(model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9778 on the test set.


### 3.1.2 Sparse network with magnitude-based pruning

Magnitude-based pruning specifically focuses on **removing weights that have the smallest absolute values**, under the assumption that weights with smaller magnitudes contribute less to the model's output.

**(1)** (6 points) Realize magnitude-based pruning below, which removes a part of weights that have the smallest absolute values.

In [24]:
def magnitude_prune(model, prune_fraction):
    for name, param in model.named_parameters():
        if "weight" in name and param.requires_grad:
            # FILL: Get weight's absolute values
            weight_abs = param.abs()
            # FILL: Compute the threshold (s.t prune_fraction % of the weights with the lowest abs are removed)
            threshold = torch.kthvalue(weight_abs.flatten(), int(prune_fraction * weight_abs.numel()))[0]
            # FILL: Prune weights below the threshold
            with torch.no_grad():
                param[weight_abs < threshold] = 0
            
    return model

In [25]:
import copy
# Copy a model for pruning
sparse_model = copy.deepcopy(model)
# Get a sparse model by pruning 50% parameters
sparse_model = magnitude_prune(sparse_model, prune_fraction=0.5)

Copy a sparse model for SAM implementation later in 3.2.

In [26]:
sparse_model_sam = copy.deepcopy(sparse_model)

Evaluation after pruning

In [27]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9386 on the test set.


### 3.1.3 Finetune the sparse model

We finetune the sparse model after pruning with SGD to recover its performance.

In [28]:
finetune_epoch = 3
train_model(sparse_model, optimizer, optimize_sgd, finetune_epoch)

Epoch 0 with 0.942 accuracy on the validation set.
Epoch 1 with 0.942 accuracy on the validation set.
Epoch 2 with 0.942 accuracy on the validation set.


Evaluate the sparse model after finetuning.

In [29]:
acc = evaluate(sparse_model)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9543 on the test set.


**(2)** (2 point) What are the pros and cons of sparse networks?

**Pros:**
* Faster training
* Faster evaluation
* Keeping the weights associated to fewer features only could lead to more explainability and/or better feature selection, just like an $\ell_1$ regularizer that forces sparcity (espiecally at smaller scale)

**Cons:**
* Lower accuracy
* Need to finetune afterwards

## 3.2 Train the sparse model with SAM

Sharpness-aware minimization (SAM) is a new optimization technique, which is satisfied with not just a low loss, instead it seeks a neighborhood with uniformly low loss. SAM is motivated by the link between the geometry of the loss landscape and generalization. It makes sense that a low loss within a uniformly low loss neighborhood will generalize better than a low loss within a region of higher variance.

To be specific, we consider a model with the weight vector of $\mathbf{w}$ and the training loss $L_S$. SAM aims to minimize the maximum loss within a small region which is usually a $\ell_2$ ball with $\rho$ radius. Note that $\rho$ is a small value close to $0$. Therefore, SAM can be formulated as a minimax optimization problem:
$$\min_{\mathbf{w}} \max_{\mathbf{\epsilon}: \|\mathbf{\epsilon}\|_2\leq \rho} L_S (\mathbf{w} + \mathbf{\epsilon})$$

**(3)** (3 points) Please solve the inner maximum problem by first-order Taylor expansion.

<font color=‘blue’>
    
Write your answer here

With the first-order Taylor expansion, the objective function becomes:

$$L_S (\mathbf{w} + \mathbf{\epsilon}) \approx L_S (\mathbf{w}) + \epsilon^{\top} \nabla L_S (\mathbf{w})$$

Which leads to the following inner maximum problem:

$$\max_{\mathbf{\epsilon}: \|\mathbf{\epsilon}\|_2\leq \rho} \epsilon^{\top} \nabla L_S (\mathbf{w})$$

As the objective is linear, we know that the optimum will be on the boundary:

$$\max_{\mathbf{\epsilon}: \|\mathbf{\epsilon}\|_2 = \rho} \epsilon^{\top} \nabla L_S (\mathbf{w})$$

Using $\|\mathbf{\epsilon}\|^2_2 = \rho^2$ as an equivalent constraint that simplifies computations, this problem can easily be solved with Lagrange multipliers and leads to the following 2 by 2 system of equations:

$$\nabla_{\epsilon}(\epsilon^{\top} \nabla_{\mathbf{w}} L_S (\mathbf{w})) = \nabla_{\mathbf{w}} L_S (\mathbf{w}) = \lambda \nabla_{\epsilon} (\|\mathbf{\epsilon}\|^2_2 - \rho^2) = 2 \epsilon$$
$$\|\mathbf{\epsilon}\|^2_2 = \rho^2$$

By substitution, we get:

$$\epsilon^{\star} = \rho \frac{L_S (\mathbf{w})}{\| \nabla L_S (\mathbf{w})\|_2}$$

Which is indeed the most aligned vector with $\nabla L_S (\mathbf{w})$ that has a norm equal to $\rho$, resulting in a maximization of the scalar product.

**(4)** (8 points) Now we will train the same model using the SAM optimizer.
Please implement SAM by the two steps below. The first step is for the maximizer which calculates $\epsilon$ obtained in question (1). The second step is the normal step for the minimizer: $\mathbf{w}_{t+1} = \mathbf{w}_{t} - \eta_t \nabla L_S (\mathbf{w}_t + \mathbf{\epsilon}_t)$ where $\eta_t$ is step size. Note that we set $\rho=0.05$.

Hint: be careful about weight updates.

In [30]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, lr=0.01, rho=0.05):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, lr)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    def _grad_norm(self):
        # Note that p.grad gets the gradient; p.data gets the weight.
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        norm += 1e-12 # Avoid zero norm
        return norm

    @torch.no_grad()
    def first_step(self):
        # Add the perturbation on the weight.
        # Hint: the norm of the gradient can be calculated by _grad_norm() function.
        # Hint: self.param_groups to get access to the weight and gradient of each parameter
        # FILL

        rho = self.param_groups[0]['rho']
        grad_norm = self._grad_norm()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    epsilon_t = rho*p.grad/(grad_norm)
                    p.add_(epsilon_t)
        

        self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        # FILL
        # Hint: Remember to change the parameters back #retirer perturbation ?

        rho = self.param_groups[0]['rho']
        grad_norm = self._grad_norm()

        lr = self.param_groups[0]['lr']
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    grad = p.grad
                    #tjrs le bon epsilon ou faut le return dans first step ?
                    epsilon_t = rho*grad/(grad_norm)
                    p.add_(-lr*grad - epsilon_t)



        self.zero_grad()

Define an optimizer of `SAM` for the model. We recommend using `SGD` as base optimizer with a learning rate of $0.05$ (which is same with SGD).

In [35]:
base_optimizer = torch.optim.SGD
sam_optimizer = SAM(sparse_model_sam.parameters(), base_optimizer, lr=0.05)

**(5)** (4 points) Please define the optimizing process of SAM.

In [36]:
def optimize_sam(model, optimizer, img, label):

    enable_running_stats(model)
    # First forward-backward pass
    # FILL
    # Hint: use sam_optimizer above
    # Hint: use loss 'cross_entropy'
    pred = model(img)
    loss = cross_entropy(pred, label)
    loss.backward()
    optimizer.first_step()

    disable_running_stats(model)
    # Second forward-backward pass
    # FILL
    pred = model(img)
    loss = cross_entropy(pred, label)
    loss.backward()
    optimizer.second_step()

In [37]:
train_model(sparse_model_sam, sam_optimizer, optimize_sam, finetune_epoch)

Epoch 0 with 0.966 accuracy on the validation set.
Epoch 1 with 0.969 accuracy on the validation set.
Epoch 2 with 0.973 accuracy on the validation set.


#### Evaluate model
Evaluate the sparse model finetuned with SAM on the test set.

In [34]:
acc = evaluate(sparse_model_sam)
print(f"Accuracy of {round(acc, 4)} on the test set.")

Accuracy of 0.9758 on the test set.


**(6)** (2 points) Give a conclusion comparing SAM with SGD. Is there any drawback of SAM?

<font color=‘blue’>
    
Write your answer here

Training with SAM allows us to get better accuracy. We almost recover the accuracy that we had without pruning.

However, SAM takes roughly twice as long to train because it has to do two forward-backward passes for the min and max part of the problem. Finding a suitable $\rho$ also adds a difficulty to the problem.