<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/performance/pruning_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Pruning and Sparsity

In [None]:
%%capture

!pip install wandb pytorch_lightning==1.3.2 torchviz

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/utils.py"
# Download a util file of helper methods for this notebook
!curl {repo_url + utils_path} > utils.py

In [None]:
import math

# usual DL imports
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import wandb

# special import for pruning with Lightning
from pytorch_lightning.callbacks import ModelPruning

# utilities for PyTorch Lightning and wandb
import utils

In [None]:
!wandb login

# Utilities for Tracking and Logging Sparsity

_Note_: The methods and classes in this section just handle details of logging
pruned networks. These details are subject to change as the `torch.nn.prune`
library develops, and are not important to understanding pruning and sparsity in neural networks.
This section may be safely skipped.

In [None]:
class SparsityLogCallback(pl.Callback):
  """PyTorch Lightning Callback for logging the sparsity of weight tensors in a PyTorch Module.
  """

  def __init__(self):
    super().__init__()

  def on_validation_epoch_end(self, trainer, module):
    self.log_sparsities(trainer, module)

  def get_sparsities(self, module):
    weights = self.get_weights(module)
    names = [".".join(name.split(".")[:-1]) for name, _ in module.named_parameters()
             if "weight" in name.split(".")[-1]]
    sparsities = [torch.sum(weight == 0) / weight.numel() for weight in weights]

    return {"sparsity/" + name: sparsity for name, sparsity in zip(names, sparsities)}

  def log_sparsities(self, trainer, module):
    sparsities = self.get_sparsities(module)
    sparsities["sparsity/total"] = 1 - fraction_nonzero(module)
    sparsities["global_step"] = trainer.global_step
    trainer.logger.experiment.log(sparsities)


@staticmethod
def get_weights(module):
  weights = [parameter for name, parameter in module.named_parameters()
             if "weight" in name.split(".")[-1]]
  masks = [mask for name, mask in module.named_buffers()
             if "weight_mask" in name.split(".")[-1]]
  if masks:
    with torch.no_grad():
      weights = [mask * weight for mask, weight in zip(masks, weights)]

  return weights

SparsityLogCallback.get_weights = get_weights
# patches the FilterLogCallback for compatibility with networks during pruning
utils.FilterLogCallback.get_weights = get_weights


def count_nonzero(module):
  """Counts the total number of non-zero parameters in a module.
  
  For compatibility with networks with active torch.nn.utils.prune methods,
  checks for _mask tensors, which are applied during forward passes and so
  represent the actual sparsity of the networks."""
  if module.named_buffers():
    masks = {name[:-5]: mask_tensor for name, mask_tensor in module.named_buffers()
             if name.endswith("_mask")}
  else:
    masks = {}

  nparams = 0
  with torch.no_grad():
    for name, tensor in module.named_parameters():
      if name[:-5] in masks.keys():
        tensor = masks[name[:-5]]
      nparams += int(torch.sum(tensor != 0))

  return nparams


def fraction_nonzero(lit_module):
  """Gives the fraction of parameters that are non-zero in a module."""

  return count_nonzero(lit_module) / lit_module.count_params()

# Setup Code: Model, Data, and Configuration

In [None]:
class FullyConnected(pl.LightningModule):

  def __init__(self, in_features, out_features, activation=None, dropout=0.):
    super().__init__()
    self.linear = torch.nn.Linear(in_features, out_features)
    if activation is None:  # defaults to passing inputs unchanged
      activation = torch.nn.Identity()
    self.activation = activation

    if dropout:
      self.post_act = torch.nn.Dropout(dropout)
    else:
      self.post_act = torch.nn.Identity()

  def forward(self, x):
    return self.post_act(self.activation(self.linear(x)))
    
    
class Convolution(pl.LightningModule):

  def __init__(self, in_channels, out_channels, kernel_size,
               activation=None, dropout=0.):
    super().__init__()
    self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size)
    if activation is None:
      activation = torch.nn.Identity()  # defaults to passing inputs unchanged
    self.activation = activation

    if dropout:
      self.post_act = torch.nn.Dropout2d(dropout)
    else:
      self.post_act = torch.nn.Identity()

  def forward(self, x):
    return self.post_act(self.activation(self.conv2d(x)))


class LitCNN(utils.LoggedImageClassifierModule):
  """A simple CNN Model, with under-the-hood wandb
  and pytorch-lightning features (logging, metrics, etc.).
  """

  def __init__(self, config, max_images_to_display=32):  # make the model
    super().__init__(max_images_to_display=max_images_to_display)

    # first, convolutional component
    self.conv_layers = torch.nn.Sequential(
      # hidden conv layer
      Convolution(in_channels=1, kernel_size=config["kernel_size"],
                  activation=config["activation"],
                  out_channels=config["conv.channels"][0],
                  dropout=config["conv.dropout"]),
      # hidden conv layer
      Convolution(in_channels=config["conv.channels"][0], kernel_size=config["kernel_size"],
                  activation=config["activation"],
                  out_channels=config["conv.channels"][1],
                  dropout=config["conv.dropout"]),
      # pooling often follows 2 convs
      torch.nn.MaxPool2d(config["pool_size"]),
    )

    # need a fixed-size input for fully-connected component,
    #  so apply a "re-sizing" layer, to size set in config
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(
      (config["final_height"], config["final_width"]))

    # now, we can apply our fully-connected component
    final_size = config["final_height"] * config["final_width"] * config["conv.channels"][-1]
    self.fc_layers = torch.nn.Sequential( # specify our LEGOs. edit this by adding to the list!
      FullyConnected(in_features=final_size, activation=config["activation"],
                     out_features=config["fc1.size"],
                     dropout=config["fc.dropout"]),
      FullyConnected(in_features=config["fc1.size"], activation=config["activation"],
                     out_features=config["fc2.size"],
                     dropout=config["fc.dropout"]),
      FullyConnected(in_features=config["fc2.size"],  # "read-out" layer
                     out_features=10),
    )

    self.loss = config["loss"]
    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]
    config.update({f"channels_{ii}": channels
                   for ii, channels in enumerate(config["conv.channels"])})

  def forward(self, x):  # produce outputs
    # first apply convolutional layers
    for layer in self.conv_layers: 
      x = layer(x)

    # then convert to a fixed-size vector
    x = self.resize_layer(x)
    x = torch.flatten(x, start_dim=1)

    # then apply the fully-connected layers
    for layer in self.fc_layers: # snap together the LEGOs
      x = layer(x)

    return F.log_softmax(x, dim=1)  # compute log of softmax, for numerical reasons

  def configure_optimizers(self):  # ⚡: setup for .fit
    return self.optimizer(self.parameters(), **self.optimizer_params)

In [None]:
config = {  # basic config, without pruning
  "batch_size": 512,
  "max_epochs": 5,
  "kernel_size": 9,
  "conv.channels": [128, 256],
  "conv.dropout": 0.5,
  "pool_size": 2,
  "final_height": 10,
  "final_width": 10,
  "fc1.size": 1024,
  "fc2.size": 512,
  "fc.dropout": 0.5,
  "activation": torch.nn.ReLU(),
  "loss": torch.nn.NLLLoss(),  # cross-entropy loss
  "optimizer": torch.optim.Adam,
  "optimizer.params": {"lr": 0.0001,
                       "weight_decay": 5e-3}  # weight decay makes weights decay to 0
}

In [None]:
# 📸 set up the dataset of images
dmodule = utils.MNISTDataModule(batch_size=config["batch_size"])
dmodule.prepare_data()
dmodule.setup()

# Training Code

In [None]:
def train(network, dmodule, config):

  with wandb.init(config=config, entity="wandb", project="prune", job_type="profile") as run:
    
    callbacks = []
    # Pruning:
    #  if doing model pruning, add in callbacks
    for prune_config in config["pruning"].values():
      callbacks.append(make_pruner(prune_config, network, n_epochs=config["max_epochs"]))

    filter_logger_callback = utils.FilterLogCallback(
      image_size=[], log_input=True, log_output=False)
    sparsity_logger_callback = SparsityLogCallback()

    callbacks.extend([filter_logger_callback, sparsity_logger_callback])

    # 👟 configure Trainer 
    trainer = pl.Trainer(
        precision=16,  # use half-precision floats
        gpus=1,  # use the GPU for .forward
        logger=pl.loggers.WandbLogger(
          log_model=True, save_code=True),  # log to Weights & Biases
        callbacks=callbacks,
        max_epochs=config["max_epochs"], log_every_n_steps=1,
        progress_bar_refresh_rate=50)
                        
    # 🏃‍♀️ run the Trainer on the model
    trainer.fit(network, dmodule)
    
  return network

# Baseline: No Pruning

In [None]:
# 🥅 instantiate the network
network = LitCNN(config)

config["pruning"] = {}
network = train(network, dmodule, config)

# Pruning Networks

## Pruning Callbacks: Global and Local

In [None]:
# Helper functions for building ModelPruning Callbacks
#  details here are mostly unimportant; see the docs for ModelPruning
#  for more on how pruning works and is configured

def make_pruner(prune_config, network=None, n_epochs=None):
  """Builds a ModelPruning PyTorchLightning Callback from a dictionary.

  Aside from the keyword arguments to pl.Callbacks.ModelPruning, this dictionary
  may contain the keys "target_sparsity"
  
  target_sparsity is combined with n_epochs to determine the value of the
  "amount" keyword argument to ModelPruning, which specifies how much pruning to
  do on each epoch.

  parameters can be None, "conv", or "linear". It is used to fetch the
  paarameters which are to be pruned from the provided network. See
  get_parameters_to_prune for details. Note that None corresponds to pruning
  all parameters.
  """
  if "target_sparsity" in prune_config.keys():
    target = prune_config.pop("target_sparsity")
    assert n_epochs is not None, "when specifying target sparsity, must provide number of epochs"
    prune_config["amount"] = compute_iterative_prune(target, n_epochs)

  assert "amount" in prune_config.keys(), "must specify stepwise pruning amount or target"

  if "parameters" in prune_config.keys():
    parameters = prune_config.pop("parameters")
    if parameters is not None:
      assert network is not None, "when specifying parameters, must provide network"
    prune_config["parameters_to_prune"] = get_parameters_to_prune(parameters, network)

  assert "parameters_to_prune" in prune_config.keys(), "must specify which parameters to prune, or None"

  return ModelPruning(**prune_config)


def get_parameters_to_prune(parameters, network):
  """Return the weights of network matching the parameters value.

  Parameters must be one of "conv" or "linear", or None,
  in which case None is also returned.
  """
  if parameters == "conv":
    return [(layer.conv2d, "weight") for layer in network.conv_layers
            if isinstance(layer, Convolution)]
  elif parameters == "linear":
    return [(layer.linear, "weight") for layer in network.fc_layers
            if isinstance(layer, FullyConnected)]
  elif parameters is None:
    return
  else:
    raise ValueError(f"could not understand parameters value: {parameters}")

def compute_iterative_prune(target_sparsity, n_epochs):
  return 1 - math.pow(1 - target_sparsity, 1 / n_epochs)

In [None]:
global_prune_config = {  # config for applying pruning to the entire network
  "parameters": None,
  "pruning_fn": "l1_unstructured",
  "target_sparsity": 0.9,  # target sparsity level for this pruner
  "use_global_unstructured": True,
  "pruning_dim": None,
  "pruning_norm": None,
}

conv_prune_config = {  # config for applying pruning channelwise to conv layers
  "parameters": "conv", 
  "pruning_fn": "ln_structured",
  "target_sparsity": 0.9,
  "use_global_unstructured": False,
  "pruning_dim": 0,
  "pruning_norm": 1,
}

linear_prune_config = {  # config for applying pruning featurewise to linear layers
  "parameters": "linear",
  "pruning_fn": "ln_structured",
  "target_sparsity": 0.9,
  "use_global_unstructured": False,
  "pruning_dim": 1,
  "pruning_norm": 1,
}

pnetwork = LitCNN(config)

config["pruning"] = {}
config["pruning"]["global"] = global_prune_config  # comment to remove global pruning
config["pruning"]["conv"] = conv_prune_config  # comment to remove channelwise pruning in conv layers
config["pruning"]["linear"] = linear_prune_config  # comment to remove featurewise pruning in linear layers

pnetwork = train(pnetwork, dmodule, config)

# Exercises

#### 1. Training a Network with 99% Sparsity

With the default settings above
(unstructured pruning applied globally
and feature-wise pruning applied to linear and convolutional layers,
each with a target sparsity of `0.9`),
typical final sparsities are close to 99%.

Review the Weights & Biases dashboard for a training run with these settings.

Do you notice anything interesting in the loss and accuracy traces?
Especially on the training set, these are typically monotonic.

The "sparsity" section tracks the degree of sparsity for
each layer and for the network as a whole.

With a typical multiplicative, iterative pruning strategy,
the largest absolute increase in sparsity is at the end of the first epoch,
when it increases from `0` to `0.33`,
while the fraction of remaining weights pruned
remains constant.
The magnitude of pruned weights in each step
generally increases throughout training.

Based on training set performance,
does the accuracy hit caused by pruning track
a) the absolute increase in sparsity,
b) the relative increase in sparsity,
(the fraction of remaining weights pruned),
or c) the magnitude of the pruned weights?

The weights in the input layer are visualized as images
alongside the training and validation metrics
(the step index, on the x axis, is shared between the charts).
Each patch, separated by black bars from the other patches,
represents the kernel for a single unit in the convolutional layer.
During training, weights are pruned by being set to 0,
which is rendered as gray here.
Notice that some kernels are entirely gray,
that by the end almost all of the kernels are either fully or primarily gray.
Relate these and other features of this chart to the sparsity over time curves
and to the pruning strategies employed.
Try removing the convolutional pruning
(`config["pruning"]["conv"]`),
re-running training,
and note the differences.

It can be helpful to additionally compare these weights
with the weights learned in the baseline,
with no pruning.

#### 2. Train Longer, Get Sparser?

In typical pruning stategies,
pruning is applied epochwise.

When epochs are shorter (smaller `batch_size`)
or when there are fewer of them (smaller `max_epochs`),
the impact of pruning on accuracy can be greater.

Test out this statement by checking for
- reduced performance with smaller `max_epochs` (2-3) and/or larger `batch_size` (2x to 4x larger)
- improved performance with larger `max_epochs` (up to 50) and/or smaller `batch_size` (2x to 8x smaller)

Can you find a setting of the parameters for which
the final network is approximately 99% sparse
but where validation accuracy is 95% or higher?
Note that 24 out of every 25 parameters
are in the linear layers.

Networks with large numbers of epochs may end up with training losses
significantly higher than final validation loss,
especially when feature-wise pruning,
as in the `linear` and `conv` pruners,
is turned on.
Why might this be?

_Hint_: DropOut is disabled during validation.

#### 3. Structured Pruning Considered Harmful

Structured pruning,
which removes entire input or output channels,
leads more easily to performance gains:
simply create a new network with all of the pruned
neurons removed
(effectively reducing the number of neurons in each layer).
This is not easy, but it can be done in a straightforward manner
to achieve acceleration of inference on commodity CPU/GPU hardware.

By comparison, unstructured pruning,
which removes specific weights,
can only lead to performance gains 
when using special-purpose hardware and software for sparse matrix
multiplication that is,
as of early 2021,
not widely used and available.

However, structured pruning can have more deleterious effects on training.

Turn off the structured pruners
(featurewise `linear` and channelwise `conv`)
and increase the global sparsity target to `0.99`.
This will train a network that only has global unstructured pruning
to the same sparsity level as one with the original default settings
(sparsity target `0.9` for each pruner separately).

Compare the final validation performance
of the two networks
(and make sure the other hyperparameters are the same!).
Which network does better?

_Note_: though the distinction between, say, 94% accuracy and 97% accuracy
is small in absolute terms, it represents a reduction by half of
the number of errors -- more akin to the difference between 33% and 66% accuracy
than between 33% and 36%.
These seemingly marginal improvements are indeed worth fighting over,
provided accuracy is in fact the correct model metric.

#### Challenge: DropOut and Sparsity



DropOut is effectively a form of random
feature-wise pruning, applied at each training step.

It seems plausible that including DropOut during training
might reduce the magnitude of the "jumps" in the loss
caused by pruning -- the network has already learned a strategy
that is robust to random pruning, so perhaps it is also more robust
to structured pruning.

Test this hypothesis by running a series
of runs with varying pruning strategies
with and then without DropOut
(you can turn off dropout by reducing the `.dropout` hyperparameters to `0.`).

Does it hold up?