# Pytorch Distributed Training

See the "Pytorch distributed training" section of the [README](README.md) for setting up the Ray cluster before running this example.

For more information about Ray Train and the pytorch distributed training example, please check the original [getting-started-pytorch](https://docs.ray.io/en/latest/train/getting-started-pytorch.html) documentation.

## 1. Install requirements

In [1]:
%%bash
# pip install "ray[data,train,tune,serve]"==2.9.0
# pip install torch torchvision
# pip install IPython

Run a port forwarding to the Ray head service:
```bash
kubectl port-forward svc/raycluster-kuberay-head-svc 10001:10001 -n default
```

In [1]:
import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train.torch import TorchTrainer
import ray.train.torch

print(ray.__version__)

2.9.0


## 2. Connect to the Ray cluster

In [2]:
runtime_env = {
    "pip": ["torch", "torchvision", "IPython"],
}
ray.init(address="ray://localhost:10001", runtime_env=runtime_env)

0,1
Python version:,3.8.13
Ray version:,2.7.0
Dashboard:,http://10.244.0.79:8265


In [3]:
print(ray.cluster_resources())

{'object_store_memory': 6313651814.0, 'node:10.244.0.79': 1.0, 'node:__internal_head__': 1.0, 'node:10.244.0.78': 1.0, 'memory': 22000000000.0, 'GPU': 2.0, 'CPU': 8.0, 'accelerator_type:G': 2.0}


## 3. Define training

- `train_func` is the Python code that executes on each distributed training worker.

- `ScalingConfig` defines the number of distributed training workers and whether to use GPUs.

- `TorchTrainer` launches the distributed training job.

In [4]:
def train_func(config):
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # model.to("cuda")  # This is done by `prepare_model`
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    n_epochs = 4
    for epoch in range(n_epochs):
        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

In [5]:
# Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=False)

In [6]:
# Distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    run_config=ray.train.RunConfig(
        storage_path="/home/ray/nfs",
        name="nfs",
    )
)

## 4. Launch training job

In [7]:
result = trainer.fit()

[2m[36m(TunerInternal pid=3003)[0m [output] This will use the new output engine with verbosity 1. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949


[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m View detailed results here: /home/ray/nfs/nfs
[2m[36m(TunerInternal pid=3003)[0m To visualize your results with TensorBoard, run: `tensorboard --logdir /home/ray/ray_results/nfs`


[2m[36m(TunerInternal pid=3003)[0m AIR_VERBOSITY is set, ignoring passed-in ProgressReporter for now.
[2m[36m(TunerInternal pid=3003)[0m GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
[2m[36m(TrainTrainable pid=619, ip=10.244.0.78)[0m GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.
[2m[36m(TorchTrainer pid=619, ip=10.244.0.78)[0m GPUs are detected in your Ray cluster, but GPU training is not enabled for this trainer. To enable GPU training, make sure to set `use_gpu` to True in your scaling config.


[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m Training started without custom configuration.


[2m[36m(TorchTrainer pid=619, ip=10.244.0.78)[0m Starting distributed worker processes: ['675 (10.244.0.78)', '676 (10.244.0.78)']
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Setting up process group for: env:// [rank=0, world_size=2]
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Moving model to device: cpu
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Wrapping provided model in DistributedDataParallel.


[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]0.78)[0m 
  0%|          | 0/26421880 [00:00<?, ?it/s]0.78)[0m 
  1%|          | 294912/26421880 [00:00<00:09, 2866403.53it/s]
  1%|          | 229376/26421880 [00:00<00:12, 2131837.48it/s]
  3%|▎         | 917504/26421880 [00:00<00:05, 4529960.09it/s]
  3%|▎         | 884736/26421880 [00:00<00:05, 4300825.88it/s]
  6%|▌         | 1572864/26421880 [00:00<00:04, 5373327.49it/s]
  5%|▌         | 1409024/26421880 [00:00<00:05, 4659092.99it/s]
  8%|▊         | 2129920/26421880 [00:00<00:04, 5430414.64it/s]
  8%|▊         | 2064384/26421880 [00:00<00:04, 5326267.58it/s]
 10%|█         | 2686976/26421880 [00:00<00:04, 5419314.02it/s]
 10%|█         | 2654208/26421880 [00:00<00:04, 5405525.79it/s]
 13%|█▎        | 3309568/26421880 [00:00<00:04, 5579685.57it/s]
 12%|█▏        | 3211264/26421880 [00:00<00:04, 5405383.19it/s]
 15%|█▍        | 3932160/26421880 [00:00<00:03, 5718821.95it/s]
 14%|█▍        | 3768320/26421880 [00:00<00:04, 5378143.62it/s

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashio

  0%|          | 0/29515 [00:00<?, ?it/s]44.0.78)[0m 
100%|██████████| 29515/29515 [00:00<00:00, 1170050.78it/s]
100%|██████████| 29515/29515 [00:00<00:00, 1187025.43it/s]


[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mn

  0%|          | 0/4422102 [00:00<?, ?it/s].0.78)[0m 
  7%|▋         | 294912/4422102 [00:00<00:01, 2896520.95it/s]
  0%|          | 0/4422102 [00:00<?, ?it/s].0.78)[0m 
 27%|██▋       | 1212416/4422102 [00:00<00:00, 6453858.97it/s]
  5%|▌         | 229376/4422102 [00:00<00:02, 1923218.36it/s]
 43%|████▎     | 1900544/4422102 [00:00<00:00, 6589055.06it/s]
 16%|█▋        | 720896/4422102 [00:00<00:01, 3374434.08it/s]
 59%|█████▊    | 2588672/4422102 [00:00<00:00, 6498797.77it/s]
 27%|██▋       | 1179648/4422102 [00:00<00:00, 3854447.42it/s]
 39%|███▉      | 1736704/4422102 [00:00<00:00, 4496035.27it/s]
 73%|███████▎  | 3244032/4422102 [00:00<00:00, 6250637.93it/s]
 53%|█████▎    | 2326528/4422102 [00:00<00:00, 4953700.75it/s]
 88%|████████▊ | 3899392/4422102 [00:00<00:00, 5921718.16it/s]
 66%|██████▌   | 2916352/4422102 [00:00<00:00, 5265885.84it/s]
100%|██████████| 4422102/4422102 [00:00<00:00, 5935315.11it/s]
 85%|████████▌ | 3768320/4422102 [00:00<00:00, 6297029.48it/s]
100%|██████

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8633457.41it/s]


[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m 
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Extracting /tmp/data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/data/FashionMNIST/raw
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m 


100%|██████████| 5148/5148 [00:00<00:00, 37682856.88it/s]


[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m Training finished iteration 1 at 2024-01-12 01:34:57. Total running time: 7min 54s
[2m[36m(TunerInternal pid=3003)[0m ╭─────────────────────────────────────────╮
[2m[36m(TunerInternal pid=3003)[0m │ Training result                         │
[2m[36m(TunerInternal pid=3003)[0m ├─────────────────────────────────────────┤
[2m[36m(TunerInternal pid=3003)[0m │ checkpoint_dir_name   checkpoint_000000 │
[2m[36m(TunerInternal pid=3003)[0m │ time_this_iter_s              193.12779 │
[2m[36m(TunerInternal pid=3003)[0m │ time_total_s                  193.12779 │
[2m[36m(TunerInternal pid=3003)[0m │ training_iteration                    1 │
[2m[36m(TunerInternal pid=3003)[0m │ epoch                                 0 │
[2m[36m(TunerInternal pid=3003)[0m │ loss                            0.32324 │
[2m[36m(TunerInternal pid=3003)[0m ╰─────────────────────────────────────────╯
[2m[36m(TunerInter

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000000)
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000000)


[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m {'loss': 0.17258714139461517, 'epoch': 1}
[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m Training finished iteration 2 at 2024-01-12 01:38:03. Total running time: 10min 59s
[2m[36m(TunerInternal pid=3003)[0m ╭─────────────────────────────────────────╮
[2m[36m(TunerInternal pid=3003)[0m │ Training result                         │
[2m[36m(TunerInternal pid=3003)[0m ├─────────────────────────────────────────┤
[2m[36m(TunerInternal pid=3003)[0m │ checkpoint_dir_name   checkpoint_000001 │
[2m[36m(TunerInternal pid=3003)[0m │ time_this_iter_s              185.30114 │
[2m[36m(TunerInternal pid=3003)[0m │ time_total_s                  378.42893 │
[2m[36m(TunerInternal pid=3003)[0m │ training_iteration                    2 │
[2m[36m(TunerInternal pid=3003)[0m │ epoch                                 1 │
[2m[36m(TunerInternal pid=3003)[0m │ loss                            0.17259 │
[2m

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000001)
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000001)


[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m Training finished iteration 3 at 2024-01-12 01:41:09. Total running time: 14min 5s
[2m[36m(TunerInternal pid=3003)[0m ╭─────────────────────────────────────────╮
[2m[36m(TunerInternal pid=3003)[0m │ Training result                         │
[2m[36m(TunerInternal pid=3003)[0m ├─────────────────────────────────────────┤
[2m[36m(TunerInternal pid=3003)[0m │ checkpoint_dir_name   checkpoint_000002 │
[2m[36m(TunerInternal pid=3003)[0m │ time_this_iter_s               186.5469 │
[2m[36m(TunerInternal pid=3003)[0m │ time_total_s                  564.97583 │
[2m[36m(TunerInternal pid=3003)[0m │ training_iteration                    3 │
[2m[36m(TunerInternal pid=3003)[0m │ epoch                                 2 │
[2m[36m(TunerInternal pid=3003)[0m │ loss                            0.16846 │
[2m[36m(TunerInternal pid=3003)[0m ╰─────────────────────────────────────────╯
[2m[36m(TunerInter

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000002)
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000002)


[2m[36m(TunerInternal pid=3003)[0m 
[2m[36m(TunerInternal pid=3003)[0m Training finished iteration 4 at 2024-01-12 01:44:15. Total running time: 17min 11s
[2m[36m(TunerInternal pid=3003)[0m ╭─────────────────────────────────────────╮
[2m[36m(TunerInternal pid=3003)[0m │ Training result                         │
[2m[36m(TunerInternal pid=3003)[0m ├─────────────────────────────────────────┤
[2m[36m(TunerInternal pid=3003)[0m │ checkpoint_dir_name   checkpoint_000003 │
[2m[36m(TunerInternal pid=3003)[0m │ time_this_iter_s               186.0695 │
[2m[36m(TunerInternal pid=3003)[0m │ time_total_s                  751.04532 │
[2m[36m(TunerInternal pid=3003)[0m │ training_iteration                    4 │
[2m[36m(TunerInternal pid=3003)[0m │ epoch                                 3 │
[2m[36m(TunerInternal pid=3003)[0m │ loss                            0.16956 │
[2m[36m(TunerInternal pid=3003)[0m ╰─────────────────────────────────────────╯
[2m[36m(TunerInte

[2m[36m(RayTrainWorker pid=676, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000003)
[2m[36m(RayTrainWorker pid=675, ip=10.244.0.78)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/ray/nfs/nfs/TorchTrainer_bf0eb_00000_0_2024-01-12_01-27-03/checkpoint_000003)
