### Deep Learning Homework 5

Starting from the implementation contained within the notebook `05-pruning.ipynb`, extend the  `magnitude_pruning` function to allow for incremental (iterative) pruning. In the current case, if you try pruning one more time, you'll notice that it will not work as there's no way to communicate to the future calls of `magnitude_pruning` to ignore the parameters which have already been pruned. Find a way to enhance the routine s.t. it can effectively prune networks in a sequential fashion (i.e., if we passed an MLP already pruned of 20% of its parameters, we want to prune *another* 20% of parameters).

Hint: use the mask.

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import pylab as pl
from IPython.display import clear_output

In [2]:
# Define a model
class MLP(nn.Module):
    """
    Implements a neural network to train on MNIST 
    """
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(784, 384),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(384),
            torch.nn.Linear(384, 384),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(384),
            torch.nn.Linear(384, 384),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(384),
            torch.nn.Linear(384, 10),
            torch.nn.LogSoftmax(dim=1)
        )
        
    def forward(self, X):        
        return self.layers(X)

In [3]:
def magnitude_pruning(model, pruning_rate, mask=None, layers_to_prune=None, verbose=False):
    """
    Implements magnitude pruning on a model, the function returns a ma
    
    Parameters
    ----------
    model: nn.Module
        Model to which apply the pruning
    
    pruning_rate: float
        Rate of weights that will be set to 0
    
    mask: list of Tensors, Optional
        Mask of already pruned weights, default: None
    
    layers_to_prune: list of strings, Optional
        Names of the layers to prune, useful to avoid pruning batchnorm layers,
        default: None
    
    verbose: bool, Optional
        Whether to print pruning statistics, mainly for testing purposes
        default: False
    Returns
    -------
    mask: list of Tensors
        List with length the number of layers and for each entry a tensor
        of size of the corresponding layer filled with 0 or 1 that tells 
        if the weight is removed (0) or not (1), the mask is useful to retrain
        the model afterwards or reapply pruning a second time
    """
    
    
    # 1. vectorize distribution of abs(parameter)
    
    # Zero out the parameters that are already pruned
    # Note: This step ensures to prune new parameters even when
    # the previously pruned parameters are different from zero
    if mask is not None and len(mask) > 0:
        params_to_prune = [pars[1]*m for pars, m in zip(model.named_parameters(), mask) if any([l in pars[0] for l in layers_to_prune])]
    else:
        params_to_prune = [pars[1] for pars in model.named_parameters() if any([l in pars[0] for l in layers_to_prune])]
    
    flat = torch.cat([pars.abs().flatten() for pars in params_to_prune], dim=0)
    
    # 2. sort this distribution
    flat = flat.sort()[0]

    # 3. obtain the threshold
    n_params = flat.size()[0]
    # Number of already pruned weights
    already_pruned = 0 if mask is None else sum([torch.sum((m-1).abs()).item() for m in mask])
    position = int(pruning_rate * (n_params - already_pruned) + already_pruned)
    thresh = flat[position]
    
    if verbose:
        s = f"Total number of parameters: {n_params}\n"
        s += f"parameters already pruned: {int(already_pruned)}\n"
        s += f"parameters pruned this run: {int(position-already_pruned)}\n"
        s += f"pruning rate: {(position-already_pruned) / (n_params-already_pruned)}\n"
        print(s)

    # reset the mask
    mask = []
    # 4. binarize the parameters & 5. compose these booleans into the mask &
    # 6. obtain the new structure of parameters
    '''
    I do this process with a for cycle instead of a list comprehension for clarity
    * if the layer is a layer to prune → populate the mask with 1s and 0s
    * otherwise → just populate the mask with ones
    By doing so, I can immediately apply the mask to the model as well...
    '''
    for pars in model.named_parameters():
        # Pruned layers
        if any([l in pars[0] for l in layers_to_prune]):
            m = torch.where(pars[1].abs() >= thresh, 1, 0)
            mask.append(m)
            pars[1].data *= m
        # Unpruned layers
        else:
            mask.append(torch.ones_like(pars[1]))

    # 7. what do we need to return?
    return mask

In [4]:
net = MLP()

# Test one shot pruning
print("One shot pruning:\n")
_ = magnitude_pruning(net, 0.2, layers_to_prune=["1", "4", "7", "10"], verbose=True)

# Test recursive pruning
print("\nRecursive pruning:\n")
mask = []
for i in range(10):
    mask = magnitude_pruning(net, 0.2, mask, layers_to_prune=["1", "4", "7", "10"], verbose=True)

One shot pruning:

Total number of parameters: 600970
parameters already pruned: 0
parameters pruned this run: 120194
pruning rate: 0.2


Recursive pruning:

Total number of parameters: 600970
parameters already pruned: 0
parameters pruned this run: 120194
pruning rate: 0.2

Total number of parameters: 600970
parameters already pruned: 120194
parameters pruned this run: 96155
pruning rate: 0.19999958400585718

Total number of parameters: 600970
parameters already pruned: 216349
parameters pruned this run: 76924
pruning rate: 0.19999948000759188

Total number of parameters: 600970
parameters already pruned: 293273
parameters pruned this run: 61539
pruning rate: 0.1999987000198247

Total number of parameters: 600970
parameters already pruned: 354812
parameters pruned this run: 49231
pruning rate: 0.19999756254113213

Total number of parameters: 600970
parameters already pruned: 404043
parameters pruned this run: 39385
pruning rate: 0.1999979687904655

Total number of parameters: 600970
p