In [None]:
import ray
ray.init()
from ray import tune


In [None]:
ray.__version__

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd.torch.examples.train_example import LinearDataset

In [None]:
class MyTrainingOperator(TrainingOperator):
    def setup(self, config):
        # Setup all components needed for training here. This could include
        # data, models, optimizers, loss & schedulers.

        # Setup data loaders.
        train_dataset, val_dataset = LinearDataset(2, 5), LinearDataset(2,
                                                                        5)
        train_loader = DataLoader(train_dataset,
                                  batch_size=config["batch_size"])
        val_loader = DataLoader(val_dataset,
                                batch_size=config["batch_size"])

        # Setup model.
        model = nn.Linear(1, 1)

        # Setup optimizer.
        optimizer = torch.optim.SGD(model.parameters(), lr=config.get("lr", 1e-4))

        # Setup loss.
        criterion = torch.nn.MSELoss()

        # Setup scheduler.
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.9)

        # Register all of these components with Ray SGD.
        # This allows Ray SGD to do framework level setup like Cuda, DDP,
        # Distributed Sampling, FP16.
        # We also assign the return values of self.register to instance
        # attributes so we can access it in our custom training/validation
        # methods.
        self.model, self.optimizer, self.criterion, self.scheduler = \
            self.register(models=model, optimizers=optimizer,
                          criterion=criterion,
                          schedulers=scheduler)
        self.register_data(train_loader=train_loader, validation_loader=val_loader)

In [None]:
from ray.util.sgd import TorchTrainer

trainer = TorchTrainer(
    training_operator_cls=MyTrainingOperator,
    scheduler_step_freq="epoch",  # if scheduler is used
    config={"lr": 0.001, "batch_size": 64*64*64*64},
    num_workers=2,
    use_gpu=False)


In [None]:
for i in range(10):
    metrics = trainer.train()
    print(metrics)
    val_metrics = trainer.validate()
    print(val_metrics)
print("success!")