In [1]:
from collections import OrderedDict
from dataclasses import dataclass
import os.path as path

import torch as t
import torch.nn.functional as F
import torchvision as tv

import numpy as np

from ax.service.managed_loop import optimize

In [2]:
DEVICE = t.device("cuda" if t.cuda.is_available() else "cpu")
DATAROOT = path.expanduser("~/mldata/pytorch")
DEVICE

device(type='cpu')

In [3]:
xform = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5,), (0.5,))
])
datapath = path.join(DATAROOT, "fashion-mnist")
train_val_set = tv.datasets.FashionMNIST(datapath, download=True, train=True, transform=xform)
train_size = int(len(train_val_set) * 0.8)
val_size = len(train_val_set) - train_size
trainset, valset = t.utils.data.random_split(train_val_set, [train_size, val_size])
testset = tv.datasets.FashionMNIST(datapath, download=True, train=False, transform=xform)

In [4]:
def create_model():
    model = t.nn.Sequential(OrderedDict([
        ("flatten", t.nn.Flatten()),
        ("fc1", t.nn.Linear(784, 128)),
        ("relu1", t.nn.ReLU()),
        ("fc2", t.nn.Linear(128, 64)),
        ("relu2", t.nn.ReLU()),
        ("fc3", t.nn.Linear(64, 32)),
        ("relu3", t.nn.ReLU()),
        ("logits", t.nn.Linear(32, 10))
    ]))
    return model

In [5]:
def accuracy(outputs, targets):
    assert outputs.shape[0] == targets.shape[0]
    predictions = t.argmax(outputs, dim=1)
    correct = t.sum(predictions == targets).item()
    return correct / targets.shape[0]

In [6]:
@dataclass
class Hyperparams:
    batch_size: int = 10
    epochs: int = 10
    learning_rate: float = 0.0001

    def to_dict(self):
        return {
            "batch_size": self.batch_size,
            "epochs": self.epochs,
            "learning_rate": self.learning_rate
        }

In [12]:
def train(model, optim, loss_fn, epochs, trainloader, valloader):
    model = model.to(DEVICE)
    for epoch in range(epochs):
        # Process the training set
        model.train()
        with t.enable_grad():
            for images, targets in trainloader:
                images = images.to(DEVICE)
                targets = targets.to(DEVICE)

                optim.zero_grad()
                outputs = model.forward(images)
                loss = loss_fn(outputs, targets)
                loss.backward()
                optim.step()

    # Calculate the validation metrics
    val_outputs = t.empty(0, 10)
    val_targets = t.tensor([], dtype=t.long)
    model.eval()
    with t.no_grad():
        for images, targets in valloader:
            images = images.to(DEVICE)
            targets = targets.to(DEVICE)
            outputs = model(images)
            loss = loss_fn(outputs, targets)
            val_outputs = t.cat((val_outputs, outputs.detach()))
            val_targets = t.cat((val_targets, targets.detach()))
    val_acc = accuracy(val_outputs, val_targets)
    return val_acc

In [15]:
def train_evaluate(hparams):
    hparams = Hyperparams(**hparams)
    model = create_model()
    optim = t.optim.SGD(model.parameters(), lr=hparams.learning_rate)
    loss_fn = t.nn.CrossEntropyLoss()
    trainloader = t.utils.data.DataLoader(trainset, batch_size=hparams.batch_size, shuffle=True)
    valloader = t.utils.data.DataLoader(valset, batch_size=5000)
    val_acc = train(model, optim, loss_fn, hparams.epochs, trainloader, valloader)
    return {"accuracy": (val_acc, 0.0)}

In [16]:
hparams = [
    {"name": "batch_size", "type": "choice", "value_type": "int", "values": [16, 32, 64]},
    {"name": "epochs", "type": "range", "value_type": "int", "bounds": [7, 13]},
    {"name": "learning_rate", "type": "range", "bounds": [1e-6, 0.4], "log_scale": True}
]

best_params, values, experiment, model = optimize(
    hparams, 
    evaluation_function=train_evaluate, 
    objective_name="accuracy",
    total_trials=5  # default is 20
)

[INFO 03-15 22:37:19] ax.modelbridge.dispatch_utils: Using Sobol generation strategy.
[INFO 03-15 22:37:20] ax.service.managed_loop: Started full optimization with 20 steps.
[INFO 03-15 22:37:20] ax.service.managed_loop: Running optimization trial 1...
[INFO 03-15 22:40:25] ax.service.managed_loop: Running optimization trial 2...
[INFO 03-15 22:43:19] ax.service.managed_loop: Running optimization trial 3...
[INFO 03-15 22:47:37] ax.service.managed_loop: Running optimization trial 4...
[INFO 03-15 22:51:48] ax.service.managed_loop: Running optimization trial 5...
[INFO 03-15 22:54:15] ax.service.managed_loop: Running optimization trial 6...
[INFO 03-15 22:58:11] ax.service.managed_loop: Running optimization trial 7...
[INFO 03-15 23:01:08] ax.service.managed_loop: Running optimization trial 8...
[INFO 03-15 23:04:31] ax.service.managed_loop: Running optimization trial 9...
[INFO 03-15 23:07:06] ax.service.managed_loop: Running optimization trial 10...
[INFO 03-15 23:10:42] ax.service.ma

In [17]:
best_params

{'epochs': 13, 'learning_rate': 0.25718724285671446, 'batch_size': 32}

In [18]:
values

({'accuracy': 0.88525}, {'accuracy': {'accuracy': 0.0}})