# [0.3] Optimization & Hyperparameters (exercises)

> **ARENA [Streamlit Page](https://arena-chapter0-fundamentals.streamlit.app/03_[0.3]_Optimization)**
>
> **Colab: [exercises](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter0_fundamentals/exercises/part3_optimization/0.3_Optimization_exercises.ipynb?t=20250527) | [solutions](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter0_fundamentals/exercises/part3_optimization/0.3_Optimization_solutions.ipynb?t=20250527)**

Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-uk/shared_invite/zt-39iwnhbj4-pMWUvZkkt2wpvaxkvJ0q2rRQ), and ask any questions on the dedicated channels for this chapter of material.

You can collapse each section so only the headers are visible, by clicking the arrow symbol on the left hand side of the markdown header cells.

Links to all other chapters: [(0) Fundamentals](https://arena-chapter0-fundamentals.streamlit.app/), [(1) Transformer Interpretability](https://arena-chapter1-transformer-interp.streamlit.app/), [(2) RL](https://arena-chapter2-rl.streamlit.app/).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/headers/header-03.png" width="350">

# Introduction

In today's exercises, we will explore various optimization algorithms and their roles in training deep learning models. We will delve into the inner workings of different optimization techniques such as Stochastic Gradient Descent (SGD), RMSprop, and Adam, and learn how to implement them using code. Additionally, we will discuss the concept of loss landscapes and their significance in visualizing the challenges faced during the optimization process. By the end of this set of exercises, you will have a solid understanding of optimization algorithms and their impact on model performance. We'll also take a look at Weights and Biases, a tool that can be used to track and visualize the training process, and test different values of hyperparameters to find the most effective ones.

> Note - the third set of exercises in this section are on distributed training, and have different requirements: specifically, you'll need to SSH into a virtual machine which has multiple GPUs, and run the exercises from a Python file (not notebook or Colab). However you can still treat the first 2 sections as normal and then make this switch for the third section.

## Content & Learning Objectives

### 1️⃣ Optimizers

These exercises will take you through how different optimization algorithms work (specifically SGD, RMSprop and Adam). You'll write your own optimisers, and use plotting functions to visualise gradient descent on loss landscapes.

> ##### Learning Objectives
>
> * Understand how different optimization algorithms work
> * Translate pseudocode for these algorithms into code
> * Understand the idea of loss landscapes, and how they can visualize specific challenges in the optimization process

### 2️⃣ Weights and Biases

In this section, we'll look at methods for choosing hyperparameters effectively. You'll learn how to use **Weights and Biases**, a useful tool for hyperparameter search. By the end of today, you should be able to use Weights and Biases to train the ResNet you created in the last set of exercises.

> ##### Learning Objectives
>
> * Write modular, extensible code for training models
> * Learn what the most important hyperparameters are, and methods for efficiently searching over hyperparameter space
> * Learn how to use Weights & Biases for logging your runs
> * Adapt your code from yesterday to log training runs to Weights & Biases, and use this service to run **hyperparameter sweeps**

### 3️⃣ Distributed Training

In this section, we'll take you through the basics of distributed training, which is the process via which training is split over multiple separate GPUs to improve efficiency and capacity.

> ##### Learning Objectives
>
> * Understand the different kinds of parallelization used in deep learning (data, pipeline, tensor)
> * Understand how primitive operations in `torch.distributed` work, and how they come together to enable distributed training
> * Launch and benchmark your own distributed training runs, to train your implementation of `ResNet34` from scratch

### 4️⃣ Bonus

This section gives you suggestions for further exploration of optimizers, and Weights & Biases.

## Setup code

In [2]:
import os
import sys
from pathlib import Path

IN_COLAB = "google.colab" in sys.modules

chapter = "chapter0_fundamentals"
repo = "ARENA_3.0"
branch = "main"

# Install dependencies
try:
    import jaxtyping
except:
    %pip install einops jaxtyping torchinfo wandb

# Get root directory, handling 3 different cases: (1) Colab, (2) notebook not in ARENA repo, (3) notebook in ARENA repo
root = (
    "/content"
    if IN_COLAB
    else "/root"
    if repo not in os.getcwd()
    else str(next(p for p in Path.cwd().parents if p.name == repo))
)

if Path(root).exists() and not Path(f"{root}/{chapter}").exists():
    if not IN_COLAB:
        !sudo apt-get install unzip
        %pip install jupyter ipython --upgrade

    if not os.path.exists(f"{root}/{chapter}"):
        !wget -P {root} https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/{branch}.zip
        !unzip {root}/{branch}.zip '{repo}-{branch}/{chapter}/exercises/*' -d {root}
        !mv {root}/{repo}-{branch}/{chapter} {root}/{chapter}
        !rm {root}/{branch}.zip
        !rmdir {root}/{repo}-{branch}


assert Path(f"{root}/{chapter}/exercises").exists(), "Unexpected error: please manually clone ARENA repo into `root`"

if f"{root}/{chapter}/exercises" not in sys.path:
    sys.path.append(f"{root}/{chapter}/exercises")

os.chdir(f"{root}/{chapter}/exercises")

Collecting jaxtyping
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading jaxtyping-0.3.2-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Downloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, torchinfo, jaxtyping
Successfully installed jaxtyping-0.3.2 torchinfo-1.8.0 wadler-lindig-0.1.7
--2025-07-14 18:25:28--  https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https

In [4]:
import importlib
import os
import sys
import time
from dataclasses import dataclass, replace
from pathlib import Path
from typing import Callable, Iterable, Literal

import numpy as np
import torch as t
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
import wandb
from IPython.core.display import HTML
from IPython.display import display
from jaxtyping import Float, Int
from torch import Tensor, optim
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms
from tqdm import tqdm

# Make sure exercises are in the path
chapter = "chapter0_fundamentals"
section = "part3_optimization"
root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())
exercises_dir = root_dir / chapter / "exercises"
section_dir = exercises_dir / section
if str(exercises_dir) not in sys.path:
    sys.path.append(str(exercises_dir))

from torch.optim import Adam, AdamW, SGD, RMSprop

import part3_optimization.tests as tests
from part2_cnns.solutions import Linear, ResNet34, get_resnet_for_feature_extraction
from part3_optimization.utils import plot_fn, plot_fn_with_points
from plotly_utils import bar, imshow, line

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

<details>
<summary>Help - I get a NumPy-related error</summary>

This is an annoying colab-related issue which I haven't been able to find a satisfying fix for. If you restart runtime (but don't delete runtime), and run just the imports cell above again (but not the `%pip install` cell), the problem should go away.
</details>

# 1️⃣ Optimizers

> ##### Learning Objectives
>
> * Understand how different optimization algorithms work
> * Translate pseudocode for these algorithms into code
> * Understand the idea of loss landscapes, and how they can visualize specific challenges in the optimization process

## Reading

Some of these are strongly recommended, while others are optional. If you like, you can jump back to some of these videos while you're going through the material, if you feel like you need to.

* Andrew Ng's video series on gradient descent variants: [Gradient Descent With Momentum](https://www.youtube.com/watch?v=k8fTYJPd3_I) (9 mins), [RMSProp](https://www.youtube.com/watch?v=_e-LFe_igno) (7 mins), [Adam](https://www.youtube.com/watch?v=JXQT_vxqwIs&list=PLkDaE6sCZn6Hn0vK8co82zjQtt3T2Nkqc&index=23) (7 mins)
    * These videos are strongly recommended, especially the RMSProp video
* [A Visual Explanation of Gradient Descent Methods](https://medium.com/towards-data-science/a-visual-explanation-of-gradient-descent-methods-momentum-adagrad-rmsprop-adam-f898b102325c)
    * This is also strongly recommended; if you only want to read/watch one thing, make it this
* [Why Momentum Really Works (distill.pub)](https://distill.pub/2017/momentum/)
    * This is optional, but a fascinating read if you have time and are interested in engaging with the mathematical details of optimization

## Gradient Descent

Tomorrow, we'll look in detail about how the backpropagation algorithm works. But for now, let's take it as read that calling `loss.backward()` on a scalar `loss` will result in the computation of the gradients $\frac{\partial loss}{\partial w}$ for every parameter `w` in the model, and store these values in `w.grad`. How do we use these gradients to update our parameters in a way which decreases loss?

A loss function can be any differentiable function such that we prefer a lower value. To apply gradient descent, we start by initializing the parameters to random values (the details of this are subtle), and then repeatedly compute the gradient of the loss with respect to the model parameters. It [can be proven](https://tutorial.math.lamar.edu/Classes/CalcIII/DirectionalDeriv.aspx) that for an infinitesimal step, moving in the direction of the gradient would increase the loss by the largest amount out of all possible directions.

We actually want to decrease the loss, so we subtract the gradient to go in the opposite direction. Taking infinitesimal steps is no good, so we pick some learning rate $\lambda$ (also called the step size) and scale our step by that amount to obtain the update rule for gradient descent:

$$\theta_t \leftarrow \theta_{t-1} - \lambda \nabla L(\theta_{t-1})$$

We know that an infinitesimal step will decrease the loss, but a finite step will only do so if the loss function is linear enough in the neighbourhood of the current parameters. If the loss function is too curved, we might actually increase our loss.

The biggest advantage of this algorithm is that for N bytes of parameters, you only need N additional bytes of memory to store the gradients, which are of the same shape as the parameters. GPU memory is very limited, so this is an extremely relevant consideration. The amount of computation needed is also minimal: one multiply and one add per parameter.

The biggest disadvantage is that we're completely ignoring the curvature of the loss function, not captured by the gradient consisting of partial derivatives. Intuitively, we can take a larger step if the loss function is flat in some direction or a smaller step if it is very curved. Generally, you could represent this by some matrix P that pre-multiplies the gradients to rescale them to account for the curvature. $P$ is called a preconditioner, and gradient descent is equivalent to approximating $P$ by an identity matrix, which is a very bad approximation.

Most competing optimizers can be interpreted as trying to do something more sensible for $P$, subject to the constraint that GPU memory is at a premium. In particular, constructing $P$ explicitly is infeasible, since it's an $N \times N$ matrix and N can be hundreds of billions. One idea is to use a diagonal $P$, which only requires N additional memory. An example of a more sophisticated scheme is [Shampoo](https://arxiv.org/pdf/1802.09568.pdf).

> The algorithm is called **Shampoo** because you put shampoo on your hair before using conditioner, and this method is a pre-conditioner.
>     
> If you take away just one thing from this entire curriculum, please don't let it be this.

## Stochastic Gradient Descent

The terms gradient descent and SGD are used loosely in deep learning. To be technical, there are three variations:

- Batch gradient descent - the loss function is the loss over the entire dataset. This requires too much computation unless the dataset is small, so it is rarely used in deep learning.
- Stochastic gradient descent - the loss function is the loss on a randomly selected example. Any particular loss may be completely in the wrong direction of the loss on the entire dataset, but in expectation it's in the right direction. This has some nice properties but doesn't parallelize well, so it is rarely used in deep learning.
- Mini-batch gradient descent - the loss function is the loss on a batch of examples of size `batch_size`. This is the standard in deep learning.

The class `torch.optim.SGD` can be used for any of these by varying the number of examples passed in. We will be using only mini-batch gradient descent in this course.

## Batch Size

In addition to choosing a learning rate or learning rate schedule, we need to choose the batch size or batch size schedule as well. Intuitively, using a larger batch means that the estimate of the gradient is closer to that of the true gradient over the entire dataset, but this requires more compute. Each element of the batch can be computed in parallel so with sufficient compute, one can increase the batch size without increasing wall-clock time. For small-scale experiments, a good heuristic is thus "fill up all of your GPU memory".

At a larger scale, we would expect diminishing returns of increasing the batch size, but empirically it's worse than that - a batch size that is too large generalizes more poorly in many scenarios. The intuition that a closer approximation to the true gradient is always better is therefore incorrect. See [this paper](https://arxiv.org/pdf/1706.02677.pdf) for one discussion of this.

For a batch size schedule, most commonly you'll see batch sizes increase over the course of training. The intuition is that a rough estimate of the proper direction is good enough early in training, but later in training it's important to preserve our progress and not "bounce around" too much.

You will commonly see batch sizes that are a multiple of 32. One motivation for this is that when using CUDA, threads are grouped into "warps" of 32 threads which execute the same instructions in parallel. So a batch size of 64 would allow two warps to be fully utilized, whereas a size of 65 would require waiting for a third warp to finish. As batch sizes become larger, this wastage becomes less important.

Powers of two are also common - the idea here is that work can be recursively divided up among different GPUs or within a GPU. For example, a matrix multiplication can be expressed by recursively dividing each matrix into four equal blocks and performing eight smaller matrix multiplications between the blocks.

In tomorrow's exercises, you'll have the option to expore batch sizes in more detail.

## Common Themes in Gradient-Based Optimizers

### Weight Decay

Weight decay means that on each iteration, in addition to a regular step, we also shrink each parameter very slightly towards 0 by multiplying a scaling factor close to 1, e.g. 0.9999. Empirically, this seems to help but there are no proofs that apply to deep neural networks.

In the case of linear regression, weight decay is mathematically equivalent to having a prior that each parameter is Gaussian distributed - in other words it's very unlikely that the true parameter values are very positive or very negative. This is an example of "**inductive bias**" - we make an assumption that helps us in the case where it's justified, and hurts us in the case where it's not justified.

For a `Linear` layer, it's common practice to apply weight decay only to the weight and not the bias. It's also common to not apply weight decay to the parameters of a batch normalization layer. Again, there is empirical evidence (such as [Jai et al 2018](https://arxiv.org/pdf/1807.11205.pdf)) and there are heuristic arguments to justify these choices, but no rigorous proofs. Note that PyTorch will implement weight decay on the weights *and* biases of linear layers by default - see the bonus exercises tomorrow for more on this.

### Momentum

Momentum means that the step includes a term proportional to a moving average of past gradients. [Distill.pub](https://distill.pub/2017/momentum/) has a great article on momentum, which you should definitely read if you have time. Don't worry if you don't understand all of it; skimming parts of it can be very informative. For instance, the first half discusses the **conditioning number** (a very important concept to understand in optimisation), and concludes by giving an intuitive argument for why we generally set the momentum parameter close to 1 for ill-conditioned problems (those with a very large conditioning number).

## Visualising optimization with pathological curvatures

A pathological curvature is a type of surface that is similar to ravines and is particularly tricky for plain SGD optimization. In words, pathological curvatures typically have a steep gradient in one direction with an optimum at the center, while in a second direction we have a slower gradient towards a (global) optimum. Let’s first create an example surface of this and visualize it. The code below creates 2 visualizations (3D and 2D) and also adds the minimum point to the plot (note this is the min in the visible region, not the global minimum).

In [None]:
def pathological_curve_loss(x: Tensor, y: Tensor):
    # Example of a pathological curvature. There are many more possible, feel free to experiment here!
    x_loss = t.tanh(x) ** 2 + 0.01 * t.abs(x)
    y_loss = t.sigmoid(y)
    return x_loss + y_loss


plot_fn(pathological_curve_loss, min_points=[(0, "y_min")])

In terms of optimization, you can image that `x` and `y` are weight parameters, and the curvature represents the loss surface over the space of `x` and `y`. Note that in typical networks, we have many, many more parameters than two, and such curvatures can occur in multi-dimensional spaces as well.

Ideally, our optimization algorithm would find the center of the ravine and focuses on optimizing the parameters towards the direction of `y`. However, if we encounter a point along the ridges, the gradient is much greater in `x` than `y`, and we might end up jumping from one side to the other. Due to the large gradients, we would have to reduce our learning rate slowing down learning significantly.

To test our algorithms, we can implement a simple function to train two parameters on such a surface.

### Exercise - implement `opt_fn_with_sgd`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵🔵⚪
>
> You should spend up to 15-20 minutes on this exercise.
> ```

Implement the `opt_fn_with_sgd` function using `torch.optim.SGD`. This function optimizes parameters `(x, y)` (which represent coordinates at which we evaluate a function) using gradient descent on that function value. In other words, this should look just like your optimization loops in previous days' material, except rather than passing in `model.parameters()` to your optimizer, you pass in `(xy,)` (because it needs to be an iterable of parameters, not just a single parameter).

Remember, your update steps `optimizer.step()` will automatically change the values of `xy` inplace - this means that you shouldn't store past values like `xy_list.append(xy)` because then past elements of that list will be modified when `xy` is updated. Instead, you should use something like `xy_list.append(xy.detach().clone())` to make sure you're returning a copy of the tensor, which won't continue to be modified.

We've also provided you with a function `plot_fn_with_points`, which plots a function as well as a list of points produced by functions like the one above. The code below starts from `(2.5, 2.5)` and adds the resulting trajectory of `(x, y)` coordinates to the contour plot. Does it find the minimum? Play with the learning rate and momentum a bit and see how close you can get within 100 iterations.

In [None]:
def opt_fn_with_sgd(
    fn: Callable, xy: Float[Tensor, "2"], lr=0.001, momentum=0.98, n_iters: int = 100
) -> Float[Tensor, "n_iters 2"]:
    """
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.
    lr, momentum: parameters passed to the torch.optim.SGD optimizer.

    Return: (n_iters+1, 2). The (x, y) values, from initial values pre-optimization to values after step `n_iters`.
    """
    # Make sure tensor has requires_grad=True, otherwise it can't be optimized (more on this tomorrow!)
    assert xy.requires_grad

    raise NotImplementedError()


points = []

optimizer_list = [
    (optim.SGD, {"lr": 0.1, "momentum": 0.0}),
    (optim.SGD, {"lr": 0.02, "momentum": 0.99}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn_with_sgd(pathological_curve_loss, xy=xy, lr=params["lr"], momentum=params["momentum"])
    points.append((xys, optimizer_class, params))
    print(f"{params=}, last point={xys[-1]}")

plot_fn_with_points(pathological_curve_loss, points=points, min_points=[(0, "y_min")])

<details>
<summary>Help - I'm not sure if my <code>opt_fn_with_sgd</code> is implemented properly.</summary>

With a learning rate of `0.02` and momentum of `0.99`, my SGD was able to reach `[ 0.8110, -6.3344]` after 100 iterations.
</details>

<details>
<summary>Help - I'm getting <code>Can't call numpy() on Tensor that requires grad</code>.</summary>

This is a protective mechanism built into PyTorch. The idea is that once you convert your Tensor to NumPy, PyTorch can no longer track gradients, but you might not understand this and expect backprop to work on NumPy arrays.

All you need to do to convince PyTorch you're a responsible adult is to call `detach()` on the tensor first, which returns a view that does not require grad and isn't part of the computation graph.
</details>


<details><summary>Solution</summary>

```python
def opt_fn_with_sgd(
    fn: Callable, xy: Float[Tensor, "2"], lr=0.001, momentum=0.98, n_iters: int = 100
) -> Float[Tensor, "n_iters 2"]:
    """
    Optimize the a given function starting from the specified point.

    xy: shape (2,). The (x, y) starting point.
    n_iters: number of steps.
    lr, momentum: parameters passed to the torch.optim.SGD optimizer.

    Return: (n_iters+1, 2). The (x, y) values, from initial values pre-optimization to values after step `n_iters`.
    """
    # Make sure tensor has requires_grad=True, otherwise it can't be optimized (more on this tomorrow!)
    assert xy.requires_grad

    optimizer = optim.SGD((xy,), lr=lr, momentum=momentum)

    xy_list = [xy.detach().clone()]  # so that we don't unintentionally modify past values in `xy_list`

    for i in range(n_iters):
        fn(xy[0], xy[1]).backward()
        optimizer.step()
        optimizer.zero_grad()
        xy_list.append(xy.detach().clone())

    return t.stack(xy_list)
```
</details>

## Build Your Own Optimizers

Now let's build our own drop-in replacement for these three classes from `torch.optim`. For each of the exercises you'll have to translate pseudocode that we give you into actual code. If you want an extra challenge, you can try and work directly from the pseudocode in the PyTorch documentation page rather than what we give you.

> **A warning regarding in-place operations**
>
> Be careful with expressions like `x = x + y` and `x += y`. They are NOT equivalent in Python.
>
> - The first one allocates a new `Tensor` of the appropriate size and adds `x` and `y` to it, then rebinds `x` to point to the new variable. The original `x` is not modified.
> - The second one modifies the storage referred to by `x` to contain the sum of `x` and `y` - it is an "in-place" operation. `x.add_(y)` and `torch.add(x, y, out=x)` also work the same way.
>
> Another example: if `x` and `y` are the same shape, then `x = y` won't change the value of `x` inplace, but `x.copy_(y)` will (i.e. changing its values to the values of `y`).
>
> When you're updating parameters in your network you _should_ use inplace operations (because your `optimizer` was passed an iterable of parameters, and so defining a new parameter value via `theta = theta - step` will take it out of the optimizer's scope - it will continue to point to the old, unmodified version).
>
> However, be careful of using inplace operations where you shouldn't be - you don't want to accidentally do something like modify the gradients manually!

### Exercise - implement SGD

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 25-35 minutes on this exercise.
> This is the first of several exercises like it. The first will probably take the longest.
> ```

First, you should implement stochastic gradient descent. It should be like the [PyTorch version](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD), but assume `nesterov=False`, `maximize=False`, and `dampening=0`. The pseudocode simplifies to:

$
b_0 \leftarrow 0 \\
\text {for } t=1 \text { to } \ldots \text { do } \\
\quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\
\quad\; \text {if } \lambda \neq 0 \\
\quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\
\quad\; \text {if } \mu \neq 0 \\
\quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \\
\quad\;\quad\; g_t \leftarrow b_t \\
\quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t
$

where $\theta_t$ are the parameters, $g_t$ are the gradients (after being modified by operations like weight decay & momentum if necessary), and $b_t$ are the values we track to implement momentum.

<details>
<summary>Derivation of the simplified pseudocode</summary>

We start by removing the "if nesterov" and "if maximize" sections, since we're not using either of those. We also substitute $\tau=0$ since we're not using dampening. This gives us:

$
\text {for } t=1 \text { to } \ldots \text { do } \\
\quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\
\quad\; \text {if } \lambda \neq 0 \\
\quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\
\quad\; \text {if } \mu \neq 0 \\
\quad\;\quad\; \text{if } t>1 \\
\quad\;\quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \\
\quad\;\quad\; else \\
\quad\;\quad\;\quad\; b_t \leftarrow g_t \\
\quad\;\quad\; g_t \leftarrow b_t \\
\quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t
$

Finally, we observe that we can set $b_0 = 0$ and then remove the special case handling of the $t=1$ case, which gives us the pseudocode above.

</details>

You should complete the `step` method below, which implements the algorithm described by the pseudocode above. Note that we've added the `torch.inference_mode` decorator to the `step` method, which is equivalent to using the context manager `with torch.inference_mode():`. This is similar to `torch.no_grad`; the difference between them isn't worth getting into here but in general know that `torch.inference_mode` is mostly preferred.

The configurations used during `tests.test_sgd` will start simple (e.g. all parameters set to zero except `lr`) and gradually move to more complicated ones. This will help you track exactly where in your model the error is coming from.


You should also read the `__init__` and `zero_grad` methods, making sure you understand how these work and what they are doing. Note that setting `grad=None` like the code below is treated as equivalent to setting `grad` equal to a tensor of zeros, i.e. the first time we're required to do an operation on the gradient it'll be replaced with this. Making it be `None` by default is the standard, so as to not use unnecessary memory.

In [None]:
class SGD:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
    ):
        """Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        self.params = list(params)  # turn params into a list (it might be a generator, so iterating over it empties it)
        self.lr = lr
        self.mu = momentum
        self.lmda = weight_decay

        self.b = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        """Zeros all gradients of the parameters in `self.params`."""
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        """Performs a single optimization step of the SGD algorithm."""
        raise NotImplementedError()

    def __repr__(self) -> str:
        return f"SGD(lr={self.lr}, momentum={self.mu}, weight_decay={self.lmda})"


tests.test_sgd(SGD)

<details><summary>Solution</summary>

```python
class SGD:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
    ):
        """Implements SGD with momentum.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        self.params = list(params)  # turn params into a list (it might be a generator, so iterating over it empties it)
        self.lr = lr
        self.mu = momentum
        self.lmda = weight_decay

        self.b = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        """Zeros all gradients of the parameters in `self.params`."""
        for param in self.params:
            param.grad = None

    @t.inference_mode()
    def step(self) -> None:
        """Performs a single optimization step of the SGD algorithm."""
        for b, theta in zip(self.b, self.params):
            g = theta.grad
            if self.lmda != 0:
                g = g + self.lmda * theta  # this shouldn't be inplace since we don't want to modify theta.grad
            if self.mu != 0:
                b.copy_(self.mu * b + g)  # this does need to be inplace, since we're modifying the value in `self.b`
                g = b
            theta -= self.lr * g  # inplace operation, to modify params

    def __repr__(self) -> str:
        return f"SGD(lr={self.lr}, momentum={self.mu}, weight_decay={self.lmda})"


tests.test_sgd(SGD)
```
</details>

If you feel comfortable with this implementation, you can skim through the remaining ones, since there's diminishing marginal returns to be gained from doing the actual exercises. We still recommend you read the content on the optimizers before the actual exercises, because they contain useful theory to understand. If you want an extra challenge in the actual exercises, you can try and implement the optimization algorithms directly from the PyTorch documentation pseudocode rather than from the simplified pseudocode we give you.

### RMSProp (and adaptive methods)

From SGD, we'll move onto discussing **adaptive gradient descent methods**. These are methods which automatically adjust the learning rate of each parameter during training, based on the size of gradients at previous steps. In a sense this is similar to how momentum operates in SGD, but we don't tend to describe SGD plus momentum as an adaptive method. When discussing momentum, we usually think of the analogy of a ball rolling down a hill, and the ball's velocity accelerates until it reaches some terminal velocity. The momentum parameter $\mu$ controls the terminal velocity: as $\mu \to 1$ the terminal velocity gets very high, which also means it can take a long time to adjust its speed when it enters new territory. In contrast, adaptive methods are better thought of as deliberate, conscious updates to the learning rate of parameters based on past values. They allow us to speed up when we need to, but without sacrificing our ability to adapt quickly when we enter new regimes.

The first adaptive method we'll look at is **RMSprop**. This is actually the second main adaptive method that was proposed in the optimization literature, after AdaGrad (however the problem with AdaGrad is that it decays the learning rates too quickly - this is the problem that RMSprop solves). RMSprop is similar to SGD, with an added dynamic: **the size of parameter steps are scaled according to the variance of past gradients**, with higher variance leading to smaller steps. Intuitively, if you're in a very monotonic region of the loss landscape then you want to take larger steps (since you know where you're going and you just want to get there quickly), whereas if you're in a very noisy region and possibly oscillating around minima then you want to take smaller steps.

One final note - when we're using non-adaptive methods like SGD we tend to have an inverse relationship between the learning rate and the batch size. Broadly speaking, this is because a larger batch size means our gradients will have smaller variance, and so we can safely use a larger learning rate. This generally isn't necessary for adaptive methods since the learning rates will be adjusted automatically during training based on the variance of our gradients - we don't need to manually scale them ourselves. Most commonly during optimization, we'll start with the default hyperparameters for whatever adaptive optimizer we're using, and then adjust from there.

### Exercise - implement RMSprop

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵⚪⚪⚪
>
> You should spend up to 15-25 minutes on this exercise.
> ```

Below, you should implement RMSprop in the same way as you implemented SGD. The pseudocode is slightly more complicated, since we now have to track 2 variables: $b_t$ for applying the momentum effect, and $v_t$ for tracking the variance of past gradients (we've called these `b` and `v` below).

[Here](https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html) is a link to the PyTorch version, alternatively you can use our simplified pseudocode again:

<details>
<summary>Click here for the simplified pseudocode</summary>

$
b_0 \leftarrow 0 \\
\text {for } t=1 \text { to } \ldots \text { do } \\
\quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\
\quad\; \text {if } \lambda \neq 0 \\
\quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\
\quad\; v_t \leftarrow \alpha v_{t-1} + (1-\alpha) g_t^2 \\
\quad\; g_t \leftarrow g_t / (\sqrt{v_t} + \epsilon) \\
\quad\; \text {if } \mu \neq 0 \\
\quad\;\quad\; b_t \leftarrow \mu b_{t-1} + g_t \\
\quad\;\quad\; g_t \leftarrow b_t \\
\quad\; \theta_t \leftarrow \theta_{t-1} - \gamma g_t
$

Note that we've reordered the pseudocode slightly differently to the PyTorch docs, so that we divide $g_t$ by $\sqrt{v_t + \epsilon}$ before applying momentum. Both ways are equivalent though.

</details>

In [None]:
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.01,
        alpha: float = 0.99,
        eps: float = 1e-08,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
    ):
        """Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
        """
        self.params = list(params)  # turn params into a list (because it might be a generator)
        self.lr = lr
        self.eps = eps
        self.mu = momentum
        self.lmda = weight_decay
        self.alpha = alpha

        self.b = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return (
            f"RMSprop(lr={self.lr}, eps={self.eps}, momentum={self.mu}, weight_decay={self.lmda}, alpha={self.alpha})"
        )


tests.test_rmsprop(RMSprop)

<details><summary>Solution</summary>

```python
class RMSprop:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.01,
        alpha: float = 0.99,
        eps: float = 1e-08,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
    ):
        """Implements RMSprop.

        Like the PyTorch version, but assumes centered=False
            https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html
        """
        self.params = list(params)  # turn params into a list (because it might be a generator)
        self.lr = lr
        self.eps = eps
        self.mu = momentum
        self.lmda = weight_decay
        self.alpha = alpha

        self.b = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        for theta, b, v in zip(self.params, self.b, self.v):
            g = theta.grad
            if self.lmda != 0:
                g = g + self.lmda * theta
            v.copy_(self.alpha * v + (1 - self.alpha) * g.pow(2))  # inplace operation, to modify value in self.v
            g = g / (v.sqrt() + self.eps)  # not inplace operation
            if self.mu > 0:
                b.copy_(self.mu * b + g)  # inplace operation, to modify value in self.b
                g = b
            theta -= self.lr * g  # inplace operation, to modify params

    def __repr__(self) -> str:
        return (
            f"RMSprop(lr={self.lr}, eps={self.eps}, momentum={self.mu}, weight_decay={self.lmda}, alpha={self.alpha})"
        )


tests.test_rmsprop(RMSprop)
```
</details>

### Adam, and "momentum"

We'll end by implementing Adam and AdamW, two of the most popular optimizers in deep learning. These combine the benefits of RMSprop and SGD with momentum: they have the same variance-based scaling as RMSprop, but they also have an update rule based on the first moment of gradients as well.

There's an important clarification to make here - the first order adjustment of Adam is sometimes called momentum as a shorthand, but there's an important sense in which it isn't. The key difference is that SGD's momentum causes acceleration until we hit terminal velocity, which could be very large for $\mu \approx 1$. In contrast, Adam's momentum is an exponentially weighted moving average - the parameter $\beta_1$ controls how quickly it adjusts (with a value closer to 1 meaning it adjust to newer values more slowly), but it doesn't change the terminal velocity in any sense. Mathematically, the difference between these two is minimal (all you'd need to do is take Adam's update rule $m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t$ and change it to $m_t \leftarrow \beta_1 m_{t-1} + g_t$ for it to have the same qualitative behaviour as SGD), but this extra factor makes a lot of difference!

### Exercise - implement Adam

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 15-20 minutes on this exercise.
> ```

This should just be an extension of your RMSprop implementation. You still have 2 variables to track, but now the variable $b_t$ for applying momentum has been replaced with $m_t$ for tracking the exponentially weighted moving average of first order moments.

[Here's](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html) a link to the PyTorch version, alternatively you can use the simplified pseudocode below:

<details>
<summary>Click here for the simplified pseudocode</summary>

$
\text {for } t=1 \text { to } \ldots \text { do } \\
\quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\
\quad\; \text {if } \lambda \neq 0 \\
\quad\;\quad\; g_t \leftarrow g_t+\lambda \theta_{t-1} \\
\quad\; m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t \\
\quad\; v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \\
\quad\; \widehat{m_t} \leftarrow m_t / (1 - \beta_1^t) \\
\quad\; \widehat{v_t} \leftarrow v_t / (1 - \beta_2^t) \\
\quad\; \theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} / (\sqrt{\widehat{v_t}} + \epsilon)
$

</details>

Note - we center our first & second moment estimators by dividing by $1 - \beta^t$, which means for this optimizer we do have to track the variable $t$ (make sure to remember to increment it after each use of the `step` function). We do this because Adam's exponentially weighted moving average would otherwise take a while to converge to the true mean (since its estimates initially behave like the truncated sum of a geometric series). We leave it as an exercise for the reader to derive this (hint - try assuming the expected value $\mathbb{E}[g_t] = g_0$ is the same for all $t$, what does the expression $\mathbb{E}[m_t]$ simplify to?).

In [None]:
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return f"Adam(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"


tests.test_adam(Adam)

<details><summary>Solution</summary>

```python
class Adam:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.Adam.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        for theta, m, v in zip(self.params, self.m, self.v):
            g = theta.grad
            if self.lmda != 0:
                g = g + self.lmda * theta
            m.copy_(self.beta1 * m + (1 - self.beta1) * g)
            v.copy_(self.beta2 * v + (1 - self.beta2) * g.pow(2))
            m_hat = m / (1 - self.beta1**self.t)
            v_hat = v / (1 - self.beta2**self.t)
            theta -= self.lr * m_hat / (v_hat.sqrt() + self.eps)
        self.t += 1

    def __repr__(self) -> str:
        return f"Adam(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"
```
</details>

### Exercise - implement AdamW

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵⚪⚪⚪
>
> You should spend up to 10-15 minutes on this exercise.
> ```

Finally, you'll adapt your Adam implementation to implement AdamW. This is a very small modification of the Adam update rule, where we apply weight decay in a different way (by modifying the weights $theta_t$ themselves, rather than modifying the gradients $g_t$ and then using those modified gradients in the first & second moment calculations). This means that, unlike with Adam, using weight decay is equivalent to having a Gaussian prior on the weights with mean zero (or alternatively, equivalent to L2 regularization). This is seen as the more "correct" way to implement weight decay, and so AdamW is now generally preferred over Adam.

You can read more about this variant of Adam [here](https://arxiv.org/abs/1711.05101). The PyTorch docs are [here](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html), and the pseudocode is again provided for you below (but for this exercise we do recommend trying to go without it - having to work with more complex pseudocode and parse out the bits that actually matter is a useful exercise!).

<details>
<summary>Click here for the simplified pseudocode</summary>

$
\text {for } t=1 \text { to } \ldots \text { do } \\
\quad\; g_t \leftarrow \nabla_\theta f_t\left(\theta_{t-1}\right) \\
\quad\; \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
\quad\; m_t \leftarrow \beta_1 m_{t-1} + (1-\beta_1) g_t \\
\quad\; v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g_t^2 \\
\quad\; \widehat{m_t} \leftarrow m_t / (1 - \beta_1^t) \\
\quad\; \widehat{v_t} \leftarrow v_t / (1 - \beta_2^t) \\
\quad\; \theta_t \leftarrow \theta_t - \gamma \widehat{m_t} / (\sqrt{\widehat{v_t}} + \epsilon)
$

</details>

In [None]:
class AdamW:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()

    def __repr__(self) -> str:
        return f"AdamW(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"


tests.test_adamw(AdamW)

<details><summary>Solution</summary>

```python
class AdamW:
    def __init__(
        self,
        params: Iterable[t.nn.parameter.Parameter],
        lr: float = 0.001,
        betas: tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-08,
        weight_decay: float = 0.0,
    ):
        """Implements Adam.

        Like the PyTorch version, but assumes amsgrad=False and maximize=False
            https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
        """
        self.params = list(params)
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.lmda = weight_decay
        self.t = 1

        self.m = [t.zeros_like(p) for p in self.params]
        self.v = [t.zeros_like(p) for p in self.params]

    def zero_grad(self) -> None:
        for p in self.params:
            p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        for theta, m, v in zip(self.params, self.m, self.v):
            g = theta.grad
            theta *= 1 - self.lr * self.lmda
            m.copy_(self.beta1 * m + (1 - self.beta1) * g)
            v.copy_(self.beta2 * v + (1 - self.beta2) * g.pow(2))
            m_hat = m / (1 - self.beta1**self.t)
            v_hat = v / (1 - self.beta2**self.t)
            theta -= self.lr * m_hat / (v_hat.sqrt() + self.eps)
        self.t += 1

    def __repr__(self) -> str:
        return f"AdamW(lr={self.lr}, beta1={self.beta1}, beta2={self.beta2}, eps={self.eps}, weight_decay={self.lmda})"
```
</details>

## Plotting multiple optimisers

Finally, we've provided some code which should allow you to plot more than one of your optimisers at once.

### Exercise - experiment with different optimizers & params

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 20-30 minutes on this exercise.
> ```

We've given you a function below which works just like `opt_fn_with_sgd` from earlier, but takes in a general optimizer and hyperparameters (as a dictionary of keyword arguments like `lr` and `momentum`).

You should use this function to play around with different optimizers and hyperparameters, comparing their performance. The code below gives one example of such a comparison, run it now and see what you get:

In [None]:
def opt_fn(
    fn: Callable,
    xy: Tensor,
    optimizer_class,
    optimizer_hyperparams: dict,
    n_iters: int = 100,
) -> Tensor:
    """Optimize the a given function starting from the specified point.

    optimizer_class: one of the optimizers you've defined, either SGD, RMSprop, or Adam
    optimzer_kwargs: keyword arguments passed to your optimiser (e.g. lr and weight_decay)
    """
    assert xy.requires_grad

    optimizer = optimizer_class([xy], **optimizer_hyperparams)

    xy_list = [xy.detach().clone()]  # so that we don't unintentionally modify past values in `xy_list`

    for i in range(n_iters):
        fn(xy[0], xy[1]).backward()
        optimizer.step()
        optimizer.zero_grad()
        xy_list.append(xy.detach().clone())

    return t.stack(xy_list)


points = []

optimizer_list = [
    (SGD, {"lr": 0.03, "momentum": 0.99}),
    (RMSprop, {"lr": 0.02, "alpha": 0.99, "momentum": 0.8}),
    (Adam, {"lr": 0.2, "betas": (0.99, 0.99), "weight_decay": 0.005}),
    (AdamW, {"lr": 0.2, "betas": (0.99, 0.99), "weight_decay": 0.005}),
]

for optimizer_class, params in optimizer_list:
    xy = t.tensor([2.5, 2.5], requires_grad=True)
    xys = opt_fn(
        pathological_curve_loss,
        xy=xy,
        optimizer_class=optimizer_class,
        optimizer_hyperparams=params,
    )
    points.append((xys, optimizer_class, params))

plot_fn_with_points(pathological_curve_loss, min_points=[(0, "y_min")], points=points)

Note that the focus shouldn't be on figuring out "which one is the best optimizer" - this loss landscape (and other examples we'll give you) were specifically designed to be pathological, and exhibit interesting kinds of behaviours from optimizers. The focus should instead be on understanding how the characteristics of optimizers we discussed in the previous sections are reflected visually in the plots produced on these loss landscapes. Some questions you might want to ask:

- We discussed that Adam (and AdamW) center their first and second moments, so that the early values are large - otherwise they start off small and take a long time to grow. Is this reflected in the plots, i.e. with Adam/AdamW taking larger early steps relative to SGD or RMSprop?
- The momentum used in SGD and RMSprop causes acceleration until "terminal velocity", which is usually a higher cap than Adam and AdamW. Is this reflected in the step size (and the instability) of those optimizers? Do Adam and AdamW seem to adapt slightly faster when they enter new terrain?
- Are there any landscapes where weight decay is advantageous, and can you see why it would be?

Some more functions you might want to try out (with their minima marked on the plots):

In [None]:
def bivariate_gaussian(x, y, x_mean=0.0, y_mean=0.0, x_sig=1.0, y_sig=1.0):
    norm = 1 / (2 * np.pi * x_sig * y_sig)
    x_exp = 0.5 * ((x - x_mean) ** 2) / (x_sig**2)
    y_exp = 0.5 * ((y - y_mean) ** 2) / (y_sig**2)
    return norm * t.exp(-x_exp - y_exp)


means = [(1.0, -0.5), (-1.0, 0.5), (-0.5, -0.8)]


def neg_trimodal_func(x, y):
    """
    This function has 3 global minima, at `means`. Unstable methods can overshoot these minima, and non-adaptive methods
    can fail to converge to them in the first place given how shallow the gradients are everywhere except in the close
    vicinity of the minima.
    """
    z = -bivariate_gaussian(x, y, x_mean=means[0][0], y_mean=means[0][1], x_sig=0.2, y_sig=0.2)
    z -= bivariate_gaussian(x, y, x_mean=means[1][0], y_mean=means[1][1], x_sig=0.2, y_sig=0.2)
    z -= bivariate_gaussian(x, y, x_mean=means[2][0], y_mean=means[2][1], x_sig=0.2, y_sig=0.2)
    return z


plot_fn(neg_trimodal_func, x_range=(-2, 2), y_range=(-2, 2), min_points=means)

In [None]:
def rosenbrocks_banana_func(x: Tensor, y: Tensor, a=1, b=100) -> Tensor:
    """
    This function has a global minimum at `(a, a)` so in this case `(1, 1)`. It's characterized by a long, narrow,
    parabolic valley (parameterized by `y = x**2`). Various gradient descent methods have trouble navigating this
    valley because they often oscillate unstably (gradients from the `b`-term dwarf the gradients from the `a`-term).

    See more on this function: https://en.wikipedia.org/wiki/Rosenbrock_function.
    """
    return (a - x) ** 2 + b * (y - x**2) ** 2 + 1


plot_fn(rosenbrocks_banana_func, x_range=(-2.5, 2.5), y_range=(-2, 4), z_range=(0, 100), min_points=[(1, 1)])

<details>
<summary>Some example visualizations & observations</summary>

Let's start with the negative trimodal function. You should find that weight decay massively helps performance here, but this is for pretty uninteresting reasons - it essentially adds a slope towards the origin, and when the ball rolls towards the origin it will probably also get caught in one of the three minima. So it doesn't tell us much about the actual optimizers.

More interestingly, we can compare the optimizers when they have weight decay switched off. You should find that Adam can outperform SGD and RMSprop here, because the way it uses "momentum" is better suited to this task. For one thing, the first and second moment centering means it can take larger early steps relative to SGD and RMSprop (which both take a while to accelerate). For another, momentum causes RMSprop step sizes to increase in an unstable way, which is why it will overshoot the minima and get stuck on the other side without careful hyperparameter tuning. SGD is even worse - because of its lack of variance-based scaling, it'll utterly fail to move anywhere unless it starts out very close to one of the three minima.

```python
optimizer_list = [
    (SGD, {"lr": 0.1, "momentum": 0.5}),
    (RMSprop, {"lr": 0.1, "alpha": 0.99, "momentum": 0.5}),
    (Adam, {"lr": 0.1, "betas": (0.9, 0.999)}),
]

points = []
for optimizer_class, params in optimizer_list:
    xy = t.tensor([1.0, 1.0], requires_grad=True)
    xys = opt_fn(neg_trimodal_func, xy=xy, optimizer_class=optimizer_class, optimizer_hyperparams=params)
    points.append((xys, optimizer_class, params))

plot_fn_with_points(neg_trimodal_func, points=points, x_range=(-2, 2), y_range=(-2, 2), min_point=means)
```

<div style="text-align: left"><embed src="https://info-arena.github.io/ARENA_img/misc/media-03/0304-points.html" width="1020" height="470"></div>

Next, Rosenbrock's banana. This function has a global minimum at `(1, 1)` inside a long, narrow, parabolic-shaped valley. Basic gradient descent often zigzags back and forth along the valley, making very slow progress. Momentum is absolutely essential to perform well in this task. This is a rare case where SGD plus momentum does converge faster than Adam because the higher terminal velocity enables larger step sizes plus the extreme slope of the loss landscape prevents the kind of instability that usually hinders SGD. However, some caveats: SGD requires a very small step size to prevent unstable oscillations (given how steep the valley is), whereas Adam is much more stable. Furthermore, if we extend the number of iterations, we see that Adam does also converge, and it does so with fewer oscillations than SGD (it stays within the parabolic valley).

```python
optimizer_list = [
    (SGD, {"lr": 0.001, "momentum": 0.99}),
    (Adam, {"lr": 0.1, "betas": (0.9, 0.999)}),
]

points = []
for optimizer_class, params in optimizer_list:
    xy = t.tensor([-1.5, 2.5], requires_grad=True)
    xys = opt_fn(
        rosenbrocks_banana_func, xy=xy, optimizer_class=optimizer_class, optimizer_hyperparams=params, n_iters=500
    )
    points.append((xys, optimizer_class, params))

plot_fn_with_points(
    rosenbrocks_banana_func, x_range=(-2.5, 2.5), y_range=(-2, 4), z_range=(0, 100), min_points=[(1, 1)], points=points
)
```

<div style="text-align: left"><embed src="https://info-arena.github.io/ARENA_img/misc/media-03/0305-points.html" width="1020" height="470"></div>

## Bonus - parameter groups

> *If you're interested in these exercises then you can go through them, if not then you can move on to the next section (weights and biases).*

Rather than passing a single iterable of parameters into an optimizer, you have the option to pass a list of parameter groups, each one with different hyperparameters. As an example of how this might work:

```python
optim.SGD([
    {'params': model.base.parameters()},
    {'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
```

The first argument here is a list of dictionaries, with each dictionary defining a separate parameter group. Each should contain a `params` key, which contains an iterable of parameters belonging to this group. The dictionaries may also contain keyword arguments. If a parameter is not specified in a group, PyTorch uses the value passed as a keyword argument. So the example above is equivalent to:

```python
optim.SGD([
    {'params': model.base.parameters(), 'lr': 1e-2, 'momentum': 0.9},
    {'params': model.classifier.parameters(), 'lr': 1e-3, 'momentum': 0.9}
])
```

All parameters have default values, with the exception of `lr` which is why you need to specify it either as a keyword arg to the optimizer or separately in each group.

PyTorch optimisers will store all their params and hyperparams in the `param_groups` attribute - this is why when we want to modify an optimizer's learning rate (which we'll do later on in the course), even if we didn't specify any parameter groups we'll still need to use `optimizer.param_groups[0].lr = new_lr`.

### When to use parameter groups

Parameter groups can be useful in several different circumstances. A few examples:

* Finetuning a model by freezing earlier layers and only training later layers is an extreme form of parameter grouping. We can use the parameter group syntax to apply a modified form, where the earlier layers have a smaller learning rate. This allows these earlier layers to adapt to the specifics of the problem, while making sure they don't forget all the useful features they've already learned.
* Often it's good to treat weights and biases differently, e.g. effects like weight decay are often applied to weights but not biases. PyTorch doesn't differentiate between these two, so you'll have to do this manually using paramter groups.
    * This in particular, you might be doing later in the course, if you choose the "train BERT from scratch" exercises during the transformers chapter.
* On the subject of transformers, weight decay is often *not* applied to embeddings and layernorms in transformer models.

More generally, if you're trying to replicate a paper, it's important to be able to use all the same training details that the original authors did, so you can get the same results.

### Exercise - rewrite SGD to use parameter groups

> ```yaml
> Difficulty: 🔴🔴🔴🔴⚪
> Importance: 🔵⚪⚪⚪⚪
>
> You should spend up to 30-40 minutes on this exercise.
> It's somewhat useful to understand the idea of parameter groups, less so to know how they're actually implemented.
> ```

You should rewrite the `SGD` optimizer from the earlier exercises, to use `param_groups`. This will involve filling in the 3 methods `__init__`, `zero_grad`, and `step`. Some guidance:

- In `__init__` you should create `self.param_groups`, which is a list of dictionaries with each one containing `"params"` as well as all the hyperparameters for that group. Remember the hierarchy for hparams: "specified for group" > "specified as a keyword argument" > "default value".
- In `zero_grad` and `step` the logic is the same as before, but now you need a double nested for loop: once over the param groups in `self.param_groups`, and once over the params in each group. For the latter, make sure you're using the group-specific hyperparameters (i.e. the ones you hopefully stored in `self.param_groups` in the init method).

In [None]:
class SGD:
    def __init__(self, params, **kwargs):
        """Implements SGD with momentum.

        Accepts parameters in groups, or an iterable.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        # Deal with case where we didn't supply groups, so we just make it into a single dictionary
        if not isinstance(params, (list, tuple)):
            params = [{"params": params}]

        # Make sure each group["params"] is a list of params not a generator (so we don't iterate over & destroy it!)
        for p in params:
            p["params"] = list(p["params"])

        self.param_groups = []

        # YOUR CODE HERE - fill in `self.param_groups`
        raise NotImplementedError()

    def zero_grad(self) -> None:
        raise NotImplementedError()

    @t.inference_mode()
    def step(self) -> None:
        raise NotImplementedError()


tests.test_sgd_param_groups(SGD)

<details><summary>Solution</summary>

```python
class SGD:
    def __init__(self, params, **kwargs):
        """Implements SGD with momentum.

        Accepts parameters in groups, or an iterable.

        Like the PyTorch version, but assume nesterov=False, maximize=False, and dampening=0
            https://pytorch.org/docs/stable/generated/torch.optim.SGD.html#torch.optim.SGD
        """
        # Deal with case where we didn't supply groups, so we just make it into a single dictionary
        if not isinstance(params, (list, tuple)):
            params = [{"params": params}]

        # Make sure each group["params"] is a list of params not a generator (so we don't iterate over & destroy it!)
        for p in params:
            p["params"] = list(p["params"])

        self.param_groups = []

        for param_group in params:
            # Set hyperparameters hierarchically: specified for group > specified as a keyword argument > default value
            # We do this via a dictionary merge (right takes precedence over left)
            param_group = {"momentum": 0.0, "weight_decay": 0.0, **kwargs, **param_group}

            # Check that "lr" is supplied
            assert "lr" in param_group, "Error: one of the param groups didn't specify 'lr'"

            # Set "params" and "b" in our group
            param_group["b"] = [t.zeros_like(p) for p in param_group["params"]]

            self.param_groups.append(param_group)

    def zero_grad(self) -> None:
        for param_group in self.param_groups:
            for p in param_group["params"]:
                p.grad = None

    @t.inference_mode()
    def step(self) -> None:
        # loop through each param group

        for param_group in self.param_groups:
            # Get hparams for this group
            lmda = param_group["weight_decay"]
            mu = param_group["momentum"]
            lr = param_group["lr"]

            # Same code as for SGD implementation before, but using group-specific hparams
            for b, theta in zip(param_group["b"], param_group["params"]):
                g = theta.grad
                if lmda != 0:
                    g = g + lmda * theta  # this shouldn't be inplace since we don't want to modify theta.grad
                if mu != 0:
                    b.copy_(mu * b + g)  # this does need to be inplace, since we're modifying the value in `self.b`
                    g = b
                theta -= lr * g  # inplace operation, to modify params


tests.test_sgd_param_groups(SGD)
```
</details>

## Bonus - Muon Optimizer

Hot off the press is a new optimizer called *Muon*. Muon is an optimizer specialized only for the parameters of a network that are *hidden* and *at least 2-dimensional*.
 * For image classification, we skip anything that directly interfaces with the input, or the output.
 * For language models, this means skipping the embedding and unembedding layers.
The dimensionality requirement also means skipping biases, $\gamma$ or $\beta$ in layernorms, etc.
All other parameters are optimized using Adam as per usual.

<figure>
  <img src="https://pbs.twimg.com/media/GZELa2YbUAAqbW6?format=jpg&name=large" width="500">
  <figcaption>Comparison of Muon to AdamW optimizer for NanoGPT training speedrun. Taken from <a href="https://x.com/kellerjordan0/status/1842300916864844014">https://x.com/kellerjordan0/status/1842300916864844014</a></figcaption>
</figure>

Muon has shown great promise for language models, shaving a massive 40%(!) off the training time for [nanoGPT speedrun](https://www.tylerromero.com/posts/nanogpt-speedrun-worklog/), a collaborative project to train a model as performance as GPT-2 as fast as possible, based on [Andrej Karpathy's nanoGPT implementation](https://github.com/karpathy/llm.c/discussions/481).

We might make this into an actual exercise later, but for now, here's a series of resources should you wish to implement it yourself:
* [Introduction to Muon](https://kellerjordan.github.io/posts/muon/)
* [NanoGPT Speedrun Project](https://github.com/KellerJordan/modded-nanogpt/) Current word record is sub-3 minutes(!!)*
  - On a 8xH100 cluster, about 16 PFLOPS of power. On a single consumer GPU (RTX 3090), (assuming no issues with out-of-memory), with ~140TFLOPS of power,  this would take ~5 hours, still incredibly impressive.
* [Twitter thread on Muon](https://x.com/kellerjordan0/status/1842300916864844014)
* [Derivation of Muon](https://jeremybernste.in/writing/deriving-muon)
  * Not required reading, but if you're curious about the math
* [Reference Muon implementation in PyTorch](https://github.com/KellerJordan/Muon)

# 2️⃣ Weights and Biases

> ##### Learning Objectives
>
> * Write modular, extensible code for training models
> * Learn what the most important hyperparameters are, and methods for efficiently searching over hyperparameter space
> * Learn how to use Weights & Biases for logging your runs
> * Adapt your code from yesterday to log training runs to Weights & Biases, and use this service to run **hyperparameter sweeps**

Next, we'll look at methods for choosing hyperparameters effectively. You'll learn how to use **Weights and Biases**, a useful tool for hyperparameter search.

The exercises themselves will be based on your ResNet implementations from yesterday, although the principles should carry over to other models you'll build in this course (such as transformers next week).

Note, this page only contains a few exercises, and they're relatively short. You're encouraged to spend some time playing around with Weights and Biases, but you should also spend some more time finetuning your ResNet from yesterday (you might want to finetune ResNet during the morning, and look at today's material in the afternoon - you can discuss this with your partner). You should also spend some time reviewing the last three days of material, to make sure there are no large gaps in your understanding.

## Finetuning & feature extraction

> We'll start with a brief discussion of the related concepts **finetuning** and **feature extraction**. If you've already gone through yesterday's bonus material on feature extraction then you can skip this section.

**Finetuning** can mean slightly different things in different contexts, but broadly speaking it means using the weights of an already trained network as the starting values for training a new network. Because training networks from scratch is very computationally expensive, this is a common practice in ML.

The specific type of finetuning we'll be doing here is called **feature extraction**. This is when we freeze most layers of a model except the last few, and perform gradient descent on those. We call this feature extraction because the earlier layers of the model have already learned to identify important features of the data (and these features are also relevant for the new task), so all that we have to do is train a few final layers in the model to extract these features.

*Terminology note - sometimes feature extraction and finetuning are defined differently, with finetuning referring to the training of all the weights in a pretrained model (usually with a small or decaying learning rate), and feature extraction referring to the freezing of some layers and training of others. To avoid confusion here, we'll use the term "feature extraction" rather than "finetuning".*

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/feature_extraction.png" width="400">

The way we implement feature extraction in PyTorch is by **freezing** all but the last few layers of our model, meaning gradients don't propagate back through them (and we don't perform gradient descent updates on them) - more on gradient freezing tomorrow! We've used the `get_resnet_for_feature_extraction` function to do this (the code for this is given to you below so you won't have to write it yourself). This function creates a version of the `ResNet34` model, loads in weights from the PyTorch ResNet34 implementation, freezes all layers, and replaces the final linear layer with an unfrozen randomly initialized linear layer with a certain number of output features (in our case 10 because we're doing feature extraction on CIFAR10 - see next section).

## CIFAR10

The benchmark we'll be doing feature extraction on is [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html), which consists of 60000 32x32 colour images in 10 different classes (as opposed to the 1000 different classes that `ResNet34` was originally trained on). Don't peek at what other people online have done for CIFAR10 (it's a common benchmark), because the point is to develop your own process by which you can figure out how to improve your model. Just reading the results of someone else would prevent you from learning how to get the answers. To get an idea of what's possible: using one V100 and a modified ResNet, one entry in the DAWNBench competition was able to achieve 94% test accuracy in 24 epochs and 76 seconds. 94% is approximately [human level performance](http://karpathy.github.io/2011/04/27/manually-classifying-cifar10/).

Below is some boilerplate code for downloading and transforming `CIFAR10` data (this shouldn't take more than a minute to run the first time). Note, even though CIFAR10 data is 32x32, we'll resize it to 224x224 like we did for ImageNet yesterday, because ResNet expects 224x224 images as input.

In [3]:
def get_cifar() -> tuple[datasets.CIFAR10, datasets.CIFAR10]:
    """Returns CIFAR-10 train and test sets."""
    cifar_trainset = datasets.CIFAR10(exercises_dir / "data", train=True, download=True, transform=IMAGENET_TRANSFORM)
    cifar_testset = datasets.CIFAR10(exercises_dir / "data", train=False, download=True, transform=IMAGENET_TRANSFORM)
    return cifar_trainset, cifar_testset


IMAGE_SIZE = 224
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

IMAGENET_TRANSFORM = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ]
)


cifar_trainset, cifar_testset = get_cifar()

imshow(
    cifar_trainset.data[:15],
    facet_col=0,
    facet_col_wrap=5,
    facet_labels=[cifar_trainset.classes[i] for i in cifar_trainset.targets[:15]],
    title="CIFAR-10 images",
    height=600,
    width=1000,
)

100%|██████████| 170M/170M [00:03<00:00, 44.6MB/s]


## Train function (modular)

First, let's build on the training function we used yesterday. Previously, we just used a single `train` function which took a dataclass as argument. But this resulted in a very long function with many nested loops and some repeated code. Instead, we'll write our code in the form of a `ResNetFinetuner` class with multiple methods, each one being responsible for a single part of the training process. This will make our code more modular, and easier to read and debug.

We've given you the `ResNetFinetuner` class below, as well as a dataclass which contains all the hyperparameters we'll use (again this helps us keep everything organized). You should read this and make sure you understand the role of each method. A brief summary:

* `pre_training_setup` defines the model, optimizer, dataset, and objects for logging data. Note that it's not good practice to have this logic run in `__init__`, because it's something we only need to do just before actually training (this structural flexibility will prove useful later, when we introduce weights & biases).
* `training_step` does a single gradient update step on a single batch of data, and logs & returns the loss.
* `evaluate` method computes the total accuracy of the model over the validation set, and logs & returns this accuracy. Note use of the `torch.inference_mode()` decorator, which stops gradients propagating (this is equivalent to using it as a context manager).
* `train` combines this all: it performs the pre-training setup, then alternates between training & evaluation for some number of epochs. Note that `model.train()` and `model.eval()` are called before these stages respectively - for why we have to do this, see yesterday's discussion of BatchNorm.

In [7]:
from torch.optim import AdamW, Adam, SGD, RMSprop

In [5]:
@dataclass
class ResNetFinetuningArgs:
    n_classes: int = 10
    batch_size: int = 128
    epochs: int = 3
    learning_rate: float = 1e-3
    weight_decay: float = 0.0


class ResNetFinetuner:
    def __init__(self, args: ResNetFinetuningArgs):
        self.args = args

    def pre_training_setup(self):
        self.model = get_resnet_for_feature_extraction(self.args.n_classes).to(device)
        self.optimizer = AdamW(
            self.model.out_layers[-1].parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay
        )
        self.trainset, self.testset = get_cifar()
        self.train_loader = DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True)
        self.test_loader = DataLoader(self.testset, batch_size=self.args.batch_size, shuffle=False)
        self.logged_variables = {"loss": [], "accuracy": []}
        self.examples_seen = 0

    def training_step(
        self,
        imgs: Float[Tensor, "batch channels height width"],
        labels: Int[Tensor, "batch"],
    ) -> Float[Tensor, ""]:
        """Perform a gradient update step on a single batch of data."""
        imgs, labels = imgs.to(device), labels.to(device)

        logits = self.model(imgs)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.examples_seen += imgs.shape[0]
        self.logged_variables["loss"].append(loss.item())
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """Evaluate the model on the test set and return the accuracy."""
        self.model.eval()
        total_correct, total_samples = 0, 0

        for imgs, labels in tqdm(self.test_loader, desc="Evaluating"):
            imgs, labels = imgs.to(device), labels.to(device)
            logits = self.model(imgs)
            total_correct += (logits.argmax(dim=1) == labels).sum().item()
            total_samples += len(imgs)

        accuracy = total_correct / total_samples
        self.logged_variables["accuracy"].append(accuracy)
        return accuracy

    def train(self) -> dict[str, list[float]]:
        self.pre_training_setup()

        accuracy = self.evaluate()

        for epoch in range(self.args.epochs):
            self.model.train()

            pbar = tqdm(self.train_loader, desc="Training")
            for imgs, labels in pbar:
                loss = self.training_step(imgs, labels)
                pbar.set_postfix(loss=f"{loss:.3f}", ex_seen=f"{self.examples_seen:06}")

            accuracy = self.evaluate()
            pbar.set_postfix(loss=f"{loss:.3f}", accuracy=f"{accuracy:.2f}", ex_seen=f"{self.examples_seen:06}")

        return self.logged_variables

With this class, we can perform feature extraction on our model as follows:

In [None]:
args = ResNetFinetuningArgs()
trainer = ResNetFinetuner(args)
logged_variables = trainer.train()

Evaluating: 100%|██████████| 79/79 [39:18<00:00, 29.85s/it]
Training:  19%|█▉        | 75/391 [41:15<2:55:03, 33.24s/it, ex_seen=009600, loss=1.011]

In [None]:
line(
    y=[logged_variables["loss"][: 391 * 3 + 1], logged_variables["accuracy"][:4]],
    x_max=len(logged_variables["loss"][: 391 * 3 + 1] * args.batch_size),
    yaxis2_range=[0, 1],
    use_secondary_yaxis=True,
    labels={"x": "Examples seen", "y1": "Cross entropy loss", "y2": "Test Accuracy"},
    title="Feature extraction with ResNet34",
    width=800,
)

Let's see how well our ResNet performs on the first few inputs!

In [None]:
def test_resnet_on_random_input(model: ResNet34, n_inputs: int = 3, seed: int | None = 42):
    if seed is not None:
        np.random.seed(seed)
    indices = np.random.choice(len(cifar_trainset), n_inputs).tolist()
    classes = [cifar_trainset.classes[cifar_trainset.targets[i]] for i in indices]
    imgs = cifar_trainset.data[indices]
    device = next(model.parameters()).device
    with t.inference_mode():
        x = t.stack(list(map(IMAGENET_TRANSFORM, imgs)))
        logits: Tensor = model(x.to(device))
    probs = logits.softmax(-1)
    if probs.ndim == 1:
        probs = probs.unsqueeze(0)
    for img, label, prob in zip(imgs, classes, probs):
        display(HTML(f"<h2>Classification probabilities (true class = {label})</h2>"))
        imshow(img, width=200, height=200, margin=0, xaxis_visible=False, yaxis_visible=False)
        bar(prob, x=cifar_trainset.classes, width=600, height=400, text_auto=".2f", labels={"x": "Class", "y": "Prob"})


test_resnet_on_random_input(trainer.model)

In [1]:
np.ones(7)

NameError: name 'np' is not defined

## What is Weights and Biases?

Weights and Biases is a cloud service that allows you to log data from experiments. Your logged data is shown in graphs during training, and you can easily compare logs across different runs. It also allows you to run **sweeps**, where you can specifiy a distribution over hyperparameters and then start a sequence of test runs which use hyperparameters sampled from this distribution.

Before you run any of the code below, you should visit the [Weights and Biases homepage](https://wandb.ai/home), and create your own account.

We'll be able to keep the same structure of training loop when using weights and biases, we'll just have to add a few functions. The key functions to know are:

#### `wandb.init`

This initialises a training run. It should be called once, at the start of your training loop.

A few important arguments are:

* `project` - the name of the project where you're sending the new run. For example, this could be `'day3-resnet'` for us. You can have many different runs in each project.
* `name` - a display name for this run. By default, if this isn't supplied, wandb generates a random 2-word name for you (e.g. `gentle-sunflower-42`).
* `config` - a dictionary containing hyperparameters for this run. If you pass this dictionary, then you can compare different runs' hyperparameters & results in a single table. Alternatively, you can pass a dataclass.

#### `wandb.watch`

This function tells wandb to watch a model - this means that it will log the gradients and parameters of the model during training. We'll call this function once, after we've created our model. The 3 most important arguments are:

* `models` - a module or list of modules (e.g. in our case we might just want to log the weights of the final linear layer, because the others aren't being trained)
* `log` - determines what gets tracked, possible values are `'gradients'` (default), `'parameters'` or `'all'`
* `log_freq` - the number of batches between each logging step (default is 1000)

Why do we log parameters and gradients? Mainly this is [helpful for debugging](https://wandb.ai/wandb_fc/articles/reports/Debugging-Neural-Networks-with-PyTorch-and-W-B-Using-Gradients-and-Visualizations--Vmlldzo1NDQxNTA5), because it helps us identify problems like exploding or vanishing gradients, dead ReLUs, etc.

#### `wandb.log`

For logging metrics to the wandb dashboard. This is used as `wandb.log(data, step)`, where `step` is an integer (the x-axis on your metric plots) and `data` is a dictionary of metrics (i.e. the keys are metric names, and the values are metric values).

#### `wandb.finish`

This function should be called at the end of your training loop. It finishes the run, and saves the results to the cloud.

If a run terminates early (either because of an error or because you manually terminated it), remember to still run `wandb.finish()` - it will speed things up when you start a new run (otherwise you have to wait for the previous run to be terminated & uploaded).

### Exercise - rewrite training loop with wandb

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵🔵⚪
>
> You should spend up to 10-25 minutes on this exercise.
> ```

You should now take the training loop from above (i.e. the `ResNetTrainer` class) and rewrite it to use the four `wandb` functions above (in place of the `logged_variables` dictionary, which you can now remove). This will require:

- Initializing your run
    - Your new `pre_training_setup` method should call `wandb.watch` and `wandb.init` as well as all the stuff it previously did
    - For `wandb.init`, you can use the project & name arguments from your new dataclass (see below)
    - For `wandb.watch`, be careful of the `log_freq` value - you want to make sure you're logging more than once per epoch
- Logging variables to wandb during your run
    - i.e. replace updating of `self.logged_variables` with calls to `wandb.log`
    - We recommend tracking `self.examples_seen` and passing this as the `step` argument to your logging calls, this way it's easier to compare across different runs with e.g. different batch sizes (more on this later)
- Finishing the run
    - i.e. calling `wandb.finish` at the end of your training loop

This is all you need to do to get wandb working, so the vast majority of the code you write below will be copied and pasted from the previous `ResNetFinetuner` class. We've given you a template for this below, along with a new dataclass. Both the dataclass and the trainer class use inheritance to remove code duplication (e.g. because we don't need to rewrite our `__init__` method, it'll be the same as for `ResNetFinetuner`).

Note, we generally recommend keeping progress bars in wandb because they update slightly faster and can give you a better sense of whether something is going wrong in training.

In [None]:
@dataclass
class WandbResNetFinetuningArgs(ResNetFinetuningArgs):
    """Contains new params for use in wandb.init, as well as all the ResNetFinetuningArgs params."""

    wandb_project: str | None = "day3-resnet"
    wandb_name: str | None = None


class WandbResNetFinetuner(ResNetFinetuner):
    args: WandbResNetFinetuningArgs  # adding this line helps with typechecker!
    examples_seen: int = 0  # for tracking the total number of examples seen; used as step argument in wandb.log

    def pre_training_setup(self):
        """Initializes the wandb run using `wandb.init` and `wandb.watch`."""
        super().pre_training_setup()
        raise NotImplementedError()

    def training_step(
        self,
        imgs: Float[Tensor, "batch channels height width"],
        labels: Int[Tensor, "batch"],
    ) -> Float[Tensor, ""]:
        """Equivalent to ResNetFinetuner.training_step, but logging the loss to wandb."""
        raise NotImplementedError()

    @t.inference_mode()
    def evaluate(self) -> float:
        """Equivalent to ResNetFinetuner.evaluate, but logging the accuracy to wandb."""
        raise NotImplementedError()

    def train(self) -> None:
        """Equivalent to ResNetFinetuner.train, but with wandb initialization & calling `wandb.finish` at the end."""
        self.pre_training_setup()
        raise NotImplementedError()


args = WandbResNetFinetuningArgs()
trainer = WandbResNetFinetuner(args)
trainer.train()

<details><summary>Solution</summary>

```python
@dataclass
class WandbResNetFinetuningArgs(ResNetFinetuningArgs):
    """Contains new params for use in wandb.init, as well as all the ResNetFinetuningArgs params."""

    wandb_project: str | None = "day3-resnet"
    wandb_name: str | None = None


class WandbResNetFinetuner(ResNetFinetuner):
    args: WandbResNetFinetuningArgs  # adding this line helps with typechecker!
    examples_seen: int = 0  # for tracking the total number of examples seen; used as step argument in wandb.log

    def pre_training_setup(self):
        """Initializes the wandb run using `wandb.init` and `wandb.watch`."""
        super().pre_training_setup()
        wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        wandb.watch(self.model.out_layers[-1], log="all", log_freq=50)
        self.examples_seen = 0

    def training_step(
        self,
        imgs: Float[Tensor, "batch channels height width"],
        labels: Int[Tensor, "batch"],
    ) -> Float[Tensor, ""]:
        """Equivalent to ResNetFinetuner.training_step, but logging the loss to wandb."""
        imgs, labels = imgs.to(device), labels.to(device)

        logits = self.model(imgs)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        self.examples_seen += imgs.shape[0]
        wandb.log({"loss": loss.item()}, step=self.examples_seen)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """Equivalent to ResNetFinetuner.evaluate, but logging the accuracy to wandb."""
        self.model.eval()
        total_correct, total_samples = 0, 0

        for imgs, labels in tqdm(self.test_loader, desc="Evaluating"):
            imgs, labels = imgs.to(device), labels.to(device)
            logits = self.model(imgs)
            total_correct += (logits.argmax(dim=1) == labels).sum().item()
            total_samples += len(imgs)

        accuracy = total_correct / total_samples
        wandb.log({"accuracy": accuracy}, step=self.examples_seen)
        return accuracy

    def train(self) -> None:
        """Equivalent to ResNetFinetuner.train, but with wandb initialization & calling `wandb.finish` at the end."""
        self.pre_training_setup()
        accuracy = self.evaluate()

        for epoch in range(self.args.epochs):
            self.model.train()

            pbar = tqdm(self.train_loader, desc="Training")
            for imgs, labels in pbar:
                loss = self.training_step(imgs, labels)
                pbar.set_postfix(loss=f"{loss:.3f}", ex_seen=f"{self.examples_seen=:06}")

            accuracy = self.evaluate()
            pbar.set_postfix(loss=f"{loss:.3f}", accuracy=f"{accuracy:.2f}", ex_seen=f"{self.examples_seen=:06}")

        wandb.finish()
```
</details>

When you run the code for the first time, you'll have to login to Weights and Biases, and paste an API key into VSCode. After this is done, your Weights and Biases training run will start. It'll give you a lot of output text, one line of which will look like:

```
View run at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/runs/<RUN-NAME>
```

which you can click on to visit the run page.

A nice thing about using Weights and Biases is that you don't need to worry about generating your own plots, that will all be done for you when you visit the page.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/wandb-day3.png" width="900">

### Run & project pages

The page you visit will show you a plot of all the variables you've logged, among other things. You can do many things with these plots (e.g. click on the "edit" icon for your `train_loss` plot, and apply smoothing & change axis bounds to get a better picture of your loss curve).

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/wandb-day3-smoothed.png" width="1000">

The charts are a useful feature of the run page that gets opened when you click on the run page link, but they're not the only feature. You can also navigate to the project page (click on the option to the right of **Projects** on the bar at the top of the Wandb page), and see superimposed plots of all the runs in this project. You can also click on the **Table** icon on the left hand sidebar to see a table of all the runs in this project, which contains useful information (e.g. runtime, the most recent values of any logged variables, etc). However, comparing runs like this becomes especially useful when we start doing hyperparameter search.

You can also look at the system tab to inspect things like GPU utilization - this is a good way of checking whether you're saturating your GPU or whether you can afford to increase your batch size more. This tab will be especially useful in the next section, when we move onto distributed training.

## Some training heuristics

One important skill which every aspiring ML researcher should develop is the ability to play around with hyperparameters and improve a model's training. At times this is more of an art than a science, because frequently rules and heuristics which work most of the time will break down in certain special cases. For example, a common heuristic for number of workers in a `DataLoader` is to set them to be 4 times the number of GPUs you have available (see later sections on distributed computing for more on this). However, setting these values too high can lead to issues where your CPU is bottlenecked by the workers and your epochs take a long time to start - it took me a long time to realize this was happening when I was initially writing these exercises!

Sweeping over hyperparameters (which we'll cover shortly) can help remove some of the difficulty here, because you can use sweep methods that guide you towards an optimal set of hyperparameter choices rather than having to manually pick your own. However, here are a few heuristics that you might find useful in a variety of situations:

- **Setting batch size**
    - Generally you should aim to **saturate your GPU** with data - this means choosing a batch size that's as large as possible without causing memory errors
        - You should generally aim for over 70% utilization of your GPU
    - Note, this means you should generally try for a larger batch size in your testloader than your trainloader (because evaluation is done without gradients, and so a smaller memory constraint)
        - A good starting point is 4x the size, but this will vary between models
- **Choosing a learning rate**
    - Inspecting loss curves can be a good way of evaluating our learning rate
        - If loss is decreasing very slowly & monotonically then this is a sign you should increase the learning rate, whereas very large loss spikes are a sign that you should decrease it
    - A common strategy is **warmup**, i.e. having a smaller learning rate for a short period of time at the start of training - we'll do this a lot in later material
    - [Jeremy Jordan](https://www.jeremyjordan.me/nn-learning-rate/) has a good blog post on learning rates
- **Balancing learning rate and batch size**
    - For standard optimizers like `SGD`, it's a good idea to scale the learning rate inversely to the batch size - this way the variance of each parameter step remains the same
    - However for **adaptive optimizers** such as `Adam` (where the size of parameter updates automatically adjusts based on the first and second moments of our gradients), this isn't as necessary
        - This is why we generally start with default parameters for Adam, and then adjust from there
- **Misc. advice**
    - If you're training a larger model, it's sometimes a good idea to start with a smaller version of that same model. Good hyperparameters tend to transfer if the architecture & data is the same; the main difference is the larger model may require more regularization to prevent overfitting.
    - Bad hyperparameters are usually clearly worse by the end of the first 1-2 epochs. You can manually abort runs that don't look promising (or do it automatically - see discussion of Hyperband in wandb sweeps at the end of this section)
    - Overfitting at the start is better than underfitting, because it means your model is capable of learning and has enough capacity

## Hyperparameter search

One way to search for good hyperparameters is to choose a set of values for each hyperparameter, and then search all combinations of those specific values. This is called **grid search**. The values don't need to be evenly spaced and you can incorporate any knowledge you have about plausible values from similar problems to choose the set of values. Searching the product of sets takes exponential time, so is really only feasible if there are a small number of hyperparameters. I would recommend forgetting about grid search if you have more than 3 hyperparameters, which in deep learning is "always".

A much better idea is for each hyperparameter, decide on a sampling distribution and then on each trial just sample a random value from that distribution. This is called **random search** and back in 2012, you could get a [publication](https://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) for this. The diagram below shows the main reason that random search outperforms grid search. Empirically, some hyperparameters matter more than others, and random search benefits from having tried more distinct values in the important dimensions, increasing the chances of finding a "peak" between the grid points.

<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/grid_vs_random.png" width="540">


It's worth noting that both of these searches are vastly less efficient than gradient descent at finding optima - imagine if you could only train neural networks by randomly initializing them and checking the loss! Either of these search methods without a dose of human (or eventually AI) judgement is just a great way to turn electricity into a bunch of models that don't perform very well.

## Running hyperparameter sweeps with `wandb`

Now we've come to one of the most impressive features of `wandb` - being able to perform hyperparameter sweeps. We do this by defining a `sweep_config` dict which tells us how our hyperparameters will be randomly sampled, then we write a `train` function which takes no arguments and launches a training run with our modified hyperparameters. Lastly we use `wandb.sweep` and `wandb.agent` to run our sweep. We'll go through each step of this below.

### Sweep config syntax

The basic syntax for a sweep config looks like this:

```python
sweep_config = dict(
    method = method, # can be "grid", "random" or "bayes"
    metric = dict(
        name = metric_name, # name of the metric you're optimising (should be a numeric type logged in `wandb.log`)
        goal = goal, # either "maximize" or "minimize"
    )),
    parameters = dict(
        param_1 = dict(...),
        param_2 = dict(...),
        ...
    ),
)
```

The `method` argument determines how we perform search: `grid` is over all combinations, `random` independently samples each hyperparameter, and `bayes` uses Bayesian optimization to sample hyperparameters. The `metric` dict determines what logged variable we're optimizing, and in what direction. Lastly, `parameters` is a list of parameters we're varying, with each dictionary describing how we want that parameter to be sampled. Possible ways to specify distributions include:

```python
parameters = dict(
    param_1 = dict(values = [...]), # uniformly sample from list of values
    param_2 = dict(values = [...], probabilities = [...]), # sample from list with given probabilities
    param_3 = dict(min = ..., max = ...), # uniform distribution over [min, max), can either be ints or floats
    param_4 = dict(min = ..., max = ..., distribution = "log_uniform_values"), # use log-uniform distribution instead
)
```

Note on log uniform distribution - this essentially means we return `value` s.t. `log(value)` is uniformly distributed between `log(min)` and `log(max)`. It can be a useful way to sample hyperparameters which take values in a very large range.

You can read more about the syntax [here](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration), but the examples we've given you above should be enough to complete the rest of these exercises.

<details>
<summary>Note on using YAML files for sweeps (optional)</summary>

Rather than using a dictionary, you can alternatively store the `sweep_config` data in a YAML file if you prefer. You will then be able to run a sweep via the following terminal commands:

```
wandb sweep sweep_config.yaml

wandb agent <SWEEP_ID>
```

where `SWEEP_ID` is the value returned from the first terminal command. You will also need to add another line to the YAML file, specifying the program to be run. For instance, your YAML file might start like this:

```yaml
program: train.py
method: random
metric:
    name: test_accuracy
    goal: maximize
```

For more, see [this link](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration).

</details>

### Exercise - define a sweep config & update `args`

> ```yaml
> Difficulty: 🔴🔴⚪⚪⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 10-20 minutes on this exercise.
> Learning how to use wandb for sweeps is very useful, so make sure you understand all parts of this code.
> ```

Using the syntax discussed above, you should define a dictionary `sweep_config` which has the following rules for hyperparameter sweeps:

* Hyperparameters are chosen **randomly**, according to the distributions given in the dictionary
* Your goal is to **maximize** the **accuracy** metric
* The hyperparameters you vary are:
    * Learning rate - a log-uniform distribution between 1e-4 and 1e-1
    * Batch size - sampled uniformly from (32, 64, 128, 256)
    * Weight decay - with 50% probability set to 0, and with 50% probability log-uniform between 1e-4 and 1e-2

You should also fill in the `update_args` function, which returns a modified version of `args` based on the hyperparameters sampled by the sweep. In other words, it should take an `args` object and a dictionary of sampled parameters that might look something like `{"lr": 0.001, "batch_size": 64, ...}`, and return a new `args` object with these fields modified.

In [None]:
# YOUR CODE HERE - fill `sweep_config` so it has the requested behaviour
sweep_config = dict(
    method = ...,
    metric = ...,
    parameters = ...,
)


def update_args(args: WandbResNetFinetuningArgs, sampled_parameters: dict) -> WandbResNetFinetuningArgs:
    """
    Returns a new args object with modified values. The dictionary `sampled_parameters` will have the same keys as
    your `sweep_config["parameters"]` dict, and values equal to the sampled values of those hyperparameters.
    """
    assert set(sampled_parameters.keys()) == set(sweep_config["parameters"].keys())

    # YOUR CODE HERE - update `args` based on `sampled_parameters`
    raise NotImplementedError()


tests.test_sweep_config(sweep_config)
tests.test_update_args(update_args, sweep_config)

<details>
<summary>Help - I'm not sure how to implement the weight decay distribution that was requested.</summary>

The easiest option is to include 2 parameters: one is a boolean and determines whether to use weight decay, one is log-uniform and gives you the value in the cases where it's non-zero. Both parameters are used to set the final value in `args`.

</details>

<details>
<summary>Solution</summary>

```python
sweep_config = dict(
    method="random",
    metric=dict(name="accuracy", goal="maximize"),
    parameters=dict(
        learning_rate=dict(min=1e-4, max=1e-1, distribution="log_uniform_values"),
        batch_size=dict(values=[32, 64, 128, 256]),
        weight_decay=dict(min=1e-4, max=1e-2, distribution="log_uniform_values"),
        weight_decay_bool=dict(values=[True, False]),
    ),
)

def update_args(args: WandbResNetFinetuningArgs, sampled_parameters: dict) -> WandbResNetFinetuningArgs:
    assert set(sampled_parameters.keys()) == set(sweep_config["parameters"].keys())

    args.learning_rate = sampled_parameters["learning_rate"]
    args.batch_size = sampled_parameters["batch_size"]
    args.weight_decay = sampled_parameters["weight_decay"] if sampled_parameters["weight_decay_bool"] else 0.0
    return args
```

Alternatively, for a solution with less repetition, you can use the `dataclasses.replace` function to update multiple fields of `args` at once:

```python
def update_args(args: WandbResNetFinetuningArgs, sampled_parameters: dict) -> WandbResNetFinetuningArgs:
    assert set(sampled_parameters.keys()) == set(sweep_config["parameters"].keys())

    sampled_parameters["weight_decay"] *= float(sampled_parameters.pop("weight_decay_bool"))
    return replace(args, **sampled_parameters)
```

If you use this solution, you need to be careful that the names of your fields in `sweep_config` match the names of the fields in `WandbResNetFinetuningArgs`.

</details>

Now we've done this, we'll define a `train` function that takes no arguments and launches a training run with our modified hyperparameters. This is done in the following way:

- The train function calls `wandb.init`
- Our sampled hyperparameters are now available in `wandb.config`, so we use this object to update `args`
- We then launch a training run based on these new hyperparameters

The line `sweep_id = wandb.sweep(...)` initializes a hyperparameter sweep (giving it an ID) and the line `wandb.agent(...)` starts an agent that runs the training script `train` 3 times, with different randomly sampled sets of hyperparameters each time.

Note that we pass `reinit=False` into our `wandb.init` call - this is so we ignore the second `wandb.init` call that takes place in our pretraining setup when we run `trainer.train()` (so we can avoid the hassle of having to rewrite this method to remove that line).

In [None]:
def train():
    # Define args & initialize wandb
    args = WandbResNetFinetuningArgs()
    wandb.init(project=args.wandb_project, name=args.wandb_name, reinit=False)

    # After initializing wandb, we can update args using `wandb.config`
    args = update_args(args, dict(wandb.config))

    # Train the model with these new hyperparameters (the second `wandb.init` call will be ignored)
    trainer = WandbResNetFinetuner(args)
    trainer.train()


sweep_id = wandb.sweep(sweep=sweep_config, project="day3-resnet-sweep")
wandb.agent(sweep_id=sweep_id, function=train, count=3)
wandb.finish()

When you run this code, you should click on the link which looks like:

```
View sweep at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/sweeps/<SWEEP-NAME>
```

This link will bring you to a page comparing each of your sweeps. You'll be able to see overlaid graphs of each of their training loss and test accuracy, as well as a bunch of other cool things like:

* Bar charts of the [importance](https://docs.wandb.ai/ref/app/features/panels/parameter-importance) (and correlation) of each hyperparameter wrt the target metric. Note that only looking at the correlation could be misleading - something can have a correlation of 1, but still have a very small effect on the metric.
* A [parallel coordinates plot](https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates), which summarises the relationship between the hyperparameters in your config and the model metric you're optimising.

What can you infer from these results? Are there any hyperparameters which are especially correlated / anticorrelated with the target metric? Are there any results which suggest the model is being undertrained?

You might also want to play around with Bayesian hyperparameter search, if you get the time! Note that wandb sweeps also offer [early termination](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration/#early_terminate) of runs that don't look promising, based on the [Hyperband](https://www.jmlr.org/papers/volume18/16-558/16-558.pdf) algorithm.

To conclude - `wandb` is an incredibly useful tool when training models, and you should find yourself using it a fair amount throughout this program. You can always return to this page of exercises if you forget how any part of it works!

# 3️⃣ Distributed Training

> ##### Learning Objectives
>
> * Understand the different kinds of parallelization used in deep learning (data, pipeline, tensor)
> * Understand how primitive operations in `torch.distributed` work, and how they come together to enable distributed training
> * Launch and benchmark your own distributed training runs, to train your implementation of `ResNet34` from scratch

## Intro to distributed training

Distributed training is a model training paradigm that involves spreading training workload across multiple worker nodes, therefore significantly improving the speed of training and model accuracy. While distributed training can be used for any type of ML model training, it is most beneficial to use it for large models and compute demanding tasks as deep learning.

There are 2 main families of distributed training methods: **data parallelism** and **model parallelism**. In data parallelism, we split batches of data across different processes, run forward & backward passes on each separately, and accumulate the gradients to update the model parameters. In model parallelism, the model is segmented into different parts that can run concurrently in different nodes, and each one runs on the same data. Model parallelism further splits into horizontal and vertical parallelism depending on whether we're splitting the model up into sequential or parallel parts. Most often horizontal parallelism is called **tensor parallelism** (because it involves splitting up the weights in a single layer across multiple GPUs, into what we commonly call **sharded weights**), and vertical parallelism is called **pipeline parallelism**.

<img src="https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/parallelism.png" width="1100">

Data & model parallelism are both widely used, and can be more or less appropriate in different circumstances (e.g. some kind of model parallelism is necessary when your model is too large to fit on a single GPU). However it is possible to create hybrid forms of parallelism by combining these; this is especially common when training large models like current SOTA LLMs. In these exercises, we'll focus on just data parallelism, although we'll suggest a few bonus exercises that explore model parallelism.

### Summary of exercises

The exercises below will take you through **data parallelism**. You'll start by learning how to use the basic send and receive functions in PyTorch's distributed training module `torch.distributed` to transfer tensors between different processes (and different GPUs). Note that **you'll need multiple GPUs for these exercises** - we've included instructions in a dropdown below.

<details>
<summary>Getting multiple GPUs</summary>

The instructions for booting up a machine from vastai can already be found on the Streamlit homepage (i.e. navigate to "Home" on the sidebar, then to the section "Virtual Machines"). The only extra thing you'll need to do here is filter for an appropriate machine type.

We recommend filtering for "Disk Space To Allocate" (i.e. the primary slider on the top of the filter menu) of at least 40GB, not for the model (which is actually quite small) but for installing the ARENA dependencies. You should also filter for number of GPUs: we recommend 4x or 8x. You can do this using the options menu at the very top of the list of machines. Lastly, we recommend filtering for a decent PCIE Bandwidth (e.g. at least 20GB/s) - this is important for efficient gradient sychronization between GPUs. We're training a small model today: approx 22m parameters, which translates to ~88MB total size of weights, and so we'll transfer 88MB of data between GPUs per process (since we're transferring the model's gradients, which have the same size as the weights). We don't want this to be a bottleneck, which is why we should filter for this bandwidth.

Once you've filtered for this, we recommend picking an RTX 3090 or 4090 machine. These won't be as powerful as an A100, but the purpose today is more to illustrate the basic ideas behind distributed training than to push your model training to its limits. Note that if you were using an A100 then you should filter for a high NVLink Bandwidth rather than PCIE (since A100s use NVLink instead of PCIE).

</details>

Once you've done this, you'll use those 2 primitive point-to-point functions to build up some more advanced functions: `broadcast` (which gets a tensor from one process to all others), `gather` (which gathers all tensors from different devices to a single device) and `all_reduce` (which combines both `broadcast` and `gather` to make aggregate tensor values across all processes). These functions (`all_reduce` in particular) are key parts of how distributed computing works.

Lastly, you'll learn how to use these functions to build a distributed training loop, which will be adapted from the `ResNetTrainer` code from your previous exercises. We also explain how you can use `DistributedDataParallel` to abstract away these low-level operations, which you might find useful later on (although you will benefit from building these components up from scratch, and understanding how they work under the hood).

### Running these exercises

> These exercises can't all be run in a notebook or Colab, because distributed training typically requires spawning multiple processes and Jupyter notebooks run in a single interactive process - they're not designed for this kind of use-case.

You have 2 different options:

1. **Do everything in a Python file** (either `# %%`-separated cells or [execute on selection](https://stackoverflow.com/questions/38952053/how-can-i-run-text-selected-in-the-active-editor-in-vs-codes-integrated-termina)), but make sure to wrap any execution code in `if __name__ == "__main__":` blocks. This makes sure that when you launch multiple processes they don't recursively launch their own processes, and they'll only execute the code you want them to.
2. **Write your functions in a Python file, then import & run them in a notebook**. For example in the example code below, you could define the `send_receive` function in a Python file, then import this function & pass it into the `mp.spawn()` call.

In either case, make sure when you run `mp.spawn` you're passing in the most updated version of your function. This means saving the Python file after you make changes, and also using something like `importlib.reload()` if you're running the code in a notebook.

In [None]:
IN_COLAB = "google.colab" in sys.modules
assert not IN_COLAB, "Should be doing these exercises in VS Code"

## Basic send & receiving

The code below is a simplified example that demonstrates distributed communication between multiple processes.

At the highest level, `mp.spawn()` launches multiple worker processes, each with a unique rank. For each worker, we create a new Python interpreter (called a "child process") which will execute the function passed to `mp.spawn` (which in this case is `send_receive`). The function has to have the type signature `fn(rank, *args)` where `args` is the tuple we pass into `mp.spawn()`. The total number of processes is determined by `world_size`. Note that this isn't the same as the number of GPUs - in fact, in the code below we've not moved any of our data to GPUs, we're just using the distributed API to sync data across multiple processes. We'll introduce GPUs in the code below this!

We require the environment variables `MASTER_ADDR` and `MASTER_PORT` to be set before launching & communicating between processes. The former specifies the IP address or hostname of the machine that will act as the central coordinator (the "master" node) for setting up and managing the distributed environment, while the latter specifies the port number that the master node will use for communication. In our case we're running all our processes from a single node, so all we need is for this to be an unused port on our machine.

Now, breaking down the `send_receive` function line by line:

- `dist.init_process_group` initializes each process with a common address and port, and a communication backend. It also gives each process a unique rank, so they know who is sending & receiving data.
- If the function is being run by rank 0, then we create a tensor of zeros and send it using `dist.send`.
- If the function is being run by rank 1, then we create a tensor of ones and wait to receive a tensor from rank 0 using `dist.recv`. This will overwrite the data in the original tensor that we created, i.e. so we're just left with a tensor of zeros.
- `dist.destroy_process_group()` is called at the end of the function to destroy the process group and release resources.

The functions `dist.send` and `dist.recv` are the basic primitives for point-to-point communication between processes (we'll look at the primitives for collective communication later on). Each `recv` for a given source process `src` will wait until it receives a `send` from that source to continue, and likewise each `send` to a given destination process `dst` will wait until it receives a `recv` from that process to continue. We call these **blocking operations** (later on we'll look at non-blocking operations).

In [None]:
WORLD_SIZE = t.cuda.device_count()

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12345"


def send_receive(rank, world_size):
    dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)

    if rank == 0:
        # Send tensor to rank 1
        sending_tensor = t.zeros(1)
        print(f"{rank=}, sending {sending_tensor=}")
        dist.send(tensor=sending_tensor, dst=1)
    elif rank == 1:
        # Receive tensor from rank 0
        received_tensor = t.ones(1)
        print(f"{rank=}, creating {received_tensor=}")
        dist.recv(received_tensor, src=0)  # this line overwrites the tensor's data with our `sending_tensor`
        print(f"{rank=}, received {received_tensor=}")

    dist.destroy_process_group()


if MAIN:
    world_size = 2  # simulate 2 processes
    mp.spawn(send_receive, args=(world_size,), nprocs=world_size, join=True)

Now, let's adapt this toy example to work with our multiple GPUs! You can check how many GPUs you have access to using `torch.cuda.device_count()`.

In [None]:
assert t.cuda.is_available()
assert t.cuda.device_count() > 1, "This example requires at least 2 GPUs per machine"

Before writing our new code, let's first return to the `backend` argument for `dist.init_process_group`. There are 3 main backends for distributed training: MPI, GLOO and NCCL. The first two are more general-purpose and support both CPU & GPU tensor communication, while NCCL is a GPU-only protocol optimized specifically for NVIDIA GPUs. It provides better bandwidth and lower latency for GPU-GPU communication, and so we'll be using it for subsequent exercises.

When sending & receiving tensors between GPUs with a NCCL backend, there are 3 important things to remember:

1. Send & received tensors should be of the same datatype.
2. Tensors need to be moved to the GPU before sending or receiving.
3. No two processes should be using the same GPU.

Because of this third point, each process `rank` will be using the GPU with index `rank` - hence we'll sometimes refer to the process rank and its corresponding GPU index interchangeably. However it's worth emphasizing that this only applies to our specific data parallelism & NCCL backend example, and so this correspondence doesn't have to exist in general.

The code below is a slightly modified version of the prior code; all we're doing is changing the backend to NCCL & moving the tensors to the appropriate device before sending or receiving.

Note - if at any point during this section you get errors related to the socket, then you can kill the processes by running `kill -9 <pid>` where `<pid>` is the process ID. If the process ID isn't given in the error message, you can find it using `lsof -i :<port>` where `<port>` is the port number specified in `os.environ["MASTER_PORT"]` (note you might have to `sudo apt-get install lsof` before you can run this). If your code is still failing, try changing the port in `os.environ["MASTER_PORT"]` and running it again.

<!-- Note - an alternative to explicitly defining the device here is to run the line `torch.cuda.set_device(rank)`, then code like `tensor.cuda()` will automatically send the tensor to the correct device. Which one you use is a matter of preference, however for the solutions & demo code we'll stick with the explicit device definition. -->

In [None]:
def send_receive_nccl(rank, world_size):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    device = t.device(f"cuda:{rank}")

    if rank == 0:
        # Create a tensor, send it to rank 1
        sending_tensor = t.tensor([rank], device=device)
        print(f"{rank=}, {device=}, sending {sending_tensor=}")
        dist.send(sending_tensor, dst=1)  # Send tensor to CPU before sending
    elif rank == 1:
        # Receive tensor from rank 0 (it needs to be on the CPU before receiving)
        received_tensor = t.tensor([rank], device=device)
        print(f"{rank=}, {device=}, creating {received_tensor=}")
        dist.recv(received_tensor, src=0)  # this line overwrites the tensor's data with our `sending_tensor`
        print(f"{rank=}, {device=}, received {received_tensor=}")

    dist.destroy_process_group()


if MAIN:
    world_size = 2  # simulate 2 processes
    mp.spawn(send_receive_nccl, args=(world_size,), nprocs=world_size, join=True)

## Collective communication primitives

We'll now move from basic point-to-point communication to **collective communication**. This refers to operations that synchronize data across multiple processes, rather than just between a single sender and receiver. There are 3 important kinds of collective communication functions:

- **Broadcast**: send a tensor from one process to all other processes
- **Gather**: collect tensors from all processes and concatenates them into a single tensor
- **Reduce**: like gather, but perform a reduction operation (e.g. sum, mean) rather than concatenation

The latter 2 functions have different variants depending on whether you want the final result to be in just a single destination process or in all of them: for example `dist.gather` will gather data to a single destination process, while `dist.all_gather` will make sure every process ends up with all the data.

The functions we're most interested in building are `broadcast` and `all_reduce` - the former for making sure all processes have the same initial model parameters, and the latter for aggregating gradients across all processes.

### Exercise - implement `broadcast`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 10-20 minutes on this exercise.
> ```

Below, you should implement `broadcast`. If you have tensor $T_i$ on process $i$ for each index, then after running this function you should have $T_s$ on all processes, where $s$ is the source process. If you're confused, you can see exactly what is expected of you by reading the test code in `tests.py`. Again, remember that you should be running tests either from the command line or in the Python interactive terminal, not in a notebook cell.

In [None]:
def broadcast(tensor: Tensor, rank: int, world_size: int, src: int = 0):
    """
    Broadcast averaged gradients from rank 0 to all other ranks.
    """
    raise NotImplementedError()


if MAIN:
    tests.test_broadcast(broadcast, WORLD_SIZE)

<details><summary>Solution</summary>

```python
def broadcast(tensor: Tensor, rank: int, world_size: int, src: int = 0):
    """
    Broadcast averaged gradients from rank 0 to all other ranks.
    """
    if rank == src:
        for other_rank in range(world_size):
            if other_rank != src:
                dist.send(tensor, dst=other_rank)
    else:
        received_tensor = t.zeros_like(tensor)
        dist.recv(received_tensor, src=src)
        tensor.copy_(received_tensor)
```
</details>

### Exercise - implement `all_reduce`

> ```yaml
> Difficulty: 🔴🔴🔴⚪⚪
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 10-20 minutes on this exercise.
> ```

You should now implement `reduce` and `all_reduce`. The former will aggregate the tensors at some destination process (either sum or mean), and the latter will do the same but then broadcast the result to all processes.

Note, more complicated allreduce algorithms exist than this naive one, and you'll be able to look at some of them in the bonus material.

In [None]:
def reduce(tensor, rank, world_size, dst=0, op: Literal["sum", "mean"] = "sum"):
    """
    Reduces gradients to rank `dst`, so this process contains the sum or mean of all tensors across processes.
    """
    raise NotImplementedError()


def all_reduce(tensor, rank, world_size, op: Literal["sum", "mean"] = "sum"):
    """
    Allreduce the tensor across all ranks, using 0 as the initial gathering rank.
    """
    raise NotImplementedError()


if MAIN:
    tests.test_reduce(reduce, WORLD_SIZE)
    tests.test_all_reduce(all_reduce, WORLD_SIZE)

<details><summary>Solution</summary>

```python
def reduce(tensor, rank, world_size, dst=0, op: Literal["sum", "mean"] = "sum"):
    """
    Reduces gradients to rank `dst`, so this process contains the sum or mean of all tensors across processes.
    """
    if rank != dst:
        dist.send(tensor, dst=dst)
    else:
        for other_rank in range(world_size):
            if other_rank != dst:
                received_tensor = t.zeros_like(tensor)
                dist.recv(received_tensor, src=other_rank)
                tensor += received_tensor
    if op == "mean":
        tensor /= world_size


def all_reduce(tensor, rank, world_size, op: Literal["sum", "mean"] = "sum"):
    """
    Allreduce the tensor across all ranks, using 0 as the initial gathering rank.
    """
    reduce(tensor, rank, world_size, dst=0, op=op)
    broadcast(tensor, rank, world_size, src=0)
```
</details>

<!-- <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">Running reduce on dst=0, with initial tensors: [0, 0], [1, 2], [10, 20]
Rank 1, op='sum', expected non-reduced tensor([1., 2.]), got tensor([1., 2.])
Rank 1, op='mean', expected non-reduced tensor([0.3333, 0.6667]), got tensor([0.3333, 0.6667])
Rank 0, op='sum', expected reduced tensor([11., 22.]), got tensor([11., 22.])
Rank 2, op='sum', expected non-reduced tensor([10., 20.]), got tensor([10., 20.])
Rank 0, op='mean', expected reduced tensor([3.6667, 7.3333]), got tensor([3.6667, 7.3333])
Rank 2, op='mean', expected non-reduced tensor([3.3333, 6.6667]), got tensor([3.3333, 6.6667])
All tests in `test_reduce` passed!

Running all_reduce, with initial tensors: [0, 0], [1, 2], [10, 20]
Rank 1, op='sum', expected non-reduced tensor([11., 22.]), got tensor([11., 22.])
Rank 2, op='sum', expected non-reduced tensor([11., 22.]), got tensor([11., 22.])
Rank 0, op='sum', expected non-reduced tensor([11., 22.]), got tensor([11., 22.])
Rank 1, op='mean', expected non-reduced tensor([3.6667, 7.3333]), got tensor([3.6667, 7.3333])
Rank 2, op='mean', expected non-reduced tensor([3.6667, 7.3333]), got tensor([3.6667, 7.3333])
Rank 0, op='mean', expected non-reduced tensor([3.6667, 7.3333]), got tensor([3.6667, 7.3333])
All tests in `test_all_reduce` passed!</pre> -->

Once you've passed these tests, you can run the code below to see how this works for a toy example of model training. In this case our model just has a single parameter and we're performing gradient descent using the squared error between its parameters and the input data as our loss function (in other words we're training the model's parameters to equal the mean of the input data).

The data in the example below is the same as the rank index, i.e. `r = 0, 1`. For initial parameter `x = 2` this gives us errors of `(x-r).pow(2) = 4, 2` respectively, and gradients of `2x(x-r) = 8, 4`. Averaging these gives us a gradient of `6`, so after a single optimization step with learning rate `lr=0.1` we get our gradients changing to `2.0 - 0.6 = 1.4`.

In [None]:
class SimpleModel(t.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.param = t.nn.Parameter(t.tensor([2.0]))

    def forward(self, x: t.Tensor):
        return x - self.param


def run_simple_model(rank, world_size):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    device = t.device(f"cuda:{rank}")
    model = SimpleModel().to(device)  # Move the model to the device corresponding to this process
    optimizer = t.optim.SGD(model.parameters(), lr=0.1)

    input = t.tensor([rank], dtype=t.float32, device=device)
    output = model(input)
    loss = output.pow(2).sum()
    loss.backward()  # Each rank has separate gradients at this point

    print(f"Rank {rank}, before all_reduce, grads: {model.param.grad=}")
    all_reduce(model.param.grad, rank, world_size)  # Synchronize gradients
    print(f"Rank {rank}, after all_reduce, synced grads (summed over processes): {model.param.grad=}")

    optimizer.step()  # Step with the optimizer (this will update all models the same way)
    print(f"Rank {rank}, new param: {model.param.data}")

    dist.destroy_process_group()


if MAIN:
    world_size = 2
    mp.spawn(run_simple_model, args=(world_size,), nprocs=world_size, join=True)

## Full training loop

We'll now use everything we've learned to put together a full training loop! Rather than finetuning it which we've been doing so far, you'll be training your resnet from scratch (although still using the same CIFAR10 dataset). We've given you a function `get_untrained_resnet` which uses the `ResNet34` class from yesterday's solutions, although you're encouraged to replace this function with your implementation if you've completed those exercises.

There are 4 key elements you'll need to change from the non-distributed version of training:

1. **Weight broadcasting at initialization**
    - For each process you'll need to initialize your model and move it onto the corresponding GPU, but you also want to make sure each process is working with the same model. You do this by **broadcasting weights in the `__init__` method**, e.g. using process 0 as the shared source process.
    - Note - you may find you'll have to brodcast `param.data` rather than `param` when you iterate through the model's parameters, because broadcasting only works for tensors not parameters. Parameters are a special class wrapping around and extending standard PyTorch tensors - we'll look at this in more detail tomorrow!
2. **Dataloader sampling at each epoch**
    - Distributed training works by splitting each batch of data across all the running processes, and so we need to implement this by splitting each batch randomly across our GPUs.
    - Some sample code for this is given below - we recommend you start with this (although you're welcome to play around with some of the parameters here like `num_workers` and `pin_memory`).
3. **Parameter syncing after each training step**
    - After each `loss.backward()` call but before stepping with the optimizer, you'll need to use `all_reduce` to sync gradients across each parameter in the model.
    - Just like in the example we gave above, calling `all_reduce` on `param.grad` should work, because `.grad` is a standard PyTorch tensor.
4. **Aggregating correct predictions after each evaluation step**\*
    - We can also split the evaluation step across GPUs - we use `all_reduce` at the end of the `evaluate` method to sum the total number of correct predictions across GPUs.
    - This is optional, and often it's not implemented because the evaluation step isn't a bottleneck compared to training, however we've included it in our solutions for completeness.

<details>
<summary>Dataloader sampling example code</summary>

```python
self.train_sampler = t.utils.data.DistributedSampler(
    self.trainset,
    num_replicas=args.world_size, # we'll divide each batch up into this many random sub-batches
    rank=self.rank, # this determines which sub-batch this process gets
)
self.train_loader = t.utils.data.DataLoader(
    self.trainset,
    self.args.batch_size, # this is the sub-batch size, i.e. the batch size that each GPU gets
    sampler=self.train_sampler,
    num_workers=2,  # setting this low so as not to risk bottlenecking CPU resources
    pin_memory=True,  # this can improve data transfer speed between CPU and GPU
)

for epoch in range(self.args.epochs):
self.train_sampler.set_epoch(epoch)
for imgs, labels in self.train_loader:
    ...
```

</details>

### Exercise - complete `DistResNetTrainer`

> ```yaml
> Difficulty: 🔴🔴🔴🔴🔴
> Importance: 🔵🔵🔵⚪⚪
>
> You should spend up to 30-60 minutes on this exercise.
> If you get stuck on specific bits, you're encouraged to look at the solutions for guidance.
> ```

We've given you the function `dist_train_resnet_from_scratch` which you'll be able to pass into `mp.spawn` just like the examples above, and we've given you a very light template for the `DistResNetTrainer` class which you should fill in. Your job is just to make the 4 adjustments described above. We recommend not using inheritance for this, because there are lots of minor modifications you'll need to make to the previous code and so you won't be reducing code duplication by very much.

A few last tips before we get started:

- If your code is running slowly, we recommend you also `wandb.log` the duration of each stage of the training step from the rank 0 process (fwd pass, bwd pass, and `all_reduce` for parameter syncing), as well as logging the duration of the training & evaluation phases across the epoch. These kinds of logs are generally very helpful for debugging slow code.
- Since running this code won't directly return your model as output, it's good practice to save your model at the end of training using `torch.save`.
- We recommend you increment `examples_seen` by the total number of examples across processes, i.e. `len(input) * world_size`. This will help when you're comparing across different runs with different world sizes (it's convenient for them to have a consistent x-axis).

In [None]:
def get_untrained_resnet(n_classes: int) -> ResNet34:
    """Gets untrained resnet using code from part2_cnns.solutions (you can replace this with your implementation)."""
    resnet = ResNet34()
    resnet.out_layers[-1] = Linear(resnet.out_features_per_group[-1], n_classes)
    return resnet


@dataclass
class DistResNetTrainingArgs(WandbResNetFinetuningArgs):
    world_size: int = 1
    wandb_project: str | None = "day3-resnet-dist-training"


class DistResNetTrainer:
    args: DistResNetTrainingArgs

    def __init__(self, args: DistResNetTrainingArgs, rank: int):
        self.args = args
        self.rank = rank
        self.device = t.device(f"cuda:{rank}")

    def pre_training_setup(self):
        raise NotImplementedError()

    def training_step(self, imgs: Tensor, labels: Tensor) -> Tensor:
        raise NotImplementedError()

    @t.inference_mode()
    def evaluate(self) -> float:
        raise NotImplementedError()

    def train(self):
        raise NotImplementedError()


def dist_train_resnet_from_scratch(rank, world_size):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    args = DistResNetTrainingArgs(world_size=world_size)
    trainer = DistResNetTrainer(args, rank)
    trainer.train()
    dist.destroy_process_group()


if MAIN:
    world_size = t.cuda.device_count()
    mp.spawn(dist_train_resnet_from_scratch, args=(world_size,), nprocs=world_size, join=True)

<details><summary>Solution</summary>

```python
def get_untrained_resnet(n_classes: int) -> ResNet34:
    """Gets untrained resnet using code from part2_cnns.solutions (you can replace this with your implementation)."""
    resnet = ResNet34()
    resnet.out_layers[-1] = Linear(resnet.out_features_per_group[-1], n_classes)
    return resnet


@dataclass
class DistResNetTrainingArgs(WandbResNetFinetuningArgs):
    world_size: int = 1
    wandb_project: str | None = "day3-resnet-dist-training"


class DistResNetTrainer:
    args: DistResNetTrainingArgs

    def __init__(self, args: DistResNetTrainingArgs, rank: int):
        self.args = args
        self.rank = rank
        self.device = t.device(f"cuda:{rank}")

    def pre_training_setup(self):
        self.model = get_untrained_resnet(self.args.n_classes).to(self.device)
        if self.args.world_size > 1:
            for param in self.model.parameters():
                broadcast(param.data, self.rank, self.args.world_size, src=0)
                # dist.broadcast(param.data, src=0)

        self.optimizer = t.optim.AdamW(
            self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay
        )

        self.trainset, self.testset = get_cifar()
        self.train_sampler = self.test_sampler = None
        if self.args.world_size > 1:
            self.train_sampler = DistributedSampler(self.trainset, num_replicas=self.args.world_size, rank=self.rank)
            self.test_sampler = DistributedSampler(self.testset, num_replicas=self.args.world_size, rank=self.rank)
        dataloader_shared_kwargs = dict(batch_size=self.args.batch_size, num_workers=2, pin_memory=True)
        self.train_loader = DataLoader(self.trainset, sampler=self.train_sampler, **dataloader_shared_kwargs)
        self.test_loader = DataLoader(self.testset, sampler=self.test_sampler, **dataloader_shared_kwargs)
        self.examples_seen = 0

        if self.rank == 0:
            wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)

    def training_step(self, imgs: Tensor, labels: Tensor) -> Tensor:
        t0 = time.time()

        # Forward pass
        imgs, labels = imgs.to(self.device), labels.to(self.device)
        logits = self.model(imgs)
        t1 = time.time()

        # Backward pass
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        t2 = time.time()

        # Gradient sychronization
        if self.args.world_size > 1:
            for param in self.model.parameters():
                all_reduce(param.grad, self.rank, self.args.world_size, op="mean")
                # dist.all_reduce(param.grad, op=dist.ReduceOp.SUM); param.grad /= self.args.world_size
        t3 = time.time()

        # Optimizer step, update examples seen & log data
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.examples_seen += imgs.shape[0] * self.args.world_size
        if self.rank == 0:
            wandb.log(
                {"loss": loss.item(), "fwd_time": (t1 - t0), "bwd_time": (t2 - t1), "dist_time": (t3 - t2)},
                step=self.examples_seen,
            )
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        self.model.eval()
        total_correct, total_samples = 0, 0

        for imgs, labels in tqdm(self.test_loader, desc="Evaluating", disable=self.rank != 0):
            imgs, labels = imgs.to(self.device), labels.to(self.device)
            logits = self.model(imgs)
            total_correct += (logits.argmax(dim=1) == labels).sum().item()
            total_samples += len(imgs)

        # Turn total_correct & total_samples into a tensor, so we can use all_reduce to sum them across processes
        tensor = t.tensor([total_correct, total_samples], device=self.device)
        all_reduce(tensor, self.rank, self.args.world_size, op="sum")
        total_correct, total_samples = tensor.tolist()

        accuracy = total_correct / total_samples
        if self.rank == 0:
            wandb.log({"accuracy": accuracy}, step=self.examples_seen)
        return accuracy

    def train(self):
        self.pre_training_setup()

        accuracy = self.evaluate()  # our evaluate method is the same as parent class

        for epoch in range(self.args.epochs):
            t0 = time.time()

            if self.args.world_size > 1:
                self.train_sampler.set_epoch(epoch)
                self.test_sampler.set_epoch(epoch)

            self.model.train()

            pbar = tqdm(self.train_loader, desc="Training", disable=self.rank != 0)
            for imgs, labels in pbar:
                loss = self.training_step(imgs, labels)
                pbar.set_postfix(loss=f"{loss:.3f}", ex_seen=f"{self.examples_seen=:06}")

            accuracy = self.evaluate()

            if self.rank == 0:
                wandb.log({"epoch_duration": time.time() - t0}, step=self.examples_seen)
                pbar.set_postfix(loss=f"{loss:.3f}", accuracy=f"{accuracy:.3f}", ex_seen=f"{self.examples_seen=:06}")

        if self.rank == 0:
            wandb.finish()
            t.save(self.model.state_dict(), f"resnet_{self.rank}.pth")


def dist_train_resnet_from_scratch(rank, world_size):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    args = DistResNetTrainingArgs(world_size=world_size)
    trainer = DistResNetTrainer(args, rank)
    trainer.train()
    dist.destroy_process_group()
```
</details>

## Bonus - DDP

In practice, the most convenient way to use DDP is to wrap your model in `torch.nn.parallel.DistributedDataParallel`, which removes the need for explicitly calling `broadcast` at the start and `all_reduce` at the end of each training step. When you define a model in this way, it will automatically broadcast its weights to all processes, and the gradients will sync after each `loss.backward()` call. Here's the example `SimpleModel` code from above, rewritten to use these features:

In [None]:
from torch.nn.parallel import DistributedDataParallel as DDP


def run(rank: int, world_size: int):
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)

    device = t.device(f"cuda:{rank}")
    model = DDP(SimpleModel().to(device), device_ids=[rank])  # Wrap the model with DDP
    optimizer = t.optim.SGD(model.parameters(), lr=0.1)

    input = t.tensor([rank], dtype=t.float32, device=device)
    output = model(input)
    loss = output.pow(2).sum()
    loss.backward()  # DDP handles gradient synchronization

    optimizer.step()
    print(f"Rank {rank}, new param: {model.module.param.data}")

    dist.destroy_process_group()


if MAIN:
    world_size = 2
    mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)

Can you use these features to rewrite your ResNet training code? Can you compare it to the code you wrote and see how much faster the built-in DDP version is? Note, you won't be able to separate the time taken for backward passes and gradient synchronization since these happen in the same line, but you can assume that the time taken for the backward pass is approximately unchanged and so any speedup you see is due to the better gradient synchronization.

## Bonus - ring operations

Our all reduce operation would scale quite badly when we have a large number of models. It chooses a single process as the source process to receive then send out all data, and so this process risks becoming a bottleneck. One of the most popular alternatives is **ring all-reduce**. Broadly speaking, ring-based algorithms work by sending data in a cyclic pattern (i.e. worker `n` sends it to worker `n+1 % N` where `N` is the total number of workers). After each sending round, we perform a reduction operation to the data that was just sent. [This blog post](https://andrew.gibiansky.com/blog/machine-learning/baidu-allreduce/) illustrates the ring all-reduce algorithm for the sum operation.

Can you implement the ring all-reduce algorithm by filling in the function below & passing tests? Once you've implemented it, you can compare the speed of your ring all-reduce vs the all-reduce we implemented earlier - is it faster? Do you expect it to be faster in this particular case?

In [None]:
def ring_all_reduce(tensor: Tensor, rank, world_size, op: Literal["sum", "mean"] = "sum") -> None:
    """
    Ring all_reduce implementation using non-blocking send/recv to avoid deadlock.
    """
    raise NotImplementedError()


if MAIN:
    tests.test_all_reduce(ring_all_reduce)

<details>
<summary>Solution</summary>

```python
def ring_all_reduce(tensor: Tensor, rank, world_size, op: Literal["sum", "mean"] = "sum") -> None:
    """
    Ring all_reduce implementation using non-blocking send/recv to avoid deadlock.
    """
    # Clone the tensor as the "send_chunk" for initial accumulation
    send_chunk = tensor.clone()

    # Step 1: Reduce-Scatter phase
    for _ in range(world_size - 1):
        # Compute the ranks involved in this round of sending/receiving
        send_to = (rank + 1) % world_size
        recv_from = (rank - 1 + world_size) % world_size

        # Prepare a buffer for the received chunk
        recv_chunk = t.zeros_like(send_chunk)

        # Non-blocking send and receive
        send_req = dist.isend(send_chunk, dst=send_to)
        recv_req = dist.irecv(recv_chunk, src=recv_from)
        send_req.wait()
        recv_req.wait()

        # Accumulate the received chunk into the tensor
        tensor += recv_chunk

        # Update send_chunk for the next iteration
        send_chunk = recv_chunk

    # Step 2: All-Gather phase
    send_chunk = tensor.clone()
    for _ in range(world_size - 1):
        # Compute the ranks involved in this round of sending/receiving
        send_to = (rank + 1) % world_size
        recv_from = (rank - 1 + world_size) % world_size

        # Prepare a buffer for the received chunk
        recv_chunk = t.zeros_like(send_chunk)

        # Non-blocking send and receive, and wait for completion
        send_req = dist.isend(send_chunk, dst=send_to)
        recv_req = dist.irecv(recv_chunk, src=recv_from)
        send_req.wait()
        recv_req.wait()

        # Update the tensor with received data
        tensor.copy_(recv_chunk)

        # Update send_chunk for the next iteration
        send_chunk = recv_chunk

    # Step 3: Average the final result
    if op == "mean":
        tensor /= world_size
```

We should expect this algorithm to be better when we scale up the number of GPUs, but it won't always be faster in small-world settings like ours, because the naive allreduce algorithm requires fewer individual communication steps and this could outweigh the benefits brought by the ring-based allreduce.

</details>

# ☆ Bonus

Congratulations for getting to the end of the main content! This section gives some suggestions for more features of Weights and Biases to explore, or some other experiments you can run.

## Scaling Laws

These bonus exercises are taken directly from Jacob Hilton's [online deep learning curriculum](https://github.com/jacobhilton/deep_learning_curriculum/blob/master/2-Scaling-Laws.md) (which is what the original version of the ARENA course was based on).

First, you can start by reading the [Chinchilla paper](https://arxiv.org/abs/2203.15556). This is a correction to the original scaling laws paper: parameter count scales linearly with token budget for compute-optimal models, not ~quadratically. The difference comes from using a separately-tuned learning rate schedule for each token budget, rather than using a single training run to measure performance for every token budget. This highlights the importance of hyperparameter tuning for measuring scaling law exponents.

You don't have to read the entire paper, just skim the graphs. Don't worry if they don't all make sense yet (it will be more illuminating when we study LLMs next week). Note that, although it specifically applies to language models, the key underlying ideas of tradeoffs between optimal dataset size and model size are generally applicable.

### Suggested exercise

Perform your own study of scaling laws for MNIST.

- Write a script to train a small CNN on MNIST, or find one you have written previously.
- Training for a single epoch only, vary the model size and dataset size. For the model size, multiply the width by powers of sqrt(2) (rounding if necessary - the idea is to vary the amount of compute used per forward pass by powers of 2). For the dataset size, multiply the fraction of the full dataset used by powers of 2 (i.e. 1, 1/2, 1/4, ...). To reduce noise, use a few random seeds and always use the full validation set.
- The learning rate will need to vary with model size. Either tune it carefully for each model size, or use the rule of thumb that for Adam, the learning rate should be proportional to the initialization scale, i.e. `1/sqrt(fan_in)` for the standard Kaiming He initialization (which is what PyTorch generally uses by default).
    - Note - `fan_in` refers to the variable $N_{in}$, which is `in_features` for a linear layer, and `in_channels * kernel_size * kernel_size` for a convolutional layer - in other words, the number of input parameters/activations we take a sumproduct over to get each output activation.
- Plot the amount of compute used (on a log scale) against validation loss. The compute-efficient frontier should follow an approximate power law (straight line on a log scale).
How does validation accuracy behave?
- Study how the compute-efficient model size varies with compute. This should also follow an approximate power law. Try to estimate its exponent.
- Repeat your entire experiment with 20% [dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html) to see how this affects the scaling exponents.

## Other WandB features

Here are a few more Weights & Biases features you might also want to play around with:

* [Logging media and objects in experiments](https://docs.wandb.ai/guides/track/log?fbclid=IwAR3NxKsGpEjZwq3vSwYkohZllMpBwxHgOCc_k0ByuD9XGUsi_Scf5ELvGsQ) - you'll be doing this during the RL week, and it's useful when you're training generative image models like VAEs and diffusion models.
* [Code saving](https://docs.wandb.ai/guides/app/features/panels/code?fbclid=IwAR2BkaXbRf7cqEH8kc1VzqH_kOJWGxqjUb_JCBq_SCnXOx1oF-Rt-hHydb4) - this captures all python source code files in the current director and all subdirectories. It's great for reproducibility, and also for sharing your code with others.
* [Saving and loading PyTorch models](https://wandb.ai/wandb/common-ml-errors/reports/How-to-Save-and-Load-Models-in-PyTorch--VmlldzozMjg0MTE?fbclid=IwAR1Y9MzFTxIiVBJG06b4ppitwKWR4H5_ncKyT2F_rR5Z_IHawmpBTKskPcQ) - you can do this easily using `torch.save`, but it's also possible to do this directly through Weights and Biases as an **artifact**.

## The Optimizer's Curse

The [optimizer's curse](https://www.lesswrong.com/posts/5gQLrJr2yhPzMCcni/the-optimizer-s-curse-and-how-to-beat-it) applies to tuning hyperparameters. The main take-aways are:

- You can expect your best hyperparameter combination to actually underperform in the future. You chose it because it was the best on some metric, but that metric has an element of noise/luck, and the more combinations you test the larger this effect is.
- Look at the overall trends and correlations in context and try to make sense of the values you're seeing. Just because you ran a long search process doesn't mean your best output is really the best.

For more on this, see [Preventing "Overfitting" of Cross-Validation Data](https://ai.stanford.edu/~ang/papers/cv-final.pdf) by Andrew Ng.