<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/projects/constrained_emotion_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Designing a Memory-Constrained Emotion Classifier

In [None]:
%%capture
!pip install pytorch_lightning torchviz wandb

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/utils.py"
# Download a util file of helper methods for this notebook
!curl {repo_url + utils_path} --output utils.py

In [None]:
from pathlib import Path
import math
import os
import subprocess

import pandas as pd
import pytorch_lightning as pl
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets
import wandb

from pytorch_lightning.callbacks import ModelPruning, QuantizationAwareTraining

import utils

In [None]:
!wandb login

## Facial Expression `DataModule` and `DataLoaders`

In [None]:
class FERDataModule(pl.LightningDataModule):
  """DataModule for downloading and preparing the FER2013 dataset.
  """
  tar_url = "https://www.dropbox.com/s/opuvvdv3uligypx/fer2013.tar"
  local_path = Path("fer2013")

  def __init__(self, batch_size=64):
    super().__init__()  # ⚡: we inherit from LightningDataModule
    self.batch_size = batch_size
    self.val_batch_size = 10 * self.batch_size

  def prepare_data(self, validation_size=0.2, force_reload=False):
    # ⚡: how do we set up the data?
    if hasattr(self, "training_data") and not force_reload:
      return  # only re-run if we haven't been run before

    # download the data from the internet
    self.download_data()

    # read it from a .csv file
    faces, emotions = self.read_data()

    # normalize it
    faces = torch.divide(faces, 255.)

    # split it into training and validation
    validation_size = int(len(faces) * 0.8)

    self.training_data = torch.utils.data.TensorDataset(
      faces[:-validation_size], emotions[:-validation_size])
    self.validation_data = torch.utils.data.TensorDataset(
      faces[-validation_size:], emotions[-validation_size:])
    
    # record metadata
    self.num_total, self.num_classes = emotions.shape[0], torch.max(emotions)
    self.num_train = self.num_total - validation_size
    self.num_validation = validation_size

  def train_dataloader(self):  # ⚡: how do we go from dataset to dataloader?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during training."""
    return DataLoader(self.training_data, batch_size=self.batch_size,
                      num_workers=1, pin_memory=True)

  def val_dataloader(self):  # ⚡: what about during validation?
    """The DataLoaders returned by a DataModule produce data for a model.
    
    This DataLoader is used during validation, at the end of each epoch."""
    return DataLoader(self.validation_data, batch_size=self.val_batch_size,
                      num_workers=1, pin_memory=True)

  def download_data(self):
    if not os.path.exists(self.local_path):
      print("Downloading the face emotion dataset...")
      subprocess.check_output(
          f"curl -SL {self.tar_url} | tar xz", shell=True)
      print("...done")
      
  def read_data(self):
    """Read the data from a .csv into torch Tensors"""
    data = pd.read_csv(self.local_path / "fer2013.csv")
    pixels = data["pixels"].tolist()
    width, height = 48, 48
    faces = []
    for pixel_sequence in pixels:
        face = np.asarray(pixel_sequence.split(
            ' '), dtype=np.uint8).reshape(1, width, height,)
        faces.append(face.astype("float32"))

    faces = np.asarray(faces)
    emotions = data["emotion"].to_numpy()

    return torch.tensor(faces), torch.tensor(emotions)

# Utility Code

These cells provide extra functionality related to logging and optimizing
performance metrics:
static quantization, model size and (nonzero) parameter counting,
and weight pruning.

### Quantization

In [None]:
def run_static_quantization(network, xs, qconfig="fbgemm"):
  """Return a quantized version of supplied network.

  Runs forward pass of network with xs, so make sure they're on
  the same device. Returns a copy of the network, so watch memory consumption.

  Note that this uses torch.quantization, rather than PyTorchLightning.

  network: torch.Module, network to be quantized.
  xs: torch.Tensor, valid inputs for network.forward.
  qconfig: string, "fbgemm" to quantize for server/x86, "qnnpack" for mobile/ARM
  """
  # set up quantization
  network.qconfig = torch.quantization.get_default_qconfig(qconfig)
  network.eval()

  # attach methods for collecting activation statistics to set quantization bounds
  qnetwork = torch.quantization.prepare(network)
  
  # run inputs through network, collect stats
  qnetwork.forward(xs)
  
  # convert network to uint8 using quantization statistics
  qnetwork = torch.quantization.convert(qnetwork)

  return qnetwork

### Pruning

In [None]:
# Helper functions for building ModelPruning Callbacks

def make_pruner(prune_config, network=None, n_epochs=None):
  """Builds a ModelPruning PyTorchLightning Callback from a dictionary.

  Aside from the keyword arguments to pl.Callbacks.ModelPruning, this dictionary
  may contain the keys "target_sparsity"
  
  target_sparsity is combined with n_epochs to determine the value of the
  "amount" keyword argument to ModelPruning, which specifies how much pruning to
  do on each epoch.

  parameters can be None, "conv", or "linear". It is used to fetch the
  parameters which are to be pruned from the provided network. See
  get_parameters_to_prune for details. Note that None corresponds to pruning
  all parameters.
  """
  if "target_sparsity" in prune_config.keys():
    target = prune_config.pop("target_sparsity")
    assert n_epochs is not None, "when specifying target sparsity, must provide number of epochs"
    prune_config["amount"] = compute_iterative_prune(target, n_epochs)

  assert "amount" in prune_config.keys(), "must specify stepwise pruning amount or target"

  if "parameters" in prune_config.keys():
    parameters = prune_config.pop("parameters")
    if parameters is not None:
      assert network is not None, "when specifying parameters, must provide network"
    prune_config["parameters_to_prune"] = get_parameters_to_prune(parameters, network)

  assert "parameters_to_prune" in prune_config.keys(), "must specify which parameters to prune, or None"

  return ModelPruning(**prune_config)


def get_parameters_to_prune(parameters, network):
  """Return the weights of network matching the parameters value.

  Parameters must be one of "conv" or "linear", or None,
  in which case None is also returned.
  """
  if parameters == "conv":
    return [(layer, "weight") for layer in network.modules()
            if isinstance(layer, torch.nn.Conv2d)]
  elif parameters == "linear":
    return [(layer, "weight") for layer in network.modules()
            if isinstance(layer, torch.nn.Linear)]
  elif parameters is None:
    return
  else:
    raise ValueError(f"could not understand parameters value: {parameters}")


def compute_iterative_prune(target_sparsity, n_epochs):
  return 1 - math.pow(1 - target_sparsity, 1 / n_epochs)

### Metrics

In [None]:
# Metric calculation for model file size (quantization)
#  and total non-zero parameters (pruning)

def print_model_size(network):
  """Save model to disk and print filesize"""
  torch.save(network.state_dict(), "tmp.pt")
  size_mb = os.path.getsize("tmp.pt") / 1e6
  print(f"{round(size_mb, 2)} MB")
  os.remove('tmp.pt')
  return size_mb


def count_nonzero(module):
  """Counts the total number of non-zero parameters in a module.
  
  For compatibility with networks with active torch.nn.utils.prune methods,
  checks for _mask tensors, which are applied during forward passes and so
  represent the actual sparsity of the networks."""
  if module.named_buffers():
    masks = {name[:-5]: mask_tensor for name, mask_tensor in module.named_buffers()
             if name.endswith("_mask")}
  else:
    masks = {}

  nparams = 0
  with torch.no_grad():
    for name, tensor in module.named_parameters():
      if name[:-5] in masks.keys():
        tensor = masks[name[:-5]]
      nparams += int(torch.sum(tensor != 0))

  return nparams

## Defining the Model

### Classes for `FullyConnected` and `Convolution`al Blocks

In [None]:
class FullyConnected(pl.LightningModule):

  def __init__(self, in_features, out_features,
               activation=None, batchnorm=False, dropout=0.):
    super().__init__()
    self.linear = torch.nn.Linear(in_features, out_features)

    if activation is None:  # defaults to passing inputs unchanged
      activation = torch.nn.Identity
    self.activation = activation()

    post_act = []
    if batchnorm:
      post_act.append(torch.nn.BatchNorm1d(out_features))
    if dropout:
      post_act.append(torch.nn.Dropout(dropout))

    self.post_act = torch.nn.Sequential(*post_act)

  def forward(self, x):
    return self.post_act(self.activation(self.linear(x)))


class Convolution(pl.LightningModule):

  def __init__(self, in_channels, out_channels, kernel_size,
               activation=None, batchnorm=False, dropout=0.):
    super().__init__()
    self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size)

    if activation is None:
      activation = torch.nn.Identity  # defaults to passing inputs unchanged
    self.activation = activation()

    post_act = []
    if batchnorm:
      post_act.append(torch.nn.BatchNorm2d(out_channels))
    if dropout:
      post_act.append(torch.nn.Dropout2d(dropout))

    self.post_act = torch.nn.Sequential(*post_act)

  def forward(self, x):
    return self.post_act(self.activation(self.conv2d(x)))

### Model Class: `LitEmotionClassifier`

In [None]:
class LitEmotionClassifier(utils.LoggedImageClassifierModule):

  def __init__(self, config, max_images_to_display=32):
    super().__init__(max_images_to_display=max_images_to_display)

    self.labels = ["Angry", "Disgusted", "Afraid", "Happy",
                   "Sad", "Surprised", "Neutral"]

    # define layers here; apply them in forward
    #  for compatibility, use the FullyConnected and Convolution blocks above
    self.linear = FullyConnected(in_features=1 * 48 * 48,
                                 out_features=len(self.labels))

    # note: applied in training_step, not forward
    self.output_layer = torch.nn.LogSoftmax(dim=1)
    
    # for quantization; if no quantization applied, these do nothing
    self.quant = torch.quantization.QuantStub()  # quantize inputs
    self.dequant = torch.quantization.DeQuantStub()   # dequantize outputs 

    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]
    self.loss = config["loss"]

  def forward(self, x):
    x = self.quant(x)  # apply quantization, if applicable

    x = torch.flatten(x, start_dim=1)

    x = self.linear(x)
    
    x = self.dequant(x)  # apply dequantization, if applicable
    # Note: LogSoftmax is applied outside of forward, for compatibility with quantization
    return x

  def training_step(self, batch, idx):
    xs, ys = batch
    y_hats = self.output_layer(self.forward(xs))
    loss = self.loss(y_hats, ys)

    logging_scalars = {"loss": loss}
    for metric in self.training_metrics:
        self.add_metric(metric, logging_scalars, y_hats, ys)

    self.do_logging(xs, ys, idx, y_hats, logging_scalars)

    return loss

  def validation_step(self, batch, idx):
    xs, ys = batch
    y_hats = self.output_layer(self.forward(xs))
    loss = self.loss(y_hats, ys)

    logging_scalars = {"loss": loss}
    for metric in self.validation_metrics:
        self.add_metric(metric, logging_scalars, y_hats, ys)

    self.do_logging(xs, ys, idx, y_hats, logging_scalars, step="val")

    return loss

  def configure_optimizers(self):
    return self.optimizer(self.parameters(), **self.optimizer_params)

## Defining the Training Process

In [None]:
def train(network, dmodule, config):
  with wandb.init(config=config, project="lit-fer-constrained", entity="wandb") as run:

    callbacks = []
  
    # QAT:
    #  if doing quant-aware training, add to callbacks
    if config["quantization"] == "qat":
      callbacks.append(QuantizationAwareTraining(input_compatible=False))
  
    # PRUNING:
    #  if doing pruning, add to callbacks
    if config["pruning"]:
      assert config["quantization"] == "none", "cannot combine pruning and quantization"
  
      for prune_config in config["pruning"].values():
        callbacks.append(make_pruner(prune_config, network, n_epochs=config["max_epochs"]))

    # 👟 configure Trainer 
    trainer = pl.Trainer(
      gpus=1, max_epochs=config["max_epochs"], log_every_n_steps=1,
      logger=pl.loggers.WandbLogger(log_model=True, save_code=True),
      callbacks=callbacks,
      progress_bar_refresh_rate=50)
                        
    # 🏃‍♀️ run the Trainer on the model
    trainer.fit(network, dmodule)

    wandb.summary["params"] = network.count_params()
    wandb.summary["nonzero_params"] = count_nonzero(network)

    # STATIC:
    #  if doing static post-training quantization, apply it now
    if config["quantization"] == "post":
      xs, _ = next(iter(dmodule.train_dataloader()))
      network = run_static_quantization(network, xs)  # see below for implementation

    # report metrics to wandb
    wandb.summary["size_mb"] = print_model_size(network)

  return network

## Loading the Data, Building the Model, and Running Training

In [None]:
config = {
  "batch_size": 256,
  "max_epochs": 10,
  "activation": torch.nn.ReLU,
  "loss": torch.nn.NLLLoss(),
  "optimizer": torch.optim.Adam,
  "optimizer.params": {"lr": 0.001},
  "quantization": "none",  # "none" or "qat" or "post"
  # pruning is configured at the end of this cell
}

## Loading the data

dmodule = FERDataModule(batch_size=config["batch_size"])
dmodule.prepare_data()

# for debugging purposes (checking shapes, etc.), make these available
dloader = dmodule.train_dataloader()  # set up the Loader

example_batch = next(iter(dloader))  # grab a batch from the Loader
example_x, example_y = example_batch[0].to("cuda"), example_batch[1].to("cuda")

print(f"Input Shape: {example_x.shape}")
print(f"Target Shape: {example_y.shape}")

## Building the model

lec = LitEmotionClassifier(config)

lec.to("cuda")
outputs = lec.forward(example_x)
print(f"Output Shape: {outputs.shape}")

## Pruning

config["pruning"] = {}  # see wandb.me/lit-prune-colab for more examples
global_prune_config = {  # config for applying pruning to the entire network
  "parameters": None,
  "pruning_fn": "l1_unstructured",
  "target_sparsity": 0.9,  # target sparsity level for this pruner
  "use_global_unstructured": True,
  "pruning_dim": None,
  "pruning_norm": None,
}
# config["pruning"]["global"] = global_prune_config  # comment to remove global pruning

## Training
lec = train(lec, dmodule, config)