Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set an upper limit on CPU threads in distributed training #18677

Merged
merged 26 commits into from Oct 4, 2023

Conversation

awaelchli
Copy link
Member

@awaelchli awaelchli commented Sep 29, 2023

What does this PR do?

Fixes #16737

Aligns our default launcher closer with torchrun's settings regarding the number of threads:
https://github.com/pytorch/pytorch/blob/e0be9ebc181fbf7b2ed4e641cd24a0dadd063f27/torch/distributed/run.py#L702-L714

Benchmarks

All experiments on 8xA100 40GB machine, 256 CPU cores, PyTorch 2.2.0.dev20230920+cu121 nightly.

Computer Vision Example

import os

import lightning.pytorch as pl
import timm
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import CIFAR10
import time

epochs = 5
epoch_size = 100


class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = timm.create_model("rexnet_150", pretrained=True, num_classes=10)
        self.t0 = 0
        self.epoch_time = []

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.1)

    def on_train_start(self):
        # torch.set_num_threads(1)
        # torch.set_num_interop_threads(1)
        print(os.environ.get("OMP_NUM_THREADS"))
        print(torch.get_num_threads(), torch.get_num_interop_threads())

    def on_train_epoch_start(self):
        self.t0 = time.perf_counter()

    def on_train_epoch_end(self):
        self.epoch_time.append(time.perf_counter() - self.t0)

    def on_train_end(self):
        avg_epoch_time = sum(self.epoch_time) / len(self.epoch_time)
        it_per_sec = epochs * epoch_size / sum(self.epoch_time)
        self.trainer.print(f"avg time per epoch: {avg_epoch_time:.3f}")
        self.trainer.print(f"it/s: {it_per_sec:.3f}")

    def training_step(self, batch):
        x, y = batch
        logits = self.model(x)
        return F.cross_entropy(logits, y)


def run():
    transform = tfs.Compose([tfs.Resize((224, 224)), tfs.ToTensor()])
    dataset = CIFAR10(".", train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64)
    model = LitModel()
    trainer = pl.Trainer(max_epochs=epochs, limit_train_batches=epoch_size, accelerator="cuda", devices=2, strategy="ddp")
    trainer.fit(model, dataloader)


if __name__ == "__main__":
    run()

Single GPU

num_threads it/s
cpu_count 0.30
cpu_count / num_procs 9.673
1 9.634

DDP with 2 processes/GPUs

num_threads it/s
cpu_count 0.16
cpu_count / num_procs 5.940
1 7.640

Transformer Example

import lightning as L
import torch
import torch.nn.functional as F
from lightning.pytorch.demos import Transformer, WikiText2
from torch.utils.data import DataLoader
import time
import os

epochs = 1
epoch_size = 1000


class LanguageModel(L.LightningModule):
    def __init__(self, vocab_size):
        super().__init__()
        self.model = Transformer(vocab_size=vocab_size)
        self.t0 = 0
        self.epoch_time = []

    def on_train_start(self):
        # torch.set_num_threads(1)
        # torch.set_num_interop_threads(1)
        print(os.environ.get("OMP_NUM_THREADS"))
        print(torch.get_num_threads(), torch.get_num_interop_threads())

    def on_train_epoch_start(self):
        self.t0 = time.perf_counter()

    def on_train_epoch_end(self):
        self.epoch_time.append(time.perf_counter() - self.t0)

    def on_train_end(self):
        avg_epoch_time = sum(self.epoch_time) / len(self.epoch_time)
        it_per_sec = epochs * epoch_size / sum(self.epoch_time)
        self.trainer.print(f"avg time per epoch: {avg_epoch_time:.3f}")
        self.trainer.print(f"it/s: {it_per_sec:.3f}")

    def training_step(self, batch, batch_idx):
        input, target = batch
        output = self.model(input, target)
        loss = F.nll_loss(output, target.view(-1))
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)


def main():
    L.seed_everything(42)

    dataset = WikiText2()
    train_dataloader = DataLoader(dataset, batch_size=20, shuffle=True)
    model = LanguageModel(vocab_size=dataset.vocab_size)
    trainer = L.Trainer(gradient_clip_val=0.25, max_epochs=epochs, limit_train_batches=epoch_size, accelerator="cuda", devices=6, strategy="ddp")
    trainer.fit(model, train_dataloader)
    


if __name__ == "__main__":
    main()

DDP with 2 processes/GPUs

num_threads it/s
cpu_count 32.481
cpu_count / num_procs 34.267
1 32.910

DDP with 6 processes/GPUs

num_threads it/s
cpu_count 63.514
cpu_count / num_procs 65.260
1 64.409

Discussion

In cases where CPU load is high like in the CV example where images need to be resized, the system can be overloaded if too many threads are launched. By restricting the upper bound on number of threads per GPU to num-cpus / num-procs, we ensure that the total number of threads across all processes in a machine is not exceedingly larger than the number of available cores on the system. This setting won't affect tasks that are CPU-light, as seen in the transformer example above. Torchrun/elastic sets the limit to 1 thread per DDP process, but warns that the value should be tuned. We could do the same, but in this PR we opt for a higher number and in return don't emit warnings.

cc @carmocca @justusschock @awaelchli @Borda

@github-actions github-actions bot added fabric lightning.fabric.Fabric pl Generic label for PyTorch Lightning package labels Sep 29, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Oct 4, 2023
@awaelchli awaelchli marked this pull request as ready for review October 4, 2023 12:16
@awaelchli awaelchli changed the title Set num threads Set an upper limit on CPU threads in distributed training Oct 4, 2023
@github-actions
Copy link
Contributor

github-actions bot commented Oct 4, 2023

⚡ Required checks status: All passing 🟢

Groups summary

🟢 pytorch_lightning: Tests workflow
Check ID Status
pl-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
pl-cpu (macOS-11, lightning, 3.9, 1.12) success
pl-cpu (macOS-11, lightning, 3.10, 1.13) success
pl-cpu (macOS-11, lightning, 3.10, 2.0) success
pl-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
pl-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
pl-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
pl-cpu (windows-2022, lightning, 3.9, 1.12) success
pl-cpu (windows-2022, lightning, 3.10, 1.13) success
pl-cpu (windows-2022, lightning, 3.10, 2.0) success
pl-cpu (macOS-11, pytorch, 3.8, 1.13) success
pl-cpu (ubuntu-20.04, pytorch, 3.8, 1.13) success
pl-cpu (windows-2022, pytorch, 3.8, 1.13) success
pl-cpu (macOS-12, pytorch, 3.11, 2.0) success
pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0) success
pl-cpu (windows-2022, pytorch, 3.11, 2.0) success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py, tests/tests_pytorch/conftest.py, tests/tests_pytorch/strategies/test_ddp_integration.py.

🟢 pytorch_lightning: Azure GPU
Check ID Status
[pytorch-lightning (GPUs) (testing Lightning latest)](https://dev.azure.com/Lightning-AI/72ab7ed8-b00f-4b6e-b131-3388f7ffafa7/_build/results?buildId=177636&view=logs&jobId=47e66f3c-897a-5428-da11-bf5c7745762e) success
[pytorch-lightning (GPUs) (testing PyTorch latest)](https://dev.azure.com/Lightning-AI/72ab7ed8-b00f-4b6e-b131-3388f7ffafa7/_build/results?buildId=177636&view=logs&jobId=3f274fac-2e11-54ca-487e-194c91f3ae9f) success

These checks are required after the changes to src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py, tests/tests_pytorch/conftest.py, tests/tests_pytorch/strategies/test_ddp_integration.py, src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py.

🟢 pytorch_lightning: Benchmarks
Check ID Status
lightning.Benchmarks success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py.

🟢 fabric: Docs
Check ID Status
docs-make (fabric, doctest) success
docs-make (fabric, html) success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py.

🟢 pytorch_lightning: Docs
Check ID Status
docs-make (pytorch, doctest) success
docs-make (pytorch, html) success

These checks are required after the changes to src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py, docs/source-pytorch/conf.py.

🟢 lightning_fabric: CPU workflow
Check ID Status
fabric-cpu (macOS-11, lightning, 3.8, 1.12, oldest) success
fabric-cpu (macOS-11, lightning, 3.9, 1.12) success
fabric-cpu (macOS-11, lightning, 3.10, 1.13) success
fabric-cpu (macOS-11, lightning, 3.10, 2.0) success
fabric-cpu (ubuntu-20.04, lightning, 3.8, 1.12, oldest) success
fabric-cpu (ubuntu-20.04, lightning, 3.9, 1.12) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 1.13) success
fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0) success
fabric-cpu (windows-2022, lightning, 3.8, 1.12, oldest) success
fabric-cpu (windows-2022, lightning, 3.9, 1.12) success
fabric-cpu (windows-2022, lightning, 3.10, 1.13) success
fabric-cpu (windows-2022, lightning, 3.10, 2.0) success
fabric-cpu (macOS-11, fabric, 3.8, 1.13) success
fabric-cpu (ubuntu-20.04, fabric, 3.8, 1.13) success
fabric-cpu (windows-2022, fabric, 3.8, 1.13) success
fabric-cpu (macOS-12, fabric, 3.11, 2.0) success
fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0) success
fabric-cpu (windows-2022, fabric, 3.11, 2.0) success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, tests/tests_fabric/conftest.py, tests/tests_fabric/utilities/test_distributed.py.

🟢 lightning_fabric: Azure GPU
Check ID Status
[lightning-fabric (GPUs) (testing Fabric latest)](https://dev.azure.com/Lightning-AI/72ab7ed8-b00f-4b6e-b131-3388f7ffafa7/_build/results?buildId=177638&view=logs&jobId=3f274fac-2e11-54ca-487e-194c91f3ae9f) success
[lightning-fabric (GPUs) (testing Lightning latest)](https://dev.azure.com/Lightning-AI/72ab7ed8-b00f-4b6e-b131-3388f7ffafa7/_build/results?buildId=177638&view=logs&jobId=b2def368-7fa8-5edf-f15e-38e7ac88d76c) success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, tests/tests_fabric/conftest.py, tests/tests_fabric/utilities/test_distributed.py.

🟢 mypy
Check ID Status
mypy success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py.

🟢 install
Check ID Status
install-pkg (ubuntu-22.04, app, 3.8) success
install-pkg (ubuntu-22.04, app, 3.11) success
install-pkg (ubuntu-22.04, fabric, 3.8) success
install-pkg (ubuntu-22.04, fabric, 3.11) success
install-pkg (ubuntu-22.04, pytorch, 3.8) success
install-pkg (ubuntu-22.04, pytorch, 3.11) success
install-pkg (ubuntu-22.04, lightning, 3.8) success
install-pkg (ubuntu-22.04, lightning, 3.11) success
install-pkg (ubuntu-22.04, notset, 3.8) success
install-pkg (ubuntu-22.04, notset, 3.11) success
install-pkg (macOS-12, app, 3.8) success
install-pkg (macOS-12, app, 3.11) success
install-pkg (macOS-12, fabric, 3.8) success
install-pkg (macOS-12, fabric, 3.11) success
install-pkg (macOS-12, pytorch, 3.8) success
install-pkg (macOS-12, pytorch, 3.11) success
install-pkg (macOS-12, lightning, 3.8) success
install-pkg (macOS-12, lightning, 3.11) success
install-pkg (macOS-12, notset, 3.8) success
install-pkg (macOS-12, notset, 3.11) success
install-pkg (windows-2022, app, 3.8) success
install-pkg (windows-2022, app, 3.11) success
install-pkg (windows-2022, fabric, 3.8) success
install-pkg (windows-2022, fabric, 3.11) success
install-pkg (windows-2022, pytorch, 3.8) success
install-pkg (windows-2022, pytorch, 3.11) success
install-pkg (windows-2022, lightning, 3.8) success
install-pkg (windows-2022, lightning, 3.11) success
install-pkg (windows-2022, notset, 3.8) success
install-pkg (windows-2022, notset, 3.11) success

These checks are required after the changes to src/lightning/fabric/cli.py, src/lightning/fabric/strategies/launchers/multiprocessing.py, src/lightning/fabric/strategies/launchers/subprocess_script.py, src/lightning/fabric/utilities/distributed.py, src/lightning/pytorch/strategies/launchers/multiprocessing.py, src/lightning/pytorch/strategies/launchers/subprocess_script.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and updates for 60 minutes every 180 seconds. If you have any other questions, contact carmocca for help.

@awaelchli awaelchli modified the milestones: 2.0.x, 2.1 Oct 4, 2023
@awaelchli awaelchli added the strategy: ddp DistributedDataParallel label Oct 4, 2023
@Borda
Copy link
Member

Borda commented Oct 4, 2023

have you tried some extensive data preprocessing because this seems as you hitting IO bottleneck

@codecov
Copy link

codecov bot commented Oct 4, 2023

Codecov Report

Merging #18677 (05fc41d) into master (343f804) will decrease coverage by 29%.
The diff coverage is 90%.

Additional details and impacted files
@@            Coverage Diff            @@
##           master   #18677     +/-   ##
=========================================
- Coverage      84%      55%    -29%     
=========================================
  Files         428      423      -5     
  Lines       33550    33478     -72     
=========================================
- Hits        28070    18329   -9741     
- Misses       5480    15149   +9669     

@mergify mergify bot added the has conflicts label Oct 4, 2023
@awaelchli
Copy link
Member Author

awaelchli commented Oct 4, 2023

@Borda I don't see how I'm hitting an IO bottleneck. The number of dataloading workers are the same between the comparison experiments, and so the IO load is identical. The only thing that changes between these runs is the number of threads used for the intra-op parallelism. The experiment simply shows that with choosing too many threads running in parallel and competing with each other, the effective throughput degrades.

@Borda
Copy link
Member

Borda commented Oct 4, 2023

@Borda I don't see how I'm hitting an IO bottleneck. The number of dataloading workers are the same between the comparison experiments, and so the IO load is identical. The only thing that changes between these runs is the number of threads used for the intra-op parallelism. The experiment simply shows that with choosing too many threads running in parallel and competing with each other, the effective throughput degrades.

I do not say that your comparison is wrong or unfair, I just say that in case that when your preprocessing would take much longer for example image augmentation runs very often on CPU, then you can see some advantage of using small number of workers... the bottleneck IO I meant for example your IO speed os 10MB/s, then this speed is for all workers not each so it make sense when you divide 10MB/s per 5 workers they will be slover because each would have at most 2MB/s munis some switching overhead...

@mergify mergify bot added ready PRs ready to be merged and removed has conflicts ready PRs ready to be merged labels Oct 4, 2023
@github-actions github-actions bot added the docs Documentation related label Oct 4, 2023
@awaelchli awaelchli removed the docs Documentation related label Oct 4, 2023
@awaelchli awaelchli merged commit 09a0fb2 into master Oct 4, 2023
109 checks passed
@awaelchli awaelchli deleted the feature/set-num-threads branch October 4, 2023 23:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric performance pl Generic label for PyTorch Lightning package ready PRs ready to be merged strategy: ddp DistributedDataParallel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Extreme single thread cpu kernel usage while training on GPU
3 participants