<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/cnn/architecture_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases">

# Image Classification Architecture Search on CIFAR-10

This notebook allows you to join an on-going, parallelized architecture
search for image classification (specifically, the
[CIFAR-10 labeling task](https://www.cs.toronto.edu/~kriz/cifar.html)),
just by executing the cells.
Read the instructions at the bottom to see how to launch a search of your own.

> Note that [Colab restricts GPU usage](https://research.google.com/colaboratory/faq.html),
especially when run non-interactively,
so if you leave this notebook running
for more than a few hours in a short period of time,
you're likely to see your access curtailed
unless you're a paid user.

The cells below define a way to sample and train a random architecture that combines 0 or more convolutional layers and 0 or more fully-connected layers,
followed by a fully-connected classifier,
using a provided random seed for reproducibility. 

The results of this architecture search will be logged to
[this Weights & Biases dashboard](https://wandb.ai/wandb/archsearch-cifar10/sweeps/bmhxqxr0),
where you can see which architectures perform best and search for patterns
in the submitted runs.

# Installs and Imports

In [None]:
%%capture
!pip install -qqq pytorch_lightning torchviz wandb

repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
utils_path = "lightning/utils.py"
# Download a util file of helper methods
!curl {repo_url + utils_path} > utils.py

In [None]:
import pytorch_lightning as pl

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

import wandb

import utils

In [None]:
!wandb login

# Dataset

In [None]:
from math import floor

class CIFAR10DataModule(pl.LightningDataModule):
  """Dataloaders and setup for the CIFAR10 dataset.
  """

  def __init__(self, batch_size, train_size=0.8, debug=False):
    """

    Arguments:
    batch_size: int. Size of batches in training, validation, and test
    train_size: int or float. If int, number of examples in training set,
                If float, fraction of examples in training set.
    debug:  bool. If True, cut dataset size by a factor of 10.
    """
    super().__init__()

    self.data_dir = "./data"
    self.seed = 117

    self.train_size = train_size
    self.batch_size = batch_size 
    self.debug = debug

    self.transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  def prepare_data(self):
    """Download the dataset.
    """
    torchvision.datasets.CIFAR10(root=self.data_dir, train=True,
                                 download=True, transform=self.transform)

    torchvision.datasets.CIFAR10(root=self.data_dir, train=False,
                                 download=True, transform=self.transform)

  def setup(self, stage=None):
    """Set up training and test data and perform our train/val split.
    """
    if stage in (None, "fit"):
      cifar10_full = torchvision.datasets.CIFAR10(self.data_dir, train=True,
                                                  transform=self.transform)
      if self.debug:                                            
        cifar10_full.data = cifar10_full.data[::10]
        cifar10_full.targets = cifar10_full.labels[::10]

      total_size, *self.dims = cifar10_full.data.shape
      train_size, val_size = self.get_split_sizes(self.train_size, total_size)

      split_generator = torch.Generator().manual_seed(self.seed)
      self.train, self.val = torch.utils.data.random_split(
          cifar10_full, [train_size, val_size], split_generator)

    if stage in (None, "test"):
      self.test = torchvision.datasets.CIFAR10(self.data_dir, train=False,
                                               transform=self.transform)


  def train_dataloader(self):
    trainloader = torch.utils.data.DataLoader(self.train, batch_size=self.batch_size,
                                              shuffle=True, num_workers=2, pin_memory=True)
    return trainloader

  def val_dataloader(self):
    valloader = torch.utils.data.DataLoader(self.val, batch_size=self.batch_size,
                                            shuffle=False, num_workers=2, pin_memory=True)
    return valloader

  def test_dataloader(self):
    testloader = torch.utils.data.DataLoader(self.test, batch_size=self.batch_size,
                                             shuffle=False, num_workers=2, pin_memory=True)
    return testloader

  @staticmethod
  def get_split_sizes(train_size, total_size):
    if isinstance(train_size, float):
      train_size = floor(total_size * train_size)

    if isinstance(train_size, int):
      val_size = total_size - train_size

    return train_size, val_size

# Network

In [None]:
###
# Shape Handling and Inference
###

# when building a random architecture, we have to take care to track the shapes
#  programmatically

def sequential_output_shape(self, h_w):
  """Utility function for computing the output shape of a torch.nn.Sequential"""
  for element in self:
    try:
      h_w = element.output_shape(h_w)
    except AttributeError:  # optimistically assume any layer without the method doesn't change shape
      pass
  
  return h_w


def sequential_feature_dim(self):

  for element in reversed(self):
    try:
      feature_dim = element.feature_dim()
      if feature_dim is not None:
        return feature_dim
    except AttributeError:
      pass


def conv2d_output_shape(self, h_w):
  """Utility function for computing output shape of 2d convolutional operators."""

  props = self.kernel_size, self.stride, self.padding, self.dilation  # grab operator properties
  props = [tuple((p, p)) if not isinstance(p, tuple) else p for p in props]  # diagonalize into tuples as needed
  props = list(zip(*props))  # "transpose" operator properties -- list indices are height/width rather than property id

  h = conv1d_output_shape(h_w[0], *props[0])  # calculate h from height parameters of props
  w = conv1d_output_shape(h_w[1], *props[1])  # calculate w from width parameters of props

  assert (h > 0) & (w > 0), "Invalid parameters"

  return h, w


def conv1d_output_shape(lngth, kernel_size, stride, padding, dilation):
  """Computes the change in dimensions for a 1d convolutional operator."""
  return floor( ((lngth + (2 * padding) - ( dilation * (kernel_size - 1) ) - 1 )/ stride) + 1)


torch.nn.AdaptiveAvgPool2d.output_shape = lambda self, h_w: self.output_size
torch.nn.Linear.output_shape = lambda self, inp: self.out_features
torch.nn.Conv2d.output_shape = conv2d_output_shape
torch.nn.MaxPool2d.output_shape = conv2d_output_shape
torch.nn.Sequential.output_shape = sequential_output_shape

torch.nn.Linear.feature_dim = lambda self: self.out_features
torch.nn.Conv2d.feature_dim = lambda self: self.out_channels
torch.nn.Sequential.feature_dim = sequential_feature_dim

In [None]:
class CNN(utils.LoggedImageClassifierModule):
  """A simple CNN Model, with under-the-hood wandb
  and pytorch-lightning features (logging, metrics, etc.).
  """

  def __init__(self, labels, config):
    super().__init__(labels=labels)

    self.loss = torch.nn.CrossEntropyLoss()

    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]

    self.input_channels = 3
    self.num_classes = 10

    self.resizing_shape = (128, 128)
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(self.resizing_shape)

    # Build conv body 
    conv_config = filter_to_subconfig(config, "conv")
    self.conv = build_conv_from_config(
        conv_config, self.input_channels)

    # Infer shape of Conv -> FC transtion
    self.conv_feature_dim = self.conv.feature_dim()
    self.final_shape = self.conv.output_shape(self.resizing_shape)
    self.final_size = self.final_shape[0] * self.final_shape[1] * self.conv_feature_dim

    # Build FC block
    fc_config = filter_to_subconfig(config, "fc")
    self.classifier = build_fc_from_config(
        fc_config, self.final_size)

    # Add classifier head
    self.classifier.add_module("classification",  # handle empty linear case
        torch.nn.Linear(self.classifier.output_shape(self.final_size), self.num_classes))

  def forward(self, xs):
    xs = self.resize_layer(xs)

    xs = self.conv(xs)

    xs = xs.view(-1, self.final_size)

    xs = self.classifier(xs)

    return xs

  def configure_optimizers(self):
    return self.optimizer(self.parameters(), **self.optimizer_params)

##
# Building Networks from Configuration Dictionaries
##

# This section defines the logic for building modules from a configuration
#  and for hooking them together

def build_conv_from_config(config, in_channels):
  conv = []
  for block in range(config["n_blocks"]):
    block_config = config[f"block_{block}"]
    conv_block = build_block_from_config(block_config, in_channels)
    in_channels = conv_block.feature_dim()
    conv.append(conv_block)

  conv = torch.nn.Sequential(*conv)
  conv.feature_dim = lambda : in_channels

  return conv


def build_fc_from_config(fc_config, in_features):
  fc = []
  for layer in range(fc_config["n_layers"]):
    layer_config = fc_config[f"layer_{layer}"]
    fc_layer = torch.nn.Linear(in_features=in_features, **layer_config)
    in_features = fc_layer.out_features
    fc.append(fc_layer)
    if fc_config["batchnorm_pre"]:
      fc.append(torch.nn.BatchNorm1d(in_features))
    fc.append(fc_config["activation"]())
    if fc_config["batchnorm"] and not fc_config["batchnorm_pre"]:
      fc.append(torch.nn.BatchNorm1d(in_features))
    if fc_config["dropout"]:
      fc.append(torch.nn.Dropout(fc_config["dropout"]))

  fc = torch.nn.Sequential(*fc)
  fc.feature_dim = lambda : in_features

  return fc


def build_block_from_config(block_config, in_channels):
  conv_block = []
  for layer in range(block_config["n_convs"]):
    conv_config = block_config[f"layer_{layer}"]
    conv = torch.nn.Conv2d(in_channels, **conv_config)
    in_channels = conv.out_channels
    conv_block.append(conv)
    if block_config["batchnorm_pre"]:
      conv_block.append(torch.nn.BatchNorm2d(in_channels))
    conv_block.append(block_config["activation"]())
    if block_config["batchnorm"] and not block_config["batchnorm_pre"]:
      conv_block.append(torch.nn.BatchNorm2d(in_channels))
    if block_config["dropout"]:
      conv_block.append(torch.nn.Dropout2d(block_config["dropout"]))

  conv_block = torch.nn.Sequential(*conv_block)
  conv_block.feature_dim = lambda : in_channels
  return conv_block


def filter_to_subconfig(config, prefix):
  return config[prefix] 

In [None]:
###
# Generating a Random Architecture from a Fixed Seed
###

# This section maps a seed value to an architecture.
#  The seed can be any valid Python seed; the public sweep uses integers.

import random
def randbool(): return bool(random.randint(0, 1))

def generate_random_config(seed):
  p_batchnorm = 0.67
  max_dropout = 0.5
  random.seed(seed)

  config = {}
  config["conv"], config["fc"] = {}, {}

  config["conv"]["n_blocks"] = random.choice([0, 1, 1, 2, 2, 4, 4, 4])
  config["fc"]["n_layers"] = random.choice([0, 1, 2, 2, 2, 4, 8])

  config["conv"]["batchnorm"] = random.random() < p_batchnorm
  config["conv"]["batchnorm_pre"] = randbool() if config["conv"]["batchnorm"] else None

  config["fc"]["batchnorm"] = random.random() < p_batchnorm
  config["fc"]["batchnorm_pre"] = randbool() if config["fc"]["batchnorm"] else None

  config["fc"]["dropout"] = random.random() * max_dropout if randbool() else None
  config["conv"]["dropout"] = random.random() * max_dropout if randbool() else None

  config["conv"]["activation"] = random.choice([torch.nn.ReLU, torch.nn.GELU, torch.nn.Sigmoid, torch.nn.SiLU])
  config["fc"]["activation"] = random.choice([torch.nn.ReLU, torch.nn.GELU, torch.nn.Sigmoid, torch.nn.SiLU])

  for block in range(config["conv"]["n_blocks"]):
    block_config = generate_random_conv_block_config(shared_config=config["conv"], index=block)
    config["conv"][f"block_{block}"] = block_config

  for layer in range(config["fc"]["n_layers"]):
    layer_config = generate_random_fc_layer_config(shared_config=config["fc"])
    config["fc"][f"layer_{layer}"] = layer_config

  return config


def generate_random_conv_block_config(shared_config, index):
  block_config = {}
  block_config["activation"] = shared_config["activation"]
  block_config["batchnorm"], block_config["batchnorm_pre"] = shared_config["batchnorm"], shared_config["batchnorm_pre"]
  block_config["dropout"] = shared_config["dropout"]

  block_config["n_convs"] = random.randint(1, 2)
  block_config["n_channels"] = random.choice([16, 32, 128])

  for layer in range(block_config["n_convs"]):
    block_config[f"layer_{layer}"] = generate_random_conv_config(n_channels=block_config["n_channels"])
    
  return block_config


def generate_random_fc_layer_config(shared_config):
  fc_layer_config = {}
  fc_layer_config["out_features"] = random.choice([16, 32, 128])
  return fc_layer_config


def generate_random_conv_config(n_channels):
  conv_config = {}
  conv_config["out_channels"] = n_channels
  conv_config["kernel_size"] = generate_random_tuple_diag_bias(lambda : random.choice([1, 3, 3, 3, 5, 7]))
  conv_config["stride"] = generate_random_tuple_diag_bias(lambda : random.choice([1, 1, 1, 2, 2, 3]))
  conv_config["dilation"] = generate_random_tuple_diag_bias(lambda : random.choice([1, 1, 1, 1, 2]))

  return conv_config


def generate_random_tuple_diag_bias(sampler):
  tupl = sampler()
  if randbool():
    tupl = (tupl, sampler())
  else:
    tupl = (tupl, tupl)

  return tupl

# Define Training Function

In [None]:
def train():
  labels = ["airplane", "automobile", "bird", "cat", "deer",
            "dog", "frog", "horse", "ship", "truck"]

  with wandb.init() as run:

    wandb.config.update({"seed": 117})
    config = generate_random_config(wandb.config.seed)
    config.update({
        "optimizer": torch.optim.Adam,
        "optimizer.params": {"lr": 0.0003},
        "batch_size": 128,
        "max_epochs": 2,
    })
    wandb.config.update(config)

    dm = CIFAR10DataModule(batch_size=config["batch_size"])
    cnn = CNN(labels, config)
    
    # logs the input weights to Weights & Biases
    filter_logger = utils.FilterLogCallback(image_size=(3,) + cnn.resizing_shape,
                                            log_input=True, log_output=False)
  
    # 👟 configure Trainer 
    trainer = pl.Trainer(gpus=1,  # use the GPU for .forward
                        logger=pl.loggers.WandbLogger(
                          log_model=True, save_code=True),  # log to Weights & Biases
                        max_epochs=config["max_epochs"], log_every_n_steps=1,
                        callbacks=[filter_logger],
                        progress_bar_refresh_rate=50)
                        
    # 🏃‍♀️ run the Trainer on the model
    trainer.fit(cnn, dm)

# Join the Parallel Architecture Search

Execute this cell to start running an "agent" that can participate
in the architecture search.

The results from the large public sweep for this project are
[here](https://wandb.ai/wandb/archsearch-cifar10/sweeps/bmhxqxr0).

> Note: this cell will run forever unless stopped.
If you leave the notebook running for longer than an hour or two,
it will be automatically shut down by Google and,
especially if this occurs more than once in a short time,
you may see your access to Colab GPUs restricted.
To avoid this, change the `count` argument to an integer,
somewhere near `20`, and the cell will finish running in 20 - 30 minutes,
after executing that many training runs.

You can also launch your own personal version of this search.
Skip over this cell and read the cells following.

In [None]:
wandb.agent(sweep_id="bmhxqxr0", function=train, count=None,
            entity="wandb", project="archsearch-cifar10")
# default id bmhxqxr0 for the public sweep 

# Sweep Init

To start up a separate
architecture search of your own,
run the following two cells
and then change the `sweep_id` in the cell above
to the output of the second cell below
before executing it.

You can change the `entity` from `wandb` to your username
if you want the sweep to be among your personal (and optionally private)
projects.
You'll want to make the same change in the `wandb.agent` cell as well.

In [None]:
sweep_config = {"method": "random",
                "metric": "validation/accuracy",
                "goal": "maximize",
                "parameters": {
                  "seed":{ 
                    "distribution": "int_uniform",
                    "min": 0,
                    "max": 10000000
                    }
                  },
                }

In [None]:
sweep_id = wandb.sweep(sweep_config, entity="wandb", project="archsearch-cifar10")