# [0.4] Optimization & Hyperparameters

Colab: [exercises](https://colab.research.google.com/drive/12PiedkLJH9GkIu5F_QC7iFf8gJb3k8mp) | [solutions](https://colab.research.google.com/drive/1yKhqfOJUTcdu0zhuRONjXLCmoLx7frxP)

[Streamlit page](https://arena-ch0-fundamentals.streamlit.app/[0.4]_Optimization)

Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-la82367/shared_invite/zt-1uvoagohe-JUv9xB7Vr143pdx1UBPrzQ), and ask any questions on the dedicated channels for this chapter of material.


<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/stats.png" width="350">


# Introduction


In today's exercises, we will explore various optimization algorithms and their roles in training deep learning models. We will delve into the inner workings of different optimization techniques such as Stochastic Gradient Descent (SGD), RMSprop, and Adam, and learn how to implement them using code. Additionally, we will discuss the concept of loss landscapes and their significance in visualizing the challenges faced during the optimization process. By the end of this set of exercises, you will have a solid understanding of optimization algorithms and their impact on model performance. We'll also take a look at Weights and Biases, a tool that can be used to track and visualize the training process, and test different values of hyperparameters to find the most effective ones.


## Content & Learning Objectives


### 2️⃣ Weights and Biases

In this section, we'll look at methods for choosing hyperparameters effectively. You'll learn how to use **Weights and Biases**, a useful tool for hyperparameter search. By the end of today, you should be able to use Weights and Biases to train the ResNet you created in the last set of exercises.

> ##### Learning Objectives
>
> * Learn what the most important hyperparameters are, and methods for efficiently searching over hyperparameter space
> * Adapt your code from yesterday to log training runs to Weights & Biases, and use this service to run hyperparameter sweeps

### 3️⃣ Bonus

This section gives you suggestions for further exploration of optimizers, and Weights & Biases.


## Setup (don't read, just run!)


In [1]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install torchinfo
    %pip install transformer_lens

    # Code to download the necessary files (e.g. solutions, test funcs)
    import os, sys
    if not os.path.exists("chapter0_fundamentals"):
        !wget https://github.com/callummcdougall/ARENA_2.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_2.0-main/chapter0_fundamentals/exercises/*'
        sys.path.append("/content/ARENA_2.0-main/chapter0_fundamentals/exercises")
        os.remove("/content/main.zip")
        os.rename("ARENA_2.0-main/chapter0_fundamentals", "chapter0_fundamentals")
        os.rmdir("ARENA_2.0-main")
        os.chdir("chapter0_fundamentals/exercises")
else:
    from IPython import get_ipython
    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Things that need to be done manually

# Nothing

# END

In [2]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
import pandas as pd
import torch as t
from torch import Tensor, optim
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader, Subset
from typing import Callable, Iterable, Tuple, Optional, Type
from jaxtyping import Float
from dataclasses import dataclass
from tqdm.notebook import tqdm
from pathlib import Path
import numpy as np
from IPython.display import display, HTML

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
if IN_COLAB:
    exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
else:
    exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/utils/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part4_optimization"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))
os.chdir(section_dir)

from plotly_utils import bar, imshow, plot_train_loss_and_test_accuracy_from_trainer
from part3_resnets.solutions import IMAGENET_TRANSFORM, ResNet34, get_resnet_for_feature_extraction
from part4_optimization.utils import plot_fn, plot_fn_with_points
import part4_optimization.tests as tests

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

  from .autonotebook import tqdm as notebook_tqdm


<details>
<summary>Help - I get a NumPy-related error</summary>

This is an annoying colab-related issue which I haven't been able to find a satisfying fix for. If you restart runtime (but don't delete runtime), and run just the imports cell above again (but not the `%pip install` cell), the problem should go away.
</details>


# 2️⃣ Weights and Biases


> ##### Learning Objectives
>
> * Learn what the most important hyperparameters are, and methods for efficiently searching over hyperparameter space
> * Adapt your code from yesterday to log training runs to Weights & Biases, and use this service to run hyperparameter sweeps


Next, we'll look at methods for choosing hyperparameters effectively. You'll learn how to use **Weights and Biases**, a useful tool for hyperparameter search.

The exercises themselves will be based on your ResNet implementations from yesterday, although the principles should carry over to other models you'll build in this course (such as transformers next week).

Note, this page only contains a few exercises, and they're relatively short. You're encouraged to spend some time playing around with Weights and Biases, but you should also spend some more time finetuning your ResNet from yesterday (you might want to finetune ResNet during the morning, and look at today's material in the afternoon - you can discuss this with your partner). You should also spend some time reviewing the last three days of material, to make sure there are no large gaps in your understanding.


## CIFAR10

The benchmark we'll be training on is [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html), which consists of 60000 32x32 colour images in 10 different classes. Don't peek at what other people online have done for CIFAR10 (it's a common benchmark), because the point is to develop your own process by which you can figure out how to improve your model. Just reading the results of someone else would prevent you from learning how to get the answers. To get an idea of what's possible: using one V100 and a modified ResNet, one entry in the DAWNBench competition was able to achieve 94% test accuracy in 24 epochs and 76 seconds. 94% is approximately [human level performance](http://karpathy.github.io/2011/04/27/manually-classifying-cifar10/).

Below is some boilerplate code for downloading and transforming `CIFAR10` data (this shouldn't take more than a minute to run the first time). There are a few differences between this and our code yesterday week - for instance, we omit the `transforms.Resize` from our `transform` object, because CIFAR10 data is already the correct size (unlike the sample images from last week).


In [1]:
def get_cifar(subset: int = 1):
    cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=IMAGENET_TRANSFORM)
    cifar_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=IMAGENET_TRANSFORM)
    if subset > 1:
        cifar_trainset = Subset(cifar_trainset, indices=range(0, len(cifar_trainset), subset))
        cifar_testset = Subset(cifar_testset, indices=range(0, len(cifar_testset), subset))
    return cifar_trainset, cifar_testset

cifar_trainset, cifar_testset = get_cifar()

imshow(
    cifar_trainset.data[:15],
    facet_col=0,
    facet_col_wrap=5,
    facet_labels=[cifar_trainset.classes[i] for i in cifar_trainset.targets[:15]],
    title="CIFAR-10 images",
    height=600
)

NameError: name 'datasets' is not defined

We have also provided a basic training & testing loop, almost identical to the one you used yesterday. This one doesn't use `wandb` at all, although it does plot the train loss and test accuracy when the function finishes running. You should run this function to verify your model is working, and that the loss is going down. Also, make sure you understand what each part of this function is doing.


## Train function - simple (from yesterday)


First, let's take our dataclass from yesterday, to hold our training arguments.

In [None]:
@dataclass
class ResNetTrainingArgs():
    batch_size: int = 64
    epochs: int = 3
    optimizer: Type[t.optim.Optimizer] = t.optim.Adam
    learning_rate: float = 1e-3
    n_classes: int = 10
    subset: int = 10

Next, here's some code to train our model (taken from yesterday's CNN solutions). You should feel free to replace it with your own implementation, if you prefer.


In [None]:
class ResNetTrainer:
	def __init__(self, args: ResNetTrainingArgs):
		self.args = args
		self.model = get_resnet_for_feature_extraction(args.n_classes).to(device)
		self.optimizer = args.optimizer(self.model.out_layers[-1].parameters(), lr=args.learning_rate)
		self.trainset, self.testset = get_cifar(subset=args.subset)
		self.logged_variables = {"loss": [], "accuracy": []}

	def _shared_train_val_step(self, imgs: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
		imgs = imgs.to(device)
		labels = labels.to(device)
		logits = self.model(imgs)
		return logits, labels

	def training_step(self, imgs: Tensor, labels: Tensor) -> t.Tensor:
		logits, labels = self._shared_train_val_step(imgs, labels)
		loss = F.cross_entropy(logits, labels)
		self.update_step(loss)
		return loss

	@t.inference_mode()
	def validation_step(self, imgs: Tensor, labels: Tensor) -> t.Tensor:
		logits, labels = self._shared_train_val_step(imgs, labels)
		classifications = logits.argmax(dim=1)
		n_correct = t.sum(classifications == labels)
		return n_correct

	def update_step(self, loss: Float[Tensor, '']):
		loss.backward()
		self.optimizer.step()
		self.optimizer.zero_grad()

	def train_dataloader(self):
        self.model.train()
		return DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True)

	def val_dataloader(self):
        self.model.eval()
		return DataLoader(self.testset, batch_size=self.args.batch_size, shuffle=True)

	def train(self):
		progress_bar = tqdm(total=self.args.epochs * len(self.trainset) // self.args.batch_size)
		accuracy = t.nan

		for epoch in range(self.args.epochs):

			# Training loop (includes updating progress bar)
			for imgs, labels in self.train_dataloader():
				loss = self.training_step(imgs, labels)
				self.logged_variables["loss"].append(loss.item())
				desc = f"Epoch {epoch+1}/{self.args.epochs}, Loss = {loss:.2f}, Accuracy = {accuracy:.2f}"
				progress_bar.set_description(desc)
				progress_bar.update()

			# Compute accuracy by summing n_correct over all batches, and dividing by number of items
			accuracy = sum(self.validation_step(imgs, labels) for imgs, labels in self.val_dataloader()) / len(self.testset)

			self.logged_variables["accuracy"].append(accuracy.item())

Lastly, we run our model.


In [None]:
args = ResNetTrainingArgs()
trainer = ResNetTrainer(args)
trainer.train()

In [None]:
plot_train_loss_and_test_accuracy_from_trainer(trainer, title="Training ResNet on MNIST data")

Let's see how well our ResNet performs on the first few inputs!


In [None]:
def test_resnet_on_random_input(model: ResNet34, n_inputs: int = 3):
	indices = np.random.choice(len(cifar_trainset), n_inputs).tolist()
	classes = [cifar_trainset.classes[cifar_trainset.targets[i]] for i in indices]
	imgs = cifar_trainset.data[indices]
	device = next(model.parameters()).device
	with t.inference_mode():
		x = t.stack(list(map(IMAGENET_TRANSFORM, imgs)))
		logits: t.Tensor = model(x.to(device))
	probs = logits.softmax(-1)
	if probs.ndim == 1: probs = probs.unsqueeze(0)
	for img, label, prob in zip(imgs, classes, probs):
		display(HTML(f"<h2>Classification probabilities (true class = {label})</h2>"))
		imshow(
			img,
			width=200, height=200, margin=0,
			xaxis_visible=False, yaxis_visible=False
		)
		bar(
			prob,
			x=cifar_trainset.classes,
			template="ggplot2", width=600, height=400,
			labels={"x": "Classification", "y": "Probability"},
			text_auto='.2f', showlegend=False,
		)


test_resnet_on_random_input(trainer.model)

## What is Weights and Biases?

Weights and Biases is a cloud service that allows you to log data from experiments. Your logged data is shown in graphs during training, and you can easily compare logs across different runs. It also allows you to run **sweeps**, where you can specifiy a distribution over hyperparameters and then start a sequence of test runs which use hyperparameters sampled from this distribution.

Before you run any of the code below, you should visit the [Weights and Biases homepage](https://wandb.ai/home), and create your own account.


In [None]:
import wandb

We'll be able to keep the same structure of training loop when using weights and biases, we'll just have to add a few functions. The key functions to know are:


#### `wandb.init`

This initialises a training run. It should be called once, at the start of your training loop.

A few important arguments are:

* `project` - the name of the project where you're sending the new run. For example, this could be `'day4-resnet'` for us. You can have many different runs in each project.
* `name` - a display name for this run. By default, if this isn't supplied, wandb generates a random 2-word name for you (e.g. `gentle-sunflower-42`).
* `config` - a dictionary containing hyperparameters for this run. If you pass this dictionary, then you can compare different runs' hyperparameters & results in a single table. Alternatively, you can pass a dataclass.

For these first two, we'll create a new dataclass (which inherits from our previous one, so gets all the same data plus this new data):

In [None]:
@dataclass
class ResNetTrainingArgsWandb(ResNetTrainingArgs):
    wandb_project: Optional[str] = 'day4-resnet'
    wandb_name: Optional[str] = None

#### `wandb.watch`

This function tells wandb to watch a model. This means that it will log the gradients and parameters of the model during training. We'll call this function once, after we've created our model.

The first argument to this function is your model (or a list of models). Another two important arguments:

* `log`, which can take the value `'gradients'`, `'parameters'`, or `'all'`, and which determines what gets tracked. Default is `'gradients'`.
* `log_freq`, which is an integer. Logging happens once every `log_freq` batches. Default is 1000.

#### `wandb.log`

For logging metrics to the wandb dashboard. This is used as `wandb.log(data, step)`, where `step` is an integer (the x-axis on your metric plots) and `data` is a dictionary of metrics (i.e. the keys are metric names, and the values are metric values).

#### `wandb.finish`

This function should be called at the end of your training loop. It finishes the run, and saves the results to the cloud. If you terminate a run early, remember to still call this.

### Exercise - rewrite training loop

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠🟠⚪

You should spend up to 10-20 minutes on this exercise.
```

You should now take the training loop from above (or your own implementation - whichever you prefer) and rewrite it to use the four weights and biases functions above.

A few notes:

* You can get rid of the dictionary for logging variables, because you'll be using `wandb.log` instead.
* It's often useful to have a global `step` variable in your training code, which keeps track of the number of update steps which have taken place. You can pass this step argument to `wandb.log`. Alternatively, you can just omit the `step` argument.
* If you use `wandb.watch`, you'll need to decrease the `log_freq` value (since your training by default has less than 1000 batches). You'll also want to log just the parameters that are changing (remember that most of them are frozen). See yesterday's code for how to do this (specifically, the `get_resnet_for_feature_extraction` function).



In [None]:
# YOUR CODE HERE - write `ResNetTrainerWandb` class


args = ResNetTrainingArgsWandb()
trainer = ResNetTrainerWandb(args)
trainer.train()

When you run the code for the first time, you'll have to login to Weights and Biases, and paste an API key into VSCode. After this is done, your Weights and Biases training run will start. It'll give you a lot of output text, one line of which will look like:

```
View run at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/runs/<RUN-NAME>
```

which you can click on to visit the run page.

A nice thing about using Weights and Biases is that you don't need to worry about generating your own plots, that will all be done for you when you visit the page.

<details>
<summary>Solution (one implementation)</summary>

```python
class ResNetTrainerWandb:
	def __init__(self, args: ResNetTrainingArgsWandb):
		self.args = args
		self.model = get_resnet_for_feature_extraction(args.n_classes).to(device)
		self.optimizer = args.optimizer(self.model.out_layers[-1].parameters(), lr=args.learning_rate)
		self.trainset, self.testset = get_cifar(subset=args.subset)
		self.step = 0
		wandb.init(project=args.wandb_project, name=args.wandb_name, config=args)
		wandb.watch(self.model.out_layers[-1], log="all", log_freq=20)

	def _shared_train_val_step(self, imgs: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
		imgs = imgs.to(device)
		labels = labels.to(device)
		logits = self.model(imgs)
		return logits, labels

	def training_step(self, imgs: Tensor, labels: Tensor) -> t.Tensor:
		logits, labels = self._shared_train_val_step(imgs, labels)
		loss = F.cross_entropy(logits, labels)
		self.update_step(loss)
		return loss

	@t.inference_mode()
	def validation_step(self, imgs: Tensor, labels: Tensor) -> t.Tensor:
		logits, labels = self._shared_train_val_step(imgs, labels)
		classifications = logits.argmax(dim=1)
		n_correct = t.sum(classifications == labels)
		return n_correct

	def update_step(self, loss: Float[Tensor, '']):
		loss.backward()
		self.optimizer.step()
		self.optimizer.zero_grad()
		self.step += 1

	def train_dataloader(self):
		self.model.train()
		return DataLoader(self.trainset, batch_size=self.args.batch_size, shuffle=True)

	def val_dataloader(self):
		self.model.eval()
		return DataLoader(self.testset, batch_size=self.args.batch_size, shuffle=True)

	def train(self):
		progress_bar = tqdm(total=self.args.epochs * len(self.trainset) // self.args.batch_size)
		accuracy = t.nan

		for epoch in range(self.args.epochs):

			# Training loop (includes updating progress bar)
			for imgs, labels in self.train_dataloader():
				loss = self.training_step(imgs, labels)
				wandb.log({"loss": loss.item()}, step=self.step)
				desc = f"Epoch {epoch+1}/{self.args.epochs}, Loss = {loss:.2f}, Accuracy = {accuracy:.2f}"
				progress_bar.set_description(desc)
				progress_bar.update()

			# Compute accuracy by summing n_correct over all batches, and dividing by number of items
			accuracy = sum(self.validation_step(imgs, labels) for imgs, labels in self.val_dataloader()) / len(self.testset)
			wandb.log({"accuracy": accuracy.item()}, step=self.step)

		wandb.finish()
```
</details>

### Run & project pages

The page you visit will show you a plot of all the variables you've logged, among other things. You can do many things with these plots (e.g. click on the "edit" icon for your `train_loss` plot, and apply smoothing to get a better picture of your loss curve).

The charts are a useful feature of the run page that gets opened when you click on the run page link, but they're not the only feature. You can also navigate to the project page (click on the option to the right of **Projects** on the bar at the top of the Wandb page), and see superimposed plots of all the runs in this project. You can also click on the **Table** icon on the left hand sidebar to see a table of all the runs in this project, which contains useful information (e.g. runtime, the most recent values of any logged variables, etc). However, comparing runs like this becomes especially useful when we start doing hyperparameter search.

## Hyperparameter search

One way to search for good hyperparameters is to choose a set of values for each hyperparameter, and then search all combinations of those specific values. This is called **grid search**. The values don't need to be evenly spaced and you can incorporate any knowledge you have about plausible values from similar problems to choose the set of values. Searching the product of sets takes exponential time, so is really only feasible if there are a small number of hyperparameters. I would recommend forgetting about grid search if you have more than 3 hyperparameters, which in deep learning is "always".

A much better idea is for each hyperparameter, decide on a sampling distribution and then on each trial just sample a random value from that distribution. This is called **random search** and back in 2012, you could get a [publication](https://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf) for this. The diagram below shows the main reason that random search outperforms grid search. Empirically, some hyperparameters matter more than others, and random search benefits from having tried more distinct values in the important dimensions, increasing the chances of finding a "peak" between the grid points.

<img src="https://raw.githubusercontent.com/callummcdougall/Fundamentals/main/images/grid_vs_random.png" width="540">


It's worth noting that both of these searches are vastly less efficient than gradient descent at finding optima - imagine if you could only train neural networks by randomly initializing them and checking the loss! Either of these search methods without a dose of human (or eventually AI) judgement is just a great way to turn electricity into a bunch of models that don't perform very well.

## Running hyperparameter sweeps with `wandb`

Now we've come to one of the most impressive features of `wandb` - being able to perform hyperparameter sweeps.

To perform hyperparameter sweeps, we follow the following 3-step process:

1. Define a `sweep_config` dict, which specifies how we'll randomize hyperparameters during our sweep (more on the exact syntax of this below).
2. Define a training function. Importantly, this function just take no arguments (i.e. you should be able to call it as `train()`).
    * You will be able to access the values inside `sweep_config["parameters"]` dict using `wandb.config`, you should do this inside your `train()` function.
3. Run a sweep, using the `wandb.sweep` and `wandb.agent` functions. This will run the training function with different hyperparameters each time.

### Exercise - define a sweep config (step 1)

```c
Difficulty: 🟠🟠⚪⚪⚪
Importance: 🟠🟠🟠⚪⚪

You should spend up to 10-15 minutes on this exercise.

Learning how to use wandb for sweeps is very useful, so make sure you understand all parts of this code.
```

You should define a dictionary `sweep_config`, which sets out the following rules for hyperparameter sweeps:

* Hyperparameters are chosen **randomly**, according to the distributions given in the dictionary
* Your goal is to **maximize** the **accuracy** metric
* The hyperparameters you vary are:
    * `learning_rate` - a log-uniform distribution between 1e-4 and 1e-1
    * `batch_size` - randomly chosen from (32, 64, 128, 256)
    * `epochs` - randomly chosen from (1, 2, 3)

*(A note on the log-uniform distribution - this means a random value `X` will be chosen between `min` and `max` s.t. `log(X)` is uniformly distributed between `log(min)` and `log(max)`. Can you see why a log uniform distribution for the learning rate makes more sense than a uniform distribution?)*

You can read the syntax for sweep config dictionaries [here](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration).

In [None]:
sweep_config = dict()

# YOUR CODE HERE - fill `sweep_config`


tests.test_sweep_config(sweep_config)

<details>
<summary>Solution</summary>

```python
sweep_config = dict(
    method = 'random',
    metric = dict(name = 'accuracy', goal = 'maximize'),
    parameters = dict(
        batch_size = dict(values = [32, 64, 128, 256]),
        epochs = dict(min = 1, max = 4),
        learning_rate = dict(max = 0.1, min = 0.0001, distribution = 'log_uniform_values'),
    )
)
```
</details>

### Exercise - define a training function (step 2)

```c
Difficulty: 🟠🟠🟠⚪⚪
Importance: 🟠🟠⚪⚪⚪

You should spend up to 10-15 minutes on this exercise.
```

Now, we have a `train` function. This takes no arguments, and it implements a training loop just like we've seen before. Note that we've set things like `args.batch_size` from the `wandb.config` dictionary, which is how we access the hyperparameters which are set at the start of each sweep.

<details>
<summary>Question - why do we set these parameters after defining our <code>args</code> object?</summary>

Weights & Biases is very particular about the order in which things are defined. In particular, we can't access the `wandb.config` object until we've called `wandb.init()`. The easiest way to get around this is to slightly restructure our trainer's `__init__` method as follows:

* Call `wandb.init` first (don't pass a `config` argument),
* Then override the values in `args` with the values in `wandb.config`,
* Then do the rest of the things in your trainer's `__init__` method (e.g. defining your model and optimizer).

You might want to create a class `ResNetTrainingArgsWandbSweeps` for this purpose (which inherits from `ResNetTrainingArgsWandb`, but has a different `__init__` method).

</details>

In [None]:
# (2) Define a training function which takes no arguments, and uses `wandb.config` to get hyperparams

def train():
    pass


<details>
<summary>Solution</summary>

```python
class ResNetTrainerWandbSweeps(ResNetTrainerWandb):
    '''
    New training class made specifically for hyperparameter sweeps, which overrides the values in `args` with
    those in `wandb.config` before defining model/optimizer/datasets.
    '''
	def __init__(self, args: ResNetTrainingArgsWandb):
		wandb.init(project=args.wandb_project, name=args.wandb_name)
		args.batch_size = wandb.config["batch_size"]
		args.epochs = wandb.config["epochs"]
		args.learning_rate = wandb.config["learning_rate"]
		self.args = args
		self.model = get_resnet_for_feature_extraction(args.n_classes).to(device)
		self.optimizer = args.optimizer(self.model.out_layers[-1].parameters(), lr=args.learning_rate)
		self.trainset, self.testset = get_cifar(subset=args.subset)
		self.step = 0
		wandb.watch(self.model.out_layers[-1], log="all", log_freq=20)


def train():
	args = ResNetTrainingArgsWandb()
	trainer = ResNetTrainerWandbSweeps(args)
	trainer.train()
```
</details>

### Run your sweep (step 3)

Finally, you can use the code below to run your sweep! This will probably take a while, because you're doing three separate full training and validation runs.

In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project='day4-resnet-sweep')
wandb.agent(sweep_id=sweep_id, function=train, count=3)
wandb.finish()

When you run this code, you should click on the link which looks like:

```
View sweep at https://wandb.ai/<USERNAME>/<PROJECT-NAME>/sweeps/<SWEEP-NAME>
```

This link will bring you to a page comparing each of your sweeps. You'll be able to see overlaid graphs of each of their training loss and test accuracy, as well as a bunch of other cool things like:

* Bar charts of the [importance](https://docs.wandb.ai/ref/app/features/panels/parameter-importance) (and correlation) of each hyperparameter wrt the target metric. Note that only looking at the correlation could be misleading - something can have a correlation of 1, but still have a very small effect on the metric.
* A [parallel coordinates plot](https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates), which summarises the relationship between the hyperparameters in your config and the model metric you're optimising.

What can you infer from these results? Are there any hyperparameters which are especially correlated / anticorrelated with the target metric? Are there any results which suggest the model is being undertrained?



<details>
<summary>Note on using YAML files (optional)</summary>

Rather than using a dictionary, you can alternatively store the `sweep_config` data in a YAML file if you prefer. You will then be able to run a sweep via the following terminal commands:

```
wandb sweep sweep_config.yaml

wandb agent <SWEEP_ID>
```

where `SWEEP_ID` is the value returned from the first terminal command. You will also need to add another line to the YAML file, specifying the program to be run. For instance, your YAML file might start like this:

```yaml
program: train.py
method: random
metric:
    name: test_accuracy
    goal: maximize
```

For more, see [this link](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration).

</details>

# 3️⃣ Bonus


Congratulations for getting to the end of the main content! This section gives some suggestions for more features of Weights and Biases to explore, or some other experiments you can run.


## Scaling Laws


These bonus exercises are taken directly from Jacob Hilton's [online deep learning curriculum](https://github.com/jacobhilton/deep_learning_curriculum/blob/master/2-Scaling-Laws.md) (which is what the original version of the ARENA course was based on).

First, you can start by reading the [Chinchilla paper](https://arxiv.org/abs/2203.15556). This is a correction to the original scaling laws paper: parameter count scales linearly with token budget for compute-optimal models, not ~quadratically. The difference comes from using a separately-tuned learning rate schedule for each token budget, rather than using a single training run to measure performance for every token budget. This highlights the importance of hyperparameter tuning for measuring scaling law exponents.

You don't have to read the entire paper, just skim the graphs. Don't worry if they don't all make sense yet (it will be more illuminating when we study LLMs next week). Note that, although it specifically applies to language models, the key underlying ideas of tradeoffs between optimal dataset size and model size are generally applicable.

### Suggested exercise

Perform your own study of scaling laws for MNIST.

- Write a script to train a small CNN on MNIST, or find one you have written previously.
- Training for a single epoch only, vary the model size and dataset size. For the model size, multiply the width by powers of sqrt(2) (rounding if necessary - the idea is to vary the amount of compute used per forward pass by powers of 2). For the dataset size, multiply the fraction of the full dataset used by powers of 2 (i.e. 1, 1/2, 1/4, ...). To reduce noise, use a few random seeds and always use the full validation set.
- The learning rate will need to vary with model size. Either tune it carefully for each model size, or use the rule of thumb that for Adam, the learning rate should be proportional to the initialization scale, i.e. `1/sqrt(fan_in)` for the standard Kaiming He initialization (which is what PyTorch generally uses by default).
    - Note - `fan_in` refers to the variable $N_{in}$, which is `in_features` for a linear layer, and `in_channels * kernel_size * kernel_size` for a convolutional layer - in other words, the number of input parameters/activations we take a sumproduct over to get each output activation.
- Plot the amount of compute used (on a log scale) against validation loss. The compute-efficient frontier should follow an approximate power law (straight line on a log scale).
How does validation accuracy behave?
- Study how the compute-efficient model size varies with compute. This should also follow an approximate power law. Try to estimate its exponent.
- Repeat your entire experiment with 20% [dropout](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html) to see how this affects the scaling exponents.


## Other WandB features: Saving & Logging

Here are a few more Weights & Biases features you might also want to play around with:

* [Logging media and objects in experiments](https://docs.wandb.ai/guides/track/log?fbclid=IwAR3NxKsGpEjZwq3vSwYkohZllMpBwxHgOCc_k0ByuD9XGUsi_Scf5ELvGsQ) - you'll be doing this during the RL week, and it's useful when you're training generative image models like VAEs and diffusion models.
* [Code saving](https://docs.wandb.ai/guides/app/features/panels/code?fbclid=IwAR2BkaXbRf7cqEH8kc1VzqH_kOJWGxqjUb_JCBq_SCnXOx1oF-Rt-hHydb4) - this captures all python source code files in the current director and all subdirectories. It's great for reproducibility, and also for sharing your code with others.
* [Saving and loading PyTorch models](https://wandb.ai/wandb/common-ml-errors/reports/How-to-Save-and-Load-Models-in-PyTorch--VmlldzozMjg0MTE?fbclid=IwAR1Y9MzFTxIiVBJG06b4ppitwKWR4H5_ncKyT2F_rR5Z_IHawmpBTKskPcQ) - you can do this easily using `torch.save`, but it's also possible to do this directly through Weights and Biases as an **artifact**.

## Train your model from scratch

Now that you understand how to run training loops, you can try a big one - training your ResNet from scratch on CIFAR-10 data!

Here are some tips and suggestions for things you can experiment with:

- First, try to reduce training time.
    - Starting with a smaller ResNet than the full `ResNet34` is a good idea. Good hyperparameters on the small model tend to transfer over to the larger model because the architecture and the data are the same; the main difference is the larger model may require more regularization to prevent overfitting.
    - Bad hyperparameters are usually clearly worse by the end of the first 1-2 epochs. If you can train for fewer epochs, you can test more hyperparameters with the same compute. You can manually abort runs that don't look promising, or you can try to do it automatically; [Hyperband](https://www.jmlr.org/papers/volume18/16-558/16-558.pdf) is a popular algorithm for this.
    - Play with optimizations like [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html) to see if you get a speed boost.
- Random search for a decent learning rate and batch size combination that allows your model to mostly memorize (overfit) the training set.
    - It's better to overfit at the start than underfit, because it means your model is capable of learning and has enough capacity.
    - Learning rate is often the most important single hyperparameter, so it's important to get a good-enough value early.
    - Eventually, you'll want a learning rate schedule. Usually, you'll start low and gradually increase, then gradually decrease but many other schedules are feasible. [Jeremy Jordan](https://www.jeremyjordan.me/nn-learning-rate/) has a good blog post on learning rates.
    - Larger batch size increases GPU memory usage and doubling batch size [often allows doubling learning rate](https://arxiv.org/pdf/1706.02677.pdf), up to a point where this relationship breaks down. The heuristic is that larger batches give a more accurate estimate of the direction to update in. Note that on the test set, you can vary the batch size independently and usually the largest value that will fit on your GPU will be the most efficient.
- Add regularization to reduce the amount of overfitting and train for longer to see if it's enough.
    - Data augmention is the first thing to do - flipping the image horizontally and Cutout are known to be effective.
    - Play with the label smoothing parameter to [torch.nn.CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).
    - Try adding weight decay to Adam. This is a bit tricky - see this [fast.ai](https://www.fast.ai/2018/07/02/adam-weight-decay/) article if you want to do this, as well as the [PyTorch pseudocode](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
- Try a bit of architecture search: play with various numbers of blocks and block groups. Or pick some fancy newfangled nonlinearity and see if it works better than ReLU.


## The Optimizer's Curse

The [optimizer's curse](https://www.lesswrong.com/posts/5gQLrJr2yhPzMCcni/the-optimizer-s-curse-and-how-to-beat-it) applies to tuning hyperparameters. The main take-aways are:

- You can expect your best hyperparameter combination to actually underperform in the future. You chose it because it was the best on some metric, but that metric has an element of noise/luck, and the more combinations you test the larger this effect is.
- Look at the overall trends and correlations in context and try to make sense of the values you're seeing. Just because you ran a long search process doesn't mean your best output is really the best.

For more on this, see [Preventing "Overfitting" of Cross-Validation Data](https://ai.stanford.edu/~ang/papers/cv-final.pdf) by Andrew Ng.


---


`wandb` is an incredibly useful tool when training models, and you should find yourself using it a fair amount throughout this program. You can always return to this page of exercises if you forget how any part of it works.
