<a href="https://www.kaggle.com/code/averma111/pytorch-raytune?scriptVersionId=128577628" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
import os
import argparse
from filelock import FileLock
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler

# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


def train(model, optimizer, train_loader, device=None):
    device = device or torch.device("cpu")
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
            return
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, data_loader, device=None):
    device = device or torch.device("cpu")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            if batch_idx * len(data) > TEST_SIZE:
                break
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total


def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/data.lock")):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "~/data", train=True, download=True, transform=mnist_transforms
            ),
            batch_size=64,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                "~/data", train=False, download=True, transform=mnist_transforms
            ),
            batch_size=64,
            shuffle=True,
        )
    return train_loader, test_loader


def train_mnist(config):
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    train_loader, test_loader = get_data_loaders()
    model = ConvNet().to(device)

    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"]
    )

    while True:
        train(model, optimizer, train_loader, device)
        acc = test(model, test_loader, device)
        # Set this to run Tune.
        tune.report(mean_accuracy=acc)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--cuda", action="store_true", default=False, help="Enables GPU training"
    )
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    parser.add_argument(
        "--ray-address",
        help="Address of Ray cluster for seamless distributed execution.",
    )
    parser.add_argument(
        "--server-address",
        type=str,
        default=None,
        required=False,
        help="The address of server to connect to if using Ray Client.",
    )
    args, _ = parser.parse_known_args()

    if args.server_address:
        ray.init(f"ray://{args.server_address}")
    elif args.ray_address:
        ray.init(address=args.ray_address)
    else:
        ray.init(num_cpus=2 if args.smoke_test else None)

    # for early stopping
    sched = AsyncHyperBandScheduler()

    analysis = tune.run(
        train_mnist,
        metric="mean_accuracy",
        mode="max",
        name="exp",
        scheduler=sched,
        stop={
            "mean_accuracy": 0.98,
            "training_iteration": 5 if args.smoke_test else 100,
        },
        resources_per_trial={"cpu": 2, "gpu": int(args.cuda)},  # set this for GPUs
        num_samples=1 if args.smoke_test else 50,
        config={
            "lr": tune.loguniform(1e-4, 1e-2),
            "momentum": tune.uniform(0.1, 0.9),
        },
    )

    print("Best config is:", analysis.best_config)

2023-05-06 21:47:57,739	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Current time:,2023-05-06 21:49:34
Running for:,00:01:33.40
Memory:,1.9/31.4 GiB

Trial name,status,loc,lr,momentum,acc,iter,total time (s)
train_mnist_ab879_00000,TERMINATED,172.19.2.2:449,0.00958148,0.71137,0.925,100,35.0095
train_mnist_ab879_00001,TERMINATED,172.19.2.2:507,0.00020048,0.800533,0.125,1,0.438999
train_mnist_ab879_00002,TERMINATED,172.19.2.2:507,0.00018875,0.624856,0.103125,1,0.372328
train_mnist_ab879_00003,TERMINATED,172.19.2.2:507,0.00187054,0.705482,0.19375,1,0.374709
train_mnist_ab879_00004,TERMINATED,172.19.2.2:507,0.000660986,0.259157,0.09375,1,0.387726
train_mnist_ab879_00005,TERMINATED,172.19.2.2:507,0.000936292,0.696077,0.10625,1,0.384456
train_mnist_ab879_00006,TERMINATED,172.19.2.2:507,0.00181865,0.655631,0.05,1,0.395782
train_mnist_ab879_00007,TERMINATED,172.19.2.2:507,0.00674117,0.700678,0.553125,4,1.26273
train_mnist_ab879_00008,TERMINATED,172.19.2.2:507,0.000440106,0.628815,0.05,1,0.398797
train_mnist_ab879_00009,TERMINATED,172.19.2.2:507,0.000139104,0.749956,0.1,1,0.36617


[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /root/data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]
100%|██████████| 9912422/9912422 [00:00<00:00, 95881404.85it/s]


[2m[36m(train_mnist pid=449)[0m Extracting /root/data/MNIST/raw/train-images-idx3-ubyte.gz to /root/data/MNIST/raw
[2m[36m(train_mnist pid=449)[0m 
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw/train-labels-idx1-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Extracting /root/data/MNIST/raw/train-labels-idx1-ubyte.gz to /root/data/MNIST/raw
[2m[36m(train_mnist pid=449)[0m 
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 55668976.94it/s]
  0%|          | 0/1648877 [00:00<?, ?it/s]


[2m[36m(train_mnist pid=449)[0m Extracting /root/data/MNIST/raw/t10k-images-idx3-ubyte.gz to /root/data/MNIST/raw
[2m[36m(train_mnist pid=449)[0m 
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz
[2m[36m(train_mnist pid=449)[0m Extracting /root/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to /root/data/MNIST/raw
[2m[36m(train_mnist pid=449)[0m 


100%|██████████| 1648877/1648877 [00:00<00:00, 24083336.75it/s]
100%|██████████| 4542/4542 [00:00<00:00, 28264879.48it/s]


Trial name,date,done,episodes_total,experiment_id,hostname,iterations_since_restore,mean_accuracy,node_ip,pid,time_since_restore,time_this_iter_s,time_total_s,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
train_mnist_ab879_00000,2023-05-06_21-48-41,True,,af7b71c0903541d0a58762229884ce5a,40858ec3dfba,100,0.925,172.19.2.2,449,35.0095,0.288851,35.0095,1683409721,0,,100,ab879_00000,0.00609446
train_mnist_ab879_00001,2023-05-06_21-48-11,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.125,172.19.2.2,507,0.438999,0.438999,0.438999,1683409691,0,,1,ab879_00001,0.00514054
train_mnist_ab879_00002,2023-05-06_21-48-12,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.103125,172.19.2.2,507,0.372328,0.372328,0.372328,1683409692,0,,1,ab879_00002,0.00514054
train_mnist_ab879_00003,2023-05-06_21-48-12,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.19375,172.19.2.2,507,0.374709,0.374709,0.374709,1683409692,0,,1,ab879_00003,0.00514054
train_mnist_ab879_00004,2023-05-06_21-48-12,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.09375,172.19.2.2,507,0.387726,0.387726,0.387726,1683409692,0,,1,ab879_00004,0.00514054
train_mnist_ab879_00005,2023-05-06_21-48-13,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.10625,172.19.2.2,507,0.384456,0.384456,0.384456,1683409693,0,,1,ab879_00005,0.00514054
train_mnist_ab879_00006,2023-05-06_21-48-13,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.05,172.19.2.2,507,0.395782,0.395782,0.395782,1683409693,0,,1,ab879_00006,0.00514054
train_mnist_ab879_00007,2023-05-06_21-48-15,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,4,0.553125,172.19.2.2,507,1.26273,0.294743,1.26273,1683409695,0,,4,ab879_00007,0.00514054
train_mnist_ab879_00008,2023-05-06_21-48-15,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.05,172.19.2.2,507,0.398797,0.398797,0.398797,1683409695,0,,1,ab879_00008,0.00514054
train_mnist_ab879_00009,2023-05-06_21-48-15,True,,c32c2c54dd6c42efbf198db7b4ee2ffb,40858ec3dfba,1,0.1,172.19.2.2,507,0.36617,0.36617,0.36617,1683409695,0,,1,ab879_00009,0.00514054


2023-05-06 21:49:34,114	INFO tune.py:798 -- Total run time: 93.91 seconds (93.39 seconds for the tuning loop).


Best config is: {'lr': 0.007756426293398399, 'momentum': 0.8935793736293866}
