# Tune the hyperparameters of a CNN on MNIST

This tutorial walks through using Ax to tune two hyperparameters (learning rate and momentum) for a PyTorch CNN on the MNIST dataset trained using SGD with momentum. Adapted from https://ax.dev/tutorials/tune_cnn.html and a tutorial about the methods is available at https://ax.dev/docs/bayesopt.html

1. Run through the tutorial, then 
2. Write your own code for grid search and random search (remember the logarithmic transforms.
3. create your own `net` structure to go into the `train_evaluate()` function and try optimising your own network. Consider also adapting the network structure. You could also try adapting some of the examples used in earlier labs to see whether you can improve on your results with hyperparameter search.


In [2]:
!pip3 install ax-platform 

Collecting ax-platform
[?25l  Downloading https://files.pythonhosted.org/packages/c3/e5/defa97540bf23447f15d142a644eed9a9d9fd1925cf1e3c4f47a49282ec0/ax_platform-0.1.9-py3-none-any.whl (499kB)
[K     |▋                               | 10kB 26.7MB/s eta 0:00:01[K     |█▎                              | 20kB 32.0MB/s eta 0:00:01[K     |██                              | 30kB 34.1MB/s eta 0:00:01[K     |██▋                             | 40kB 37.1MB/s eta 0:00:01[K     |███▎                            | 51kB 39.3MB/s eta 0:00:01[K     |████                            | 61kB 40.5MB/s eta 0:00:01[K     |████▋                           | 71kB 40.1MB/s eta 0:00:01[K     |█████▎                          | 81kB 40.5MB/s eta 0:00:01[K     |██████                          | 92kB 40.2MB/s eta 0:00:01[K     |██████▋                         | 102kB 41.2MB/s eta 0:00:01[K     |███████▏                        | 112kB 41.2MB/s eta 0:00:01[K     |███████▉                        | 12

In [3]:
import torch
import numpy as np

from ax.plot.contour import plot_contour
from ax.plot.trace import optimization_trace_single_method
from ax.service.managed_loop import optimize
from ax.utils.notebook.plotting import render, init_notebook_plotting
from ax.utils.tutorials.cnn_utils import load_mnist, train, evaluate, CNN

init_notebook_plotting(offline=True)

Output hidden; open in https://colab.research.google.com to view.

In [4]:
torch.manual_seed(12345)
dtype = torch.float
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## 1. Load MNIST data
First, we need to load the MNIST data and partition it into training, validation, and test sets.

Note: this will download the dataset if necessary.

In [5]:
BATCH_SIZE = 512
train_loader, valid_loader, test_loader = load_mnist(batch_size=BATCH_SIZE)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


## 2. Define function to optimize
In this tutorial, we want to optimize classification accuracy on the validation set as a function of the learning rate and momentum. The function takes in a parameterization (set of parameter values), computes the classification accuracy, and returns a dictionary of metric name ('accuracy') to a tuple with the mean and standard error.

In [6]:
print(CNN())

def train_evaluate(parameterization):
    net = CNN()
    net = train(net=net, train_loader=train_loader, parameters=parameterization, dtype=dtype, device=device)
    return evaluate(
        net=net,
        data_loader=valid_loader,
        dtype=dtype,
        device=device,
    )

CNN(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1280, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)


## 3. Run the optimization loop
Here, we set the bounds on the learning rate and momentum and set the parameter space for the learning rate to be on a log scale. 

In [7]:
best_parameters, values, experiment, model = optimize(
    parameters=[
        {"name": "lr", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True},
        {"name": "momentum", "type": "range", "bounds": [0.0, 1.0]},
    ],
    evaluation_function=train_evaluate,
    objective_name='accuracy',
)

[INFO 02-09 12:37:59] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 5 arms, GPEI for subsequent arms], generated 0 arm(s) so far). Iterations after 5 will take longer to generate due to model-fitting.
[INFO 02-09 12:37:59] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 02-09 12:37:59] ax.service.managed_loop: Running optimization trial 1...
[INFO 02-09 12:38:12] ax.service.managed_loop: Running optimization trial 2...
[INFO 02-09 12:38:16] ax.service.managed_loop: Running optimization trial 3...
[INFO 02-09 12:38:19] ax.service.managed_loop: Running optimization trial 4...
[INFO 02-09 12:38:22] ax.service.managed_loop: Running optimization trial 5...
[INFO 02-09 12:38:27] ax.service.managed_loop: Running optimization trial 6...
[INFO 02-09 12:38:31] ax.service.managed_loop: Running optimization trial 7...
[INFO 02-09 12:38:35] ax.service.managed_loop: Running optimization t

We can introspect the optimal parameters and their outcomes:

In [0]:
best_parameters

{'lr': 0.0002357769832851689, 'momentum': 0.4011007806149234}

In [0]:
means, covariances = values
means, covariances

({'accuracy': 0.9161664968680362},
 {'accuracy': {'accuracy': 7.351145247679816e-09}})

## 4. Plot response surface

Contour plot showing classification accuracy as a function of the two hyperparameters.

The black squares show points that we have actually run, notice how they are clustered in the optimal region.

In [0]:
# some boilerplate to make things render on colab
import plotly.io as pio
pio.renderers.default = 'colab'
def configure_plotly_browser_state():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-1.5.1.min.js?noext',
            },
          });
        </script>
        '''))

In [0]:
configure_plotly_browser_state()
render(plot_contour(model=model, param_x='lr', param_y='momentum', metric_name='accuracy'))

## 5. Plot best objective as function of the iteration

Show the model accuracy improving as we identify better hyperparameters.

In [10]:
configure_plotly_browser_state()
# `plot_single_method` expects a 2-d array of means, because it expects to average means from multiple 
# optimization runs, so we wrap out best objectives array in another array.
best_objectives = np.array([[trial.objective_mean*100 for trial in experiment.trials.values()]])
best_objective_plot = optimization_trace_single_method(
    y=np.maximum.accumulate(best_objectives, axis=1),
    title="Model performance vs. # of iterations",
    ylabel="Classification Accuracy, %",
)
render(best_objective_plot)

## 6. Train CNN with best hyperparameters and evaluate on test set
Note that the resulting accuracy on the test set might not be exactly the same as the maximum accuracy achieved on the evaluation set throughout optimization. 

In [11]:
data = experiment.fetch_data()
df = data.df
best_arm_name = df.arm_name[df['mean'] == df['mean'].max()].values[0]
best_arm = experiment.arms_by_name[best_arm_name]
best_arm

Arm(name='19_0', parameters={'lr': 0.0002430489320267667, 'momentum': 0.920156017935921})

In [0]:
combined_train_valid_set = torch.utils.data.ConcatDataset([
    train_loader.dataset.dataset, 
    valid_loader.dataset.dataset,
])
combined_train_valid_loader = torch.utils.data.DataLoader(
    combined_train_valid_set, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
)

In [0]:
net = train(
    net=CNN(),
    train_loader=combined_train_valid_loader, 
    parameters=best_arm.parameters,
    dtype=dtype,
    device=device,
)
test_accuracy = evaluate(
    net=net,
    data_loader=test_loader,
    dtype=dtype,
    device=device,
)

In [14]:
print(f"Classification Accuracy (test set): {round(test_accuracy*100, 2)}%")

Classification Accuracy (test set): 10.42%
