<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/performance/quantization_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Quantization

_Note_: This notebook trains and profiles multiple models
and can take several minutes to run.

We recommend that you run all of the cells first (Runtime > Run All)
and then read through the code and exercises while the notebook executes.

You can also check out the results in the public
[W&B Workspace for this notebook](https://wandb.ai/wandb/quantize/workspace).

In [None]:
%%capture

!pip install wandb pytorch_lightning==1.3.2 torchviz

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/utils.py"
# Download a util file of helper methods for this notebook
!curl {repo_url + utils_path} > utils.py

In [None]:
# standard libraries, for profiling
import os
import time

# usual DL imports
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import wandb

# special import for quantization with Lightning
from pytorch_lightning.callbacks import QuantizationAwareTraining

# utilities for PyTorch Lightning and wandb
import utils

In [None]:
# filter out deprecation warnings from torch.quantization
import warnings

warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r"torch.quantization")
warnings.filterwarnings(action="ignore", category=UserWarning, module=r"torch.quantization")

In [None]:
!wandb login

# Setup Code: Model, Data, and Configuration

In [None]:
def print_model_size(network):
  """Save model to disk and print filesize"""
  torch.save(network.state_dict(), "tmp.pt")
  size_mb = os.path.getsize("tmp.pt") / 1e6
  print(f"{round(size_mb, 2)} MB")
  os.remove('tmp.pt')
  return size_mb


class FullyConnected(pl.LightningModule):

  def __init__(self, in_features, out_features,
               activation=None, batchnorm=False):
    super().__init__()
    self.linear = torch.nn.Linear(in_features, out_features)

    if activation is None:  # defaults to passing inputs unchanged
      activation = torch.nn.Identity()
    self.activation = activation

    if batchnorm:
      self.post_act = torch.nn.BatchNorm1d(out_features)
    else:
      self.post_act = torch.nn.Identity()

  def forward(self, x):
    return self.post_act(self.activation(self.linear(x)))


class Convolution(pl.LightningModule):

  def __init__(self, in_channels, out_channels, kernel_size,
               activation=None, batchnorm=False):
    super().__init__()
    self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size)

    if activation is None:
      activation = torch.nn.Identity()  # defaults to passing inputs unchanged
    self.activation = activation

    if batchnorm:
      self.post_act = torch.nn.BatchNorm2d(out_channels)
    else:
      self.post_act = torch.nn.Identity()

  def forward(self, x):
    return self.post_act(self.activation(self.conv2d(x)))


class LitCNN(utils.LoggedImageClassifierModule):
  """A simple CNN Model, with under-the-hood wandb
  and pytorch-lightning channels (logging, metrics, etc.).
  """

  def __init__(self, config, max_images_to_display=32):  # make the model
    super().__init__(max_images_to_display=max_images_to_display)

    # first, convolutional component
    self.conv_layers = torch.nn.Sequential(  # specify our LEGOs. edit this by adding to the list!
      # hidden conv layer
      Convolution(in_channels=1, kernel_size=config["kernel_size"],
                  activation=config["activation"],
                  out_channels=config["conv.channels"][0],
                  batchnorm=config["batchnorm"]),
      # hidden conv layer
      Convolution(in_channels=config["conv.channels"][0], kernel_size=config["kernel_size"],
                  activation=config["activation"],
                  out_channels=config["conv.channels"][1],
                  batchnorm=config["batchnorm"]),
      # pooling often follows 2 convs
      torch.nn.MaxPool2d(config["pool_size"]),
    )


    # need a fixed-size input for fully-connected component,
    #  so apply a "re-sizing" layer, to size set in config
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(
      (config["final_height"], config["final_width"]))

    # now, we can apply our fully-connected component
    final_size = config["final_height"] * config["final_width"] * config["conv.channels"][-1]
    self.fc_layers = torch.nn.Sequential( # specify our LEGOs. edit this by adding to the list!
      FullyConnected(in_features=final_size, activation=config["activation"],
                     out_features=config["fc1.size"], batchnorm=False),
      FullyConnected(in_features=config["fc1.size"], activation=config["activation"],
                     out_features=config["fc2.size"], batchnorm=False),
      FullyConnected(in_features=config["fc2.size"],  # "read-out" layer
                     out_features=10, batchnorm=False),
    )

    self.output_layer = torch.nn.LogSoftmax(dim=1)
    
    # for quantization
    self.quant = torch.quantization.QuantStub()  # quantize inputs
    self.dequant = torch.quantization.DeQuantStub()   # dequantize outputs

    self.loss = config["loss"]
    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]
    config.update({f"channels_{ii}": channels
                   for ii, channels in enumerate(config["conv.channels"])})

  def forward(self, x):  # produce outputs
    x = self.quant(x)  # apply quantization, if applicable

    # first apply convolutional layers
    for layer in self.conv_layers: 
      x = layer(x)

    # then convert to a fixed-size vector
    x = self.resize_layer(x)
    x = torch.flatten(x, start_dim=1)

    # then apply the fully-connected layers
    for layer in self.fc_layers:
      x = layer(x)

    x = self.dequant(x)  # remove quantization, if applicable
    return x

  def training_step(self, batch, idx):
    xs, ys = batch
    y_hats = self.output_layer(self.forward(xs))
    loss = self.loss(y_hats, ys)

    logging_scalars = {"loss": loss}
    for metric in self.training_metrics:
        self.add_metric(metric, logging_scalars, y_hats, ys)

    self.do_logging(xs, ys, idx, y_hats, logging_scalars)

    return loss

  def validation_step(self, batch, idx):
    xs, ys = batch
    y_hats = self.output_layer(self.forward(xs))
    loss = self.loss(y_hats, ys)

    logging_scalars = {"loss": loss}
    for metric in self.validation_metrics:
        self.add_metric(metric, logging_scalars, y_hats, ys)

    self.do_logging(xs, ys, idx, y_hats, logging_scalars, step="val")

    return loss

  def configure_optimizers(self):  # ⚡: setup for .fit
    return self.optimizer(self.parameters(), **self.optimizer_params)

In [None]:
config = {
  "quantization": "post",  # "post" | "qat" | "none"
  "batch_size": 1024,
  "max_epochs": 1,
  "batchnorm": True,
  "kernel_size": 7,
  "conv.channels": [128, 256],
  "pool_size": 2,
  "final_height": 10,
  "final_width": 10,
  "fc1.size": 1024,
  "fc2.size": 512,
  "activation": torch.nn.ReLU(),
  "loss": torch.nn.NLLLoss(),
  "optimizer": torch.optim.Adam,
  "optimizer.params": {"lr": 0.0001},
}

In [None]:
  # 📸 set up the dataset of images
  dmodule = utils.MNISTDataModule(batch_size=config["batch_size"], validation_size=1024)
  dmodule.prepare_data()
  dmodule.setup()

In [None]:
def train(network, dmodule, config):
  callbacks = []

  # QAT:
  #  if doing quant-aware training, add to callbacks
  if config["quantization"] == "qat":
    callbacks.append(QuantizationAwareTraining(input_compatible=False))

  with wandb.init(config=config, project="quantize", entity="wandb", job_type="train") as run:
    # 👟 configure Trainer 
    trainer = pl.Trainer(
      gpus=1, max_epochs=config["max_epochs"], log_every_n_steps=1,
      logger=pl.loggers.WandbLogger(log_model=True, save_code=True),
      callbacks=callbacks,
      progress_bar_refresh_rate=50)
                        
    # 🏃‍♀️ run the Trainer on the model
    trainer.fit(network, dmodule)

  # STATIC:
  #  if doing static post-training quantization, apply it now
  if config["quantization"] == "post":
    xs, _ = next(iter(dmodule.train_dataloader()))
    network = run_static_quantization(network, xs)  # see below for implementation

  return network


def profile(network, dmodule, config):

  # ⏱️ time the model and check the validation accuracy
  with wandb.init(config=config, project="quantize", entity="wandb", job_type="profile") as run:
    val_trainer = pl.Trainer(
      gpus=0,  # profile on CPU, not GPU
      max_epochs=1, logger=pl.loggers.WandbLogger(log_model=True, save_code=True),
      progress_bar_refresh_rate=50
    )

    network.eval()
    start = time.process_time()
    val_trainer.validate(network, val_dataloaders=dmodule.val_dataloader())
    runtime = time.process_time() - start

    # report metrics to wandb
    wandb.summary["runtime"] = runtime
    wandb.summary["size_mb"] = print_model_size(network)
    wandb.summary["params"] = network.count_params()

  return network 


def train_and_profile(dmodule, config):
  network = LitCNN(config)

  network = train(network, dmodule, config)
  network = profile(network, dmodule, config)

  return network

In [None]:
xs, ys = next(iter(dmodule.train_dataloader()))

# Baseline: No Quantization

In [None]:
config["quantization"] = "none"

baseline_nn = train_and_profile(dmodule, config)

# Post-Training Quantization

In [None]:
def run_static_quantization(network, xs, qconfig="fbgemm"):
  """Return a quantized version of supplied network.

  Runs forward pass of network with xs, so make sure they're on
  the same device. Returns a copy of the network, so watch memory consumption.

  Note that this uses torch.quantization, rather than PyTorchLightning.

  network: torch.Module, network to be quantized.
  xs: torch.Tensor, valid inputs for network.forward.
  qconfig: string, "fbgemm" to quantize for server/x86, "qnnpack" for mobile/ARM
  """
  # set up quantization
  network.qconfig = torch.quantization.get_default_qconfig(qconfig)
  network.eval()

  # attach methods for collecting activation statistics to set quantization bounds
  qnetwork = torch.quantization.prepare(network)
  
  # run inputs through network, collect stats
  qnetwork.forward(xs)
  
  # convert network to uint8 using quantization statistics
  qnetwork = torch.quantization.convert(qnetwork)

  return qnetwork

In [None]:
config["quantization"] = "post"
static_quant_nn = train_and_profile(dmodule, config)

# Quantization-Aware Training

In [None]:
config["quantization"] = "qat"

qat_nn = train_and_profile(dmodule, config)

# Exercises


#### 1. Comparing Model Size, Runtime, and Accuracy

Quantization improves two critical performance characteristics
relevant to running models on edge CPUs:
1. reducing the model memory footprint by ~4x
(from 32 bits to ~8 bits per parameter),
2. reducing the latency of large matrix multiplications,
typically by a factor of 2 or less.

Compare the model sizes (in MB)
and the runtimes for the three models.
Do you observe the typical improvements in this case?

The biggest drawback to quantization is that
quantization can reduce accuracy,
especially when statically quantizing large models.
Do you observe any accuracy penalty for quantization?

_Note_: quantization-aware training is less stable,
in this case, and gives a wider range of accuracies.
Check the [Project page](https://wandb.ai/wandb/quantize/workspace)
to see other runs and get a sense of the distribution,
or repeatedly run the notebook with the same parameters.

Generally, QAT is only beneficial in performance terms for certain models,
especially MobileNet-style architectures.

Check out [this video](https://www.youtube.com/watch?v=c3MT2qV5f9w)
for a more thorough discussion
on this and other details of quantization in PyTorch.

#### 2. A Closer Look at Quantization-Aware Training

In QAT,
we add extra "fake-quantization" operations to the
graph during training,
then drop them once the model is quantized.

Head to the Weights & Biases run page for a training (not profiling!)
run that used QAT
and find the model's compute graph (Files tab, `graph.png`),
where the operations that make up the model are represented.
Compare the graph to that from a run without QAT.
What differences do you see?

Adding these extra modules can increase the runtime for training
(which occurs at full precision).
Compare the runtimes for training runs
with and without QAT. Do you see a difference?
Compare this to the runtimes for training runs
with and without static quantization. Is there a difference here?

Also, note that training occurs on the GPU,
rather than the CPU on which the quantized models run,
and it operates on a training set that is 5x larger than
the validation set used during profiling.
Compare the runtimes of the
(see the Overview tab on the W&B run page).
Are they 5x longer?
Which provides more of a speedup:
running on the GPU or applying quantization?

#### 3. CHALLENGE: Fusing Modules

Additional performance improvements can be obtained by "fusing"
multiple modules together into a single module,
enabling some operations to occur more quickly.

The cell below creates a new, "fused" model
that combines the ReLU activation computation
with the matrix multiplication layers.

Do you see any reduction in the runtime?

In [None]:
config["quantization"] = "post+fusion"

# copy over parameters etc. from baseline_nn to new network
state = baseline_nn.state_dict()
fused_nn = LitCNN(config)
fused_nn.load_state_dict(state)

for layer in fused_nn.conv_layers:
  if isinstance(layer, Convolution):
    torch.quantization.fuse_modules(  # "fuses" multiple modules into one
      layer, ["conv2d", "activation",], inplace=True)

for layer in fused_nn.fc_layers:
  if isinstance(layer, FullyConnected):
    if isinstance(layer.activation, torch.nn.ReLU):
      torch.quantization.fuse_modules(
        layer, ["linear", "activation",], inplace=True)

fused_nn = run_static_quantization(fused_nn, xs)

fused_nn = profile(fused_nn, dmodule, config)

Improvements are more obvious for much larger tensors
(try resizing the inputs with a `torch.nn.AdaptiveAvgPool2d` layer
before applying the `conv_layers`!)
and when more modules are fused.

Batchnorm can be fused when it occurs after convolutional or before ReLU layers
(see the docstring for `torch.quantization.fuse_modules`).
In that order, all three modules can be fused into one.

Change the definition of the `Convolution` layer so that batchnorm
comes inbetween the convolutional layer and the activation,
then add it to the list of `modules_to_fuse` above.
Does this bring any additional speedup?

# Endnote

The most common error with quantized models involves improper quantization of inputs and/or outputs. Expand this section for details.


Quantized `uint8` tensors are a fundamentally different type from regular floating point tensors,
and so different operations are implemented for each.
For example, a quantized tensor cannot be quantized again,
while a floating point tensor cannot be un-quantized;
if the weights are not quantized but the inputs are quantized,
there are .
This is akin to the differences between tensors on CPU and GPU --
you can't add a CPU tensor to a GPU tensor, for example,
or pass

The error messages look like so:

```
RuntimeError: Could not run 'aten::{OP_NAME}' with arguments from the {'QuantizedCPU'/'CPU'} backend
```

where `OP_NAME` is the name of an operator
(e.g. including `relu` or `nll_loss` or `linear`).

You can use `OP_NAME` to identify where in the model the mismatch is occurring,
e.g. at the input or the output or, more rarely, in between.

If the error mentions the `QuantizedCPU` backend,
that means the inputs are quantized and the operator is not compatible with them.
As of version 1.8 of PyTorch,
most modules are not compatible with quantized tensors,
nor is any of `torch.nn.functional`,
but the most common modules are.
This can be resolved by applying a `DeQuantStub`
to the inputs
(see `.dequant` in the module above).

If the error mentions the `CPU` backend,
that means the operation is quantized but the inputs are not.
In general, this is resolved by
apply a `QuantStub` to the inputs
(see `.quant` in the module above).