In [None]:
import os
import random

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from matplotlib import pyplot as plt

from ray import train, tune
from ray.tune.schedulers import ASHAScheduler

In [None]:
random.seed(0)

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(0)


In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # In this example, we don't change the model architecture
        # due to simplicity.
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

In [None]:
MAX_ITER = 2024
TEST_SIZE = 256

def train_epoch(model, optimizer, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # We set this just for the example to run quickly.
        if batch_idx  > MAX_ITER:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

    # return model


def validate(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx > TEST_SIZE:
              break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

In [None]:
def train_mnist(config):
    # Data Setup
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(
        datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)
    test_loader = DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),
        batch_size=64,
        shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ConvNet()
    model.to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    for i in range(5):
        train_epoch(model, optimizer, train_loader)
        acc = validate(model, test_loader)

        # print(f"Epoch {i+1}, Accuracy: {acc}")
        train.report({"accuracy": acc})#, checkpoint=checkpoint)


In [None]:
# config = {"lr": 0.01, "momentum":0.95}
# train_mnist(config)

tuner = tune.Tuner(
    train_mnist,
    param_space={"lr": 0.01, "momentum":0.95},
)
results = tuner.fit()

In [None]:
search_space = {
    "lr": tune.grid_search([0.1, 0.01, 0.005, 0.001]),
    "momentum": tune.grid_search([0.9, 0.95]),

    # "momentum": tune.choice([0.9, 0.95, 0.99]),

    # "lr": tune.sample_from([0.1, 0.01, 0.001]),
    # "momentum": tune.uniform(0.9, 0.99),
}

trainable_with_cpu_gpu = tune.with_resources(train_mnist, {"cpu":1 , "gpu": 0.5})
tuner = tune.Tuner(
    trainable_with_cpu_gpu,
    param_space=search_space,
)
results = tuner.fit()

In [None]:
dfs = {result.path: result.metrics_dataframe for result in results}
ax = None  # This plots everything on the same plot
for d in dfs.values():
    lr, m = d[["config/lr","config/momentum"]].iloc[0].tolist()
    # if m ==0.95 and lr == 0.1:
    #     continue
    l = f"lr:{lr},m:{m}"
    plt.plot(d.accuracy, label=l)
    # ax = d.accuracy.plot(ax=ax)
    # ax.legend("a")
plt.legend()