In [1]:
import argparse
from typing import Dict

import torch
import ray.train as train
from ray.train.trainer import Trainer
from ray.train.callbacks import JsonLoggerCallback
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

In [2]:
print (torch.cuda.is_available())

True


In [3]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)

In [4]:
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [5]:
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(),
            nn.Linear(512, 10), nn.ReLU())

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [6]:
def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [7]:
def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n "
          f"Accuracy: {(100 * correct):>0.1f}%, "
          f"Avg loss: {test_loss:>8f} \n")
    return test_loss


In [8]:
def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    loss_results = []

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        train.report(loss=loss)
        loss_results.append(loss)

    return loss_results

In [9]:
def train_fashion_mnist(num_workers=2, use_gpu=False):
    trainer = Trainer(
        backend="torch", num_workers=num_workers, use_gpu=use_gpu)
    trainer.start()
    result = trainer.run(
        train_func=train_func,
        config={
            "lr": 1e-3,
            "batch_size": 64,
            "epochs": 4
        },
        callbacks=[JsonLoggerCallback()])
    trainer.shutdown()
    print(f"Loss results: {result}")

In [10]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address",
        required=False,
        type=str,
        help="the address to use for Ray")
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=4,
        help="Sets number of workers for training.")
    parser.add_argument(
        "--use-gpu",
        action="store_true",
        default=True,
        help="Enables GPU training")
    parser.add_argument(
        "--smoke-test",
        action="store_true",
        default=False,
        help="Finish quickly for testing.")

    args, _ = parser.parse_known_args()

    import ray

    if args.smoke_test:
        ray.init(num_cpus=2)
    else:
        ray.init(address=args.address)
    train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)

2022-02-20 14:00:52,422	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8266[39m[22m
2022-02-20 14:00:54,281	INFO trainer.py:190 -- Trainer logs will be logged in: /home/steve/ray_results/train_2022-02-20_14-00-54
[2m[36m(BaseWorkerMixin pid=29111)[0m 2022-02-20 14:00:56,690	INFO torch.py:66 -- Setting up process group for: env:// [rank=0, world_size=4]
[2m[36m(BaseWorkerMixin pid=29104)[0m 2022-02-20 14:00:56,684	INFO torch.py:66 -- Setting up process group for: env:// [rank=3, world_size=4]
[2m[36m(BaseWorkerMixin pid=29101)[0m 2022-02-20 14:00:56,685	INFO torch.py:66 -- Setting up process group for: env:// [rank=1, world_size=4]
[2m[36m(BaseWorkerMixin pid=29108)[0m 2022-02-20 14:00:56,720	INFO torch.py:66 -- Setting up process group for: env:// [rank=2, world_size=4]
2022-02-20 14:00:57,827	INFO trainer.py:196 -- Run results will be logged in: /home/steve/ray_results/train_2022-02-20_14-00-54/run_001
[2m[36m(BaseWorkerMixin pid=29104)[0

[2m[36m(BaseWorkerMixin pid=29111)[0m loss: 2.311540  [    0/60000]
[2m[36m(BaseWorkerMixin pid=29104)[0m loss: 2.310010  [    0/60000]
[2m[36m(BaseWorkerMixin pid=29108)[0m loss: 2.307373  [    0/60000]
[2m[36m(BaseWorkerMixin pid=29101)[0m loss: 2.308183  [    0/60000]
[2m[36m(BaseWorkerMixin pid=29111)[0m loss: 2.304651  [ 6400/60000]
[2m[36m(BaseWorkerMixin pid=29104)[0m loss: 2.300238  [ 6400/60000]
[2m[36m(BaseWorkerMixin pid=29108)[0m loss: 2.290715  [ 6400/60000]
[2m[36m(BaseWorkerMixin pid=29101)[0m loss: 2.301900  [ 6400/60000]
[2m[36m(BaseWorkerMixin pid=29111)[0m loss: 2.299919  [12800/60000]
[2m[36m(BaseWorkerMixin pid=29104)[0m loss: 2.293612  [12800/60000]
[2m[36m(BaseWorkerMixin pid=29108)[0m loss: 2.287717  [12800/60000]
[2m[36m(BaseWorkerMixin pid=29101)[0m loss: 2.298029  [12800/60000]
[2m[36m(BaseWorkerMixin pid=29111)[0m Test Error: 
[2m[36m(BaseWorkerMixin pid=29111)[0m  Accuracy: 6.1%, Avg loss: 2.287458 
[2m[36m(BaseWo